diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..22cad94f5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,30 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +trim_trailing_whitespace = true +insert_final_newline = true +max_line_length = 80 + +[*.md] +indent_style = space +indent_size = 4 +trim_trailing_whitespace = false +max_line_length = off + +[*.rs] +indent_style = space +indent_size = 4 + +[*.toml] +indent_style = space +indent_size = 4 + +[*.{yaml,yml}] +indent_style = space +indent_size = 2 + +[Makefile] +indent_style = tab +indent_size = 4 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index 0e43501f7..000000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1,12 +0,0 @@ -# These are supported funding model platforms - -github: webrtc-rs # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] -patreon: WebRTCrs # Replace with a single Patreon username -open_collective: webrtc-rs # Replace with a single Open Collective username -ko_fi: # Replace with a single Ko-fi username -tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel -community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry -liberapay: # Replace with a single Liberapay username -issuehunt: # Replace with a single IssueHunt username -otechie: # Replace with a single Otechie username -custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/actions-rs/grcov.yml b/.github/actions-rs/grcov.yml deleted file mode 100644 index 94b11b067..000000000 --- a/.github/actions-rs/grcov.yml +++ /dev/null @@ -1,15 +0,0 @@ -branch: true -ignore-not-existing: true -llvm: true -filter: covered -output-type: lcov -output-path: ./lcov.info -source-dir: . -ignore: - - "/*" - - "C:/*" - - "../*" -excl-line: "#\\[derive\\(" -excl-start: "mod tests \\{" -excl-br-line: "#\\[derive\\(" -excl-br-start: "mod tests \\{" \ No newline at end of file diff --git a/.github/workflows/cargo.yml b/.github/workflows/cargo.yml deleted file mode 100644 index 9ece39a19..000000000 --- a/.github/workflows/cargo.yml +++ /dev/null @@ -1,120 +0,0 @@ -name: cargo - -on: - push: - branches: [master] - pull_request: - branches: [master] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -env: - CARGO_TERM_COLOR: always - -jobs: - test: - name: Test - strategy: - matrix: - os: ["ubuntu-latest", "macos-latest"] - toolchain: - # - 1.65.0 # min supported version (https://github.com/webrtc-rs/webrtc/#toolchain) - - stable - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v3 - - name: Install Rust ${{ matrix.toolchain }} - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.toolchain }} - override: true - - name: Install Rust - run: rustup update stable - - name: ๐Ÿ“ฆ Cache cargo registry - uses: actions/cache@v3 - with: - path: ~/.cargo/registry - key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-registry- - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.toolchain }} - profile: minimal - override: true - - name: ๐Ÿญ Cache dependencies - uses: Swatinem/rust-cache@v2 - - name: Test - run: cargo test - - name: Test with all features enabled - run: cargo test --all-features - - test_windows: - name: Test (windows) - strategy: - matrix: - toolchain: - # - 1.63.0 # min supported version (https://github.com/webrtc-rs/webrtc/#toolchain) - - stable - runs-on: windows-latest - steps: - - uses: actions/checkout@v3 - - name: Install Rust ${{ matrix.toolchain }} - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.toolchain }} - override: true - - name: Install Rust - run: rustup update stable - - name: ๐Ÿ“ฆ Cache cargo registry - uses: actions/cache@v3 - with: - path: ~/.cargo/registry - key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-registry- - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.toolchain }} - profile: minimal - override: true - - name: Copy to C drive - run: cp D:\a C:\ -Recurse - # - name: ๐Ÿญ Cache dependencies - # uses: Swatinem/rust-cache@v2 - - name: Test - working-directory: "C:\\a\\webrtc\\webrtc" - run: cargo test --features metrics - - quality: - name: Check formatting style and run clippy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - components: clippy, rustfmt - override: true - - name: ๐Ÿ“ฆ Cache cargo registry - uses: actions/cache@v3 - with: - path: ~/.cargo/registry - key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-registry- - - name: ๐Ÿ“Ž Run clippy - uses: actions-rs/cargo@v1 - with: - command: clippy - args: --workspace --all-targets --all-features --all -- -D warnings - - name: ๐Ÿ’ฌ Check formatting - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check - # - name: Check for typos - # uses: crate-ci/typos@master diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..44efe0993 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,183 @@ +name: CI + +on: + push: + branches: ["master"] + tags: ["v*"] + pull_request: + branches: ["master"] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUST_BACKTRACE: 1 + +jobs: + + ################ + # Pull Request # + ################ + + pr: + if: ${{ github.event_name == 'pull_request' }} + needs: + - clippy + #- msrv + - rustdoc + - rustfmt + - test + runs-on: ubuntu-latest + steps: + - run: true + + + + + ########################## + # Linting and formatting # + ########################## + + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: stable + components: clippy + + - run: make cargo.lint + + rustfmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: nightly + components: rustfmt + + - run: make cargo.fmt check=yes + + + + + ########### + # Testing # + ########### + + msrv: + name: MSRV + if: ${{ false }} # TODO: re-enable once fully refactored + strategy: + fail-fast: false + matrix: + msrv: ["1.70.0"] + os: ["ubuntu", "macOS", "windows"] + runs-on: ${{ matrix.os }}-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: nightly + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: ${{ matrix.msrv }} + + - run: cargo +nightly update -Z minimal-versions + + - run: make test.cargo + + test: + strategy: + fail-fast: false + matrix: + toolchain: ["stable", "beta", "nightly"] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: ${{ matrix.toolchain }} + components: rust-src + + - run: cargo install cargo-careful + if: ${{ matrix.toolchain == 'nightly' }} + + - run: make test.cargo + careful=${{ (matrix.toolchain == 'nightly' && 'yes') + || 'no' }} + + + + + ################# + # Documentation # + ################# + + rustdoc: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: stable + + - run: make cargo.doc private=yes open=no + env: + RUSTFLAGS: -D warnings + + + + + ############# + # Releasing # + ############# + + publish: + name: publish (crates.io) + if: ${{ startsWith(github.ref, 'refs/tags/v') }} + needs: ["release-github"] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + toolchain: stable + + - run: cargo publish -p medea-turn + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CRATESIO_TOKEN }} + + release-github: + name: release (GitHub) + if: ${{ startsWith(github.ref, 'refs/tags/v') }} + needs: ["clippy", "msrv", "rustdoc", "rustfmt", "test"] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Parse release version + id: release + run: echo "version=${GITHUB_REF#refs/tags/v}" + >> $GITHUB_OUTPUT + - name: Verify release version matches Cargo manifest + run: | + test "${{ steps.release.outputs.version }}" \ + == "$(grep -m1 'version = "' Cargo.toml | cut -d'"' -f2)" + + - name: Parse CHANGELOG link + id: changelog + run: echo "link=${{ github.server_url }}/${{ github.repository }}/blob/v${{ steps.release.outputs.version }}/CHANGELOG.md#$(sed -n '/^## \[${{ steps.release.outputs.version }}\]/{s/^## \[\(.*\)\][^0-9]*\([0-9].*\)/\1--\2/;s/[^0-9a-z-]*//g;p;}' CHANGELOG.md)" + >> $GITHUB_OUTPUT + + - name: Create GitHub release + uses: softprops/action-gh-release@v1 + with: + name: ${{ steps.release.outputs.version }} + body: | + [API docs](https://docs.rs/medea-turn/${{ steps.release.outputs.version }}) + [Changelog](${{ steps.changelog.outputs.link }}) + prerelease: ${{ contains(steps.release.outputs.version, '-') }} diff --git a/.github/workflows/grcov.yml b/.github/workflows/grcov.yml deleted file mode 100644 index 1d4f25884..000000000 --- a/.github/workflows/grcov.yml +++ /dev/null @@ -1,89 +0,0 @@ -name: coverage - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -env: - CARGO_TERM_COLOR: always - -jobs: - grcov: - name: Coverage - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: - - ubuntu-latest - toolchain: - - nightly - cargo_flags: - - "--all-features" - steps: - - name: Checkout source code - uses: actions/checkout@v2 - - - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: ${{ matrix.toolchain }} - override: true - - - name: Install grcov - uses: actions-rs/install@v0.1 - with: - crate: grcov - version: latest - use-tool-cache: true - - - name: Test - uses: actions-rs/cargo@v1 - with: - command: test - args: --all --no-fail-fast ${{ matrix.cargo_flags }} - env: - CARGO_INCREMENTAL: "0" - RUSTFLAGS: '-Zprofile -Ccodegen-units=1 -Copt-level=0 -Clink-dead-code -Coverflow-checks=off -Zpanic_abort_tests -Cpanic=abort -Cdebug-assertions=off' - RUSTDOCFLAGS: '-Zprofile -Ccodegen-units=1 -Copt-level=0 -Clink-dead-code -Coverflow-checks=off -Zpanic_abort_tests -Cpanic=abort -Cdebug-assertions=off' - - - name: Generate coverage data - id: grcov - # uses: actions-rs/grcov@v0.1 - run: | - grcov target/debug/ \ - --branch \ - --llvm \ - --source-dir . \ - --output-path lcov.info \ - --ignore='/**' \ - --ignore='C:/**' \ - --ignore='../**' \ - --ignore-not-existing \ - --excl-line "#\\[derive\\(" \ - --excl-br-line "#\\[derive\\(" \ - --excl-start "#\\[cfg\\(test\\)\\]" \ - --excl-br-start "#\\[cfg\\(test\\)\\]" \ - --commit-sha ${{ github.sha }} \ - --service-job-id ${{ github.job }} \ - --service-name "GitHub Actions" \ - --service-number ${{ github.run_id }} - - name: Upload coverage as artifact - uses: actions/upload-artifact@v2 - with: - name: lcov.info - # path: ${{ steps.grcov.outputs.report }} - path: lcov.info - - - name: Upload coverage to codecov.io - uses: codecov/codecov-action@v1 - with: - # file: ${{ steps.grcov.outputs.report }} - file: lcov.info - fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index ef58f6ced..1e25d8ede 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,7 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ /.idea/ -/crates/target/ -/crates/.idea/ - -# These are backup files generated by rustfmt -**/*.rs.bk +/.vscode/ +/*.iml +.DS_Store -# Editor configs -.vscode +/Cargo.lock +/target/ diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e69de29bb..000000000 diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 000000000..947f9c57b --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,18 @@ +# Project configuration for rustfmt Rust code formatter. +# See full list of configurations at: +# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md + +max_width = 80 +use_small_heuristics = "Max" + +format_strings = false +imports_granularity = "Crate" + +format_code_in_doc_comments = true +format_macro_matchers = true +use_try_shorthand = true + +error_on_line_overflow = true +error_on_unformatted = true + +unstable_features = true diff --git a/turn/CHANGELOG.md b/CHANGELOG.md similarity index 57% rename from turn/CHANGELOG.md rename to CHANGELOG.md index df8893093..f71fea8c2 100644 --- a/turn/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,56 @@ +`medea-turn` changelog +====================== + +All user visible changes to this project will be documented in this file. This project uses [Semantic Versioning 2.0.0]. + + + + +## [0.7.0] ยท 2024-??-?? (unreleased) +[0.7.0]: /../../tree/v0.7.0 + +### Initially re-implemented + +- Performed major refactoring with non-server code removing. ([#1]) +- Added TCP transport. ([#1]) + +### [Upstream changes](https://github.com/webrtc-rs/webrtc/blob/89285ceba23dc57fc99386cb978d2d23fe909437/turn/CHANGELOG.md#unreleased) + +- Fixed non-released UDP port of server relay. ([webrtc-rs/webrtc#330] by [@clia]) +- Added `alloc_close_notify` config parameter to `ServerConfig` and `Allocation` to receive notify on allocation close event, with metrics data. ([webrtc-rs/webrtc#421] by [@clia]) + +[@clia]: https://github.com/clia +[#1]: /../../pull/1 +[webrtc-rs/webrtc#330]: https://github.com/webrtc-rs/webrtc/pull/330 +[webrtc-rs/webrtc#421]: https://github.com/webrtc-rs/webrtc/pull/421 + + + + +## Previous releases + +See [old upstream CHANGELOG](https://github.com/webrtc-rs/webrtc/blob/turn-v0.6.1/turn/CHANGELOG.md). + + + + +[Semantic Versioning 2.0.0]: https://semver.org + + + + + + + # webrtc-turn changelog ## Unreleased * [#330 Fix the problem that the UDP port of the server relay is not released](https://github.com/webrtc-rs/webrtc/pull/330) by [@clia](https://github.com/clia). * Added `alloc_close_notify` config parameter to `ServerConfig` and `Allocation`, to receive notify on allocation close event, with metrics data. +* Major refactor, add TCP transport [#1] + +[#1]: https://github.com/instrumentisto/medea-turn-rs/pull/1 ## v0.6.1 diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index 5fe1f90e5..000000000 --- a/Cargo.lock +++ /dev/null @@ -1,3207 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "addr2line" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "aead" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" -dependencies = [ - "crypto-common", - "generic-array", -] - -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "aes-gcm" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" -dependencies = [ - "aead", - "aes", - "cipher", - "ctr", - "ghash", - "subtle", -] - -[[package]] -name = "aho-corasick" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" -dependencies = [ - "memchr", -] - -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - -[[package]] -name = "anstyle" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" - -[[package]] -name = "anyhow" -version = "1.0.82" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" - -[[package]] -name = "arc-swap" -version = "1.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" - -[[package]] -name = "asn1-rs" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ad1373757efa0f70ec53939aabc7152e1591cb485208052993070ac8d2429d" -dependencies = [ - "asn1-rs-derive", - "asn1-rs-impl", - "displaydoc", - "nom", - "num-traits", - "rusticata-macros", - "thiserror", - "time", -] - -[[package]] -name = "asn1-rs-derive" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7378575ff571966e99a744addeff0bff98b8ada0dedf1956d59e634db95eaac1" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "asn1-rs-impl" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "async-channel" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d4d23bcc79e27423727b36823d86233aad06dfea531837b038394d11e9928" -dependencies = [ - "concurrent-queue", - "event-listener 5.3.0", - "event-listener-strategy 0.5.1", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-executor" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b10202063978b3351199d68f8b22c4e47e4b1b822f8d43fd862d5ea8c006b29a" -dependencies = [ - "async-task", - "concurrent-queue", - "fastrand", - "futures-lite", - "slab", -] - -[[package]] -name = "async-global-executor" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c" -dependencies = [ - "async-channel", - "async-executor", - "async-io", - "async-lock", - "blocking", - "futures-lite", - "once_cell", -] - -[[package]] -name = "async-io" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcccb0f599cfa2f8ace422d3555572f47424da5648a4382a9dd0310ff8210884" -dependencies = [ - "async-lock", - "cfg-if", - "concurrent-queue", - "futures-io", - "futures-lite", - "parking", - "polling", - "rustix", - "slab", - "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "async-lock" -version = "3.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" -dependencies = [ - "event-listener 4.0.3", - "event-listener-strategy 0.4.0", - "pin-project-lite", -] - -[[package]] -name = "async-stream" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "async-task" -version = "4.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbb36e985947064623dbd357f727af08ffd077f93d696782f3c56365fa2e2799" - -[[package]] -name = "async-trait" -version = "0.1.80" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - -[[package]] -name = "autocfg" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" - -[[package]] -name = "backtrace" -version = "0.3.71" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - -[[package]] -name = "base64" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" - -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "block-padding" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" -dependencies = [ - "generic-array", -] - -[[package]] -name = "blocking" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" -dependencies = [ - "async-channel", - "async-lock", - "async-task", - "fastrand", - "futures-io", - "futures-lite", - "piper", - "tracing", -] - -[[package]] -name = "bumpalo" -version = "3.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "bytes" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" - -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - -[[package]] -name = "cbc" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" -dependencies = [ - "cipher", -] - -[[package]] -name = "cc" -version = "1.0.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" - -[[package]] -name = "ccm" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae3c82e4355234767756212c570e29833699ab63e6ffd161887314cc5b43847" -dependencies = [ - "aead", - "cipher", - "ctr", - "subtle", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "chrono" -version = "0.4.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" -dependencies = [ - "android-tzdata", - "iana-time-zone", - "js-sys", - "num-traits", - "wasm-bindgen", - "windows-targets 0.52.5", -] - -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - -[[package]] -name = "clap" -version = "3.2.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" -dependencies = [ - "atty", - "bitflags 1.3.2", - "clap_lex 0.2.4", - "indexmap 1.9.3", - "strsim", - "termcolor", - "textwrap", -] - -[[package]] -name = "clap" -version = "4.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" -dependencies = [ - "clap_builder", -] - -[[package]] -name = "clap_builder" -version = "4.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" -dependencies = [ - "anstyle", - "clap_lex 0.7.0", -] - -[[package]] -name = "clap_lex" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" -dependencies = [ - "os_str_bytes", -] - -[[package]] -name = "clap_lex" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" - -[[package]] -name = "concurrent-queue" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - -[[package]] -name = "core-foundation-sys" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" - -[[package]] -name = "cpufeatures" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" -dependencies = [ - "libc", -] - -[[package]] -name = "crc" -version = "3.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" -dependencies = [ - "crc-catalog", -] - -[[package]] -name = "crc-catalog" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" - -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap 4.5.4", - "criterion-plot", - "futures", - "is-terminal", - "itertools", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" - -[[package]] -name = "crunchy" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" - -[[package]] -name = "crypto-bigint" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" -dependencies = [ - "generic-array", - "rand_core", - "subtle", - "zeroize", -] - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "rand_core", - "typenum", -] - -[[package]] -name = "ctr" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" -dependencies = [ - "cipher", -] - -[[package]] -name = "curve25519-dalek" -version = "4.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a677b8922c94e01bdbb12126b0bc852f00447528dee1782229af9c720c3f348" -dependencies = [ - "cfg-if", - "cpufeatures", - "curve25519-dalek-derive", - "fiat-crypto", - "platforms", - "rustc_version", - "subtle", - "zeroize", -] - -[[package]] -name = "curve25519-dalek-derive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "data-encoding" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" - -[[package]] -name = "der" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" -dependencies = [ - "const-oid", - "pem-rfc7468", - "zeroize", -] - -[[package]] -name = "der-parser" -version = "9.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" -dependencies = [ - "asn1-rs", - "displaydoc", - "nom", - "num-bigint", - "num-traits", - "rusticata-macros", -] - -[[package]] -name = "deranged" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" -dependencies = [ - "powerfmt", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "const-oid", - "crypto-common", - "subtle", -] - -[[package]] -name = "displaydoc" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "ecdsa" -version = "0.16.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" -dependencies = [ - "der", - "digest", - "elliptic-curve", - "rfc6979", - "signature", - "spki", -] - -[[package]] -name = "either" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" - -[[package]] -name = "elliptic-curve" -version = "0.13.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" -dependencies = [ - "base16ct", - "crypto-bigint", - "digest", - "ff", - "generic-array", - "group", - "hkdf", - "pem-rfc7468", - "pkcs8", - "rand_core", - "sec1", - "subtle", - "zeroize", -] - -[[package]] -name = "env_logger" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" -dependencies = [ - "humantime", - "is-terminal", - "log", - "regex", - "termcolor", -] - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "errno" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "event-listener" -version = "4.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener" -version = "5.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" -dependencies = [ - "event-listener 4.0.3", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332f51cb23d20b0de8458b86580878211da09bcd4503cb579c225b3d124cabb3" -dependencies = [ - "event-listener 5.3.0", - "pin-project-lite", -] - -[[package]] -name = "examples" -version = "0.5.0" -dependencies = [ - "anyhow", - "bytes", - "chrono", - "clap 3.2.25", - "env_logger", - "hyper", - "lazy_static", - "log", - "memchr", - "rand", - "serde", - "serde_json", - "signal", - "tokio", - "tokio-util", - "webrtc", -] - -[[package]] -name = "fastrand" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" - -[[package]] -name = "ff" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" -dependencies = [ - "rand_core", - "subtle", -] - -[[package]] -name = "fiat-crypto" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c007b1ae3abe1cb6f85a16305acd418b7ca6343b953633fee2b76d8f108b830f" - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - -[[package]] -name = "form_urlencoded" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" -dependencies = [ - "percent-encoding", -] - -[[package]] -name = "futures" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" - -[[package]] -name = "futures-executor" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - -[[package]] -name = "futures-lite" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" -dependencies = [ - "fastrand", - "futures-core", - "futures-io", - "parking", - "pin-project-lite", -] - -[[package]] -name = "futures-macro" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" - -[[package]] -name = "futures-task" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" - -[[package]] -name = "futures-util" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", -] - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", - "zeroize", -] - -[[package]] -name = "getrandom" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "ghash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" -dependencies = [ - "opaque-debug", - "polyval", -] - -[[package]] -name = "gimli" -version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" - -[[package]] -name = "group" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" -dependencies = [ - "ff", - "rand_core", - "subtle", -] - -[[package]] -name = "h2" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http", - "indexmap 2.2.6", - "slab", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "half" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" -dependencies = [ - "cfg-if", - "crunchy", -] - -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - -[[package]] -name = "hashbrown" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" - -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "hkdf" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" -dependencies = [ - "hmac", -] - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "http" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - -[[package]] -name = "http-body" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" -dependencies = [ - "bytes", - "http", - "pin-project-lite", -] - -[[package]] -name = "httparse" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" - -[[package]] -name = "httpdate" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" - -[[package]] -name = "hub" -version = "0.1.0" -dependencies = [ - "rcgen", - "rustls", - "rustls-pemfile", - "thiserror", - "tokio", - "webrtc-dtls", - "webrtc-util", -] - -[[package]] -name = "humantime" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" - -[[package]] -name = "hyper" -version = "0.14.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2", - "http", - "http-body", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", - "want", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - -[[package]] -name = "idna" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - -[[package]] -name = "indexmap" -version = "2.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" -dependencies = [ - "equivalent", - "hashbrown 0.14.3", - "serde", -] - -[[package]] -name = "inout" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" -dependencies = [ - "block-padding", - "generic-array", -] - -[[package]] -name = "interceptor" -version = "0.12.0" -dependencies = [ - "async-trait", - "bytes", - "chrono", - "log", - "portable-atomic", - "rand", - "rtcp", - "rtp", - "thiserror", - "tokio", - "tokio-test", - "waitgroup", - "webrtc-srtp", - "webrtc-util", -] - -[[package]] -name = "ipnet" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" - -[[package]] -name = "is-terminal" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" -dependencies = [ - "hermit-abi 0.3.9", - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" - -[[package]] -name = "js-sys" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "libc" -version = "0.2.153" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" - -[[package]] -name = "linux-raw-sys" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" - -[[package]] -name = "lock_api" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" - -[[package]] -name = "md-5" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" -dependencies = [ - "cfg-if", - "digest", -] - -[[package]] -name = "memchr" -version = "2.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" - -[[package]] -name = "memoffset" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" -dependencies = [ - "autocfg", -] - -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - -[[package]] -name = "miniz_oxide" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" -dependencies = [ - "adler", -] - -[[package]] -name = "mio" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.48.0", -] - -[[package]] -name = "nearly_eq" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a629868a433328c35d654e1e1fb4648a68a042e3c71de4e507a9bcf4602c5635" - -[[package]] -name = "nix" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" -dependencies = [ - "bitflags 1.3.2", - "cfg-if", - "libc", - "memoffset", - "pin-utils", -] - -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - -[[package]] -name = "num-bigint" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi 0.3.9", - "libc", -] - -[[package]] -name = "object" -version = "0.32.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" -dependencies = [ - "memchr", -] - -[[package]] -name = "oid-registry" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c958dd45046245b9c3c2547369bb634eb461670b2e7e0de552905801a648d1d" -dependencies = [ - "asn1-rs", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "oorandom" -version = "11.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" - -[[package]] -name = "opaque-debug" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" - -[[package]] -name = "openssl" -version = "0.10.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" -dependencies = [ - "bitflags 2.5.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "openssl-src" -version = "300.2.3+3.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843" -dependencies = [ - "cc", -] - -[[package]] -name = "openssl-sys" -version = "0.9.102" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" -dependencies = [ - "cc", - "libc", - "openssl-src", - "pkg-config", - "vcpkg", -] - -[[package]] -name = "ordered-float" -version = "4.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" -dependencies = [ - "num-traits", -] - -[[package]] -name = "os_str_bytes" -version = "6.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" - -[[package]] -name = "p256" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" -dependencies = [ - "ecdsa", - "elliptic-curve", - "primeorder", - "sha2", -] - -[[package]] -name = "p384" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70786f51bcc69f6a4c0360e063a4cac5419ef7c5cd5b3c99ad70f3be5ba79209" -dependencies = [ - "ecdsa", - "elliptic-curve", - "primeorder", - "sha2", -] - -[[package]] -name = "parking" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" - -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets 0.48.5", -] - -[[package]] -name = "pem" -version = "3.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" -dependencies = [ - "base64 0.22.0", - "serde", -] - -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - -[[package]] -name = "percent-encoding" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" - -[[package]] -name = "pin-project-lite" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "piper" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" -dependencies = [ - "atomic-waker", - "fastrand", - "futures-io", -] - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - -[[package]] -name = "pkg-config" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" - -[[package]] -name = "platforms" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db23d408679286588f4d4644f965003d056e3dd5abcaaa938116871d7ce2fee7" - -[[package]] -name = "plotters" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" - -[[package]] -name = "plotters-svg" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" -dependencies = [ - "plotters-backend", -] - -[[package]] -name = "polling" -version = "3.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c976a60b2d7e99d6f229e414670a9b85d13ac305cc6d1e9c134de58c5aaaf6" -dependencies = [ - "cfg-if", - "concurrent-queue", - "hermit-abi 0.3.9", - "pin-project-lite", - "rustix", - "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "polyval" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" -dependencies = [ - "cfg-if", - "cpufeatures", - "opaque-debug", - "universal-hash", -] - -[[package]] -name = "portable-atomic" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" - -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "primeorder" -version = "0.13.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" -dependencies = [ - "elliptic-curve", -] - -[[package]] -name = "proc-macro2" -version = "1.0.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - -[[package]] -name = "rcgen" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" -dependencies = [ - "pem", - "ring", - "rustls-pki-types", - "time", - "x509-parser", - "yasna", -] - -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "regex" -version = "1.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" - -[[package]] -name = "rfc6979" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" -dependencies = [ - "hmac", - "subtle", -] - -[[package]] -name = "ring" -version = "0.17.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" -dependencies = [ - "cc", - "cfg-if", - "getrandom", - "libc", - "spin", - "untrusted", - "windows-sys 0.52.0", -] - -[[package]] -name = "rtcp" -version = "0.11.0" -dependencies = [ - "bytes", - "thiserror", - "webrtc-util", -] - -[[package]] -name = "rtp" -version = "0.11.0" -dependencies = [ - "bytes", - "chrono", - "criterion", - "memchr", - "portable-atomic", - "rand", - "serde", - "thiserror", - "webrtc-util", -] - -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - -[[package]] -name = "rusticata-macros" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" -dependencies = [ - "nom", -] - -[[package]] -name = "rustix" -version = "0.38.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" -dependencies = [ - "bitflags 2.5.0", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.52.0", -] - -[[package]] -name = "rustls" -version = "0.23.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afabcee0551bd1aa3e18e5adbf2c0544722014b899adb31bd186ec638d3da97e" -dependencies = [ - "once_cell", - "ring", - "rustls-pki-types", - "rustls-webpki", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls-pemfile" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" -dependencies = [ - "base64 0.22.0", - "rustls-pki-types", -] - -[[package]] -name = "rustls-pki-types" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" - -[[package]] -name = "rustls-webpki" -version = "0.102.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - -[[package]] -name = "ryu" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" - -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "sdp" -version = "0.6.2" -dependencies = [ - "criterion", - "rand", - "substring", - "thiserror", - "url", -] - -[[package]] -name = "sec1" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" -dependencies = [ - "base16ct", - "der", - "generic-array", - "pkcs8", - "subtle", - "zeroize", -] - -[[package]] -name = "semver" -version = "1.0.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" - -[[package]] -name = "serde" -version = "1.0.198" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.198" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.116" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" -dependencies = [ - "indexmap 2.2.6", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sha2" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "signal" -version = "0.1.0" -dependencies = [ - "anyhow", - "base64 0.21.7", - "hyper", - "lazy_static", - "tokio", -] - -[[package]] -name = "signal-hook-registry" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" -dependencies = [ - "libc", -] - -[[package]] -name = "signature" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "digest", - "rand_core", -] - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" - -[[package]] -name = "smol_str" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6845563ada680337a52d43bb0b29f396f2d911616f6573012645b9e3d048a49" -dependencies = [ - "serde", -] - -[[package]] -name = "socket2" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der", -] - -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - -[[package]] -name = "stun" -version = "0.6.0" -dependencies = [ - "base64 0.21.7", - "clap 3.2.25", - "crc", - "criterion", - "lazy_static", - "md-5", - "rand", - "ring", - "subtle", - "thiserror", - "tokio", - "tokio-test", - "url", - "webrtc-util", -] - -[[package]] -name = "substring" -version = "1.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ee6433ecef213b2e72f587ef64a2f5943e7cd16fbd82dbe8bc07486c534c86" -dependencies = [ - "autocfg", -] - -[[package]] -name = "subtle" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" - -[[package]] -name = "syn" -version = "2.0.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "synstructure" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "textwrap" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" - -[[package]] -name = "thiserror" -version = "1.0.58" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.58" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "time" -version = "0.3.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" -dependencies = [ - "deranged", - "itoa", - "num-conv", - "powerfmt", - "serde", - "time-core", - "time-macros", -] - -[[package]] -name = "time-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" - -[[package]] -name = "time-macros" -version = "0.2.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" -dependencies = [ - "num-conv", - "time-core", -] - -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "tinyvec" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "tokio" -version = "1.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" -dependencies = [ - "backtrace", - "bytes", - "libc", - "mio", - "num_cpus", - "parking_lot", - "pin-project-lite", - "signal-hook-registry", - "socket2", - "tokio-macros", - "windows-sys 0.48.0", -] - -[[package]] -name = "tokio-macros" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-stream" -version = "0.1.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-test" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" -dependencies = [ - "async-stream", - "bytes", - "futures-core", - "tokio", - "tokio-stream", -] - -[[package]] -name = "tokio-util" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", - "tracing", -] - -[[package]] -name = "tower-service" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" - -[[package]] -name = "tracing" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" -dependencies = [ - "pin-project-lite", - "tracing-core", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", -] - -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - -[[package]] -name = "turn" -version = "0.8.0" -dependencies = [ - "async-trait", - "base64 0.21.7", - "chrono", - "clap 3.2.25", - "criterion", - "env_logger", - "futures", - "hex", - "log", - "md-5", - "portable-atomic", - "rand", - "ring", - "stun", - "thiserror", - "tokio", - "tokio-test", - "tokio-util", - "webrtc-util", -] - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "unicode-bidi" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unicode-normalization" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "universal-hash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" -dependencies = [ - "crypto-common", - "subtle", -] - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "url" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - -[[package]] -name = "uuid" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" -dependencies = [ - "getrandom", -] - -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - -[[package]] -name = "waitgroup" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1f50000a783467e6c0200f9d10642f4bc424e39efc1b770203e88b488f79292" -dependencies = [ - "atomic-waker", -] - -[[package]] -name = "walkdir" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" -dependencies = [ - "same-file", - "winapi-util", -] - -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" - -[[package]] -name = "web-sys" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webrtc" -version = "0.11.0" -dependencies = [ - "arc-swap", - "async-trait", - "bytes", - "cfg-if", - "env_logger", - "hex", - "interceptor", - "lazy_static", - "log", - "pem", - "portable-atomic", - "rand", - "rcgen", - "regex", - "ring", - "rtcp", - "rtp", - "rustls", - "sdp", - "serde", - "serde_json", - "sha2", - "smol_str", - "stun", - "thiserror", - "time", - "tokio", - "tokio-test", - "turn", - "url", - "waitgroup", - "webrtc-data", - "webrtc-dtls", - "webrtc-ice", - "webrtc-mdns", - "webrtc-media", - "webrtc-sctp", - "webrtc-srtp", - "webrtc-util", -] - -[[package]] -name = "webrtc-constraints" -version = "0.1.0" -dependencies = [ - "env_logger", - "indexmap 2.2.6", - "lazy_static", - "ordered-float", - "serde", - "serde_json", - "thiserror", -] - -[[package]] -name = "webrtc-data" -version = "0.9.0" -dependencies = [ - "bytes", - "chrono", - "env_logger", - "log", - "portable-atomic", - "thiserror", - "tokio", - "tokio-test", - "webrtc-sctp", - "webrtc-util", -] - -[[package]] -name = "webrtc-dtls" -version = "0.10.0" -dependencies = [ - "aes", - "aes-gcm", - "async-trait", - "bincode", - "byteorder", - "cbc", - "ccm", - "chrono", - "clap 3.2.25", - "der-parser", - "env_logger", - "hkdf", - "hmac", - "hub", - "log", - "p256", - "p384", - "pem", - "portable-atomic", - "rand", - "rand_core", - "rcgen", - "ring", - "rustls", - "sec1", - "serde", - "sha1", - "sha2", - "subtle", - "thiserror", - "tokio", - "tokio-test", - "webrtc-util", - "x25519-dalek", - "x509-parser", -] - -[[package]] -name = "webrtc-ice" -version = "0.11.0" -dependencies = [ - "arc-swap", - "async-trait", - "chrono", - "clap 3.2.25", - "crc", - "env_logger", - "hyper", - "ipnet", - "lazy_static", - "log", - "portable-atomic", - "rand", - "regex", - "serde", - "serde_json", - "sha1", - "stun", - "thiserror", - "tokio", - "tokio-test", - "turn", - "url", - "uuid", - "waitgroup", - "webrtc-mdns", - "webrtc-util", -] - -[[package]] -name = "webrtc-mdns" -version = "0.7.0" -dependencies = [ - "chrono", - "clap 3.2.25", - "env_logger", - "log", - "socket2", - "thiserror", - "tokio", - "webrtc-util", -] - -[[package]] -name = "webrtc-media" -version = "0.8.0" -dependencies = [ - "byteorder", - "bytes", - "criterion", - "nearly_eq", - "rand", - "rtp", - "thiserror", -] - -[[package]] -name = "webrtc-sctp" -version = "0.10.0" -dependencies = [ - "arc-swap", - "async-trait", - "bytes", - "chrono", - "clap 3.2.25", - "crc", - "env_logger", - "lazy_static", - "log", - "portable-atomic", - "rand", - "thiserror", - "tokio", - "tokio-test", - "webrtc-util", -] - -[[package]] -name = "webrtc-srtp" -version = "0.13.0" -dependencies = [ - "aead", - "aes", - "aes-gcm", - "byteorder", - "bytes", - "criterion", - "ctr", - "hmac", - "lazy_static", - "log", - "openssl", - "rtcp", - "rtp", - "sha1", - "subtle", - "thiserror", - "tokio", - "tokio-test", - "webrtc-util", -] - -[[package]] -name = "webrtc-util" -version = "0.9.0" -dependencies = [ - "async-global-executor", - "async-trait", - "bitflags 1.3.2", - "bytes", - "chrono", - "criterion", - "env_logger", - "ipnet", - "lazy_static", - "libc", - "log", - "nix", - "portable-atomic", - "rand", - "thiserror", - "tokio", - "tokio-test", - "winapi", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets 0.52.5", -] - -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.5", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" -dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", - "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" - -[[package]] -name = "x25519-dalek" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" -dependencies = [ - "curve25519-dalek", - "rand_core", - "serde", - "zeroize", -] - -[[package]] -name = "x509-parser" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" -dependencies = [ - "asn1-rs", - "data-encoding", - "der-parser", - "lazy_static", - "nom", - "oid-registry", - "ring", - "rusticata-macros", - "thiserror", - "time", -] - -[[package]] -name = "yasna" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" -dependencies = [ - "time", -] - -[[package]] -name = "zeroize" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] diff --git a/Cargo.toml b/Cargo.toml index 371f7d08c..0e1f97298 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,24 +1,27 @@ -[workspace] -members = [ - "constraints", - "data", - "dtls", - "examples", - "ice", - "interceptor", - "mdns", - "media", - "rtcp", - "rtp", - "sctp", - "sdp", - "srtp", - "stun", - "turn", - "util", - "webrtc", -] -resolver = "2" +[package] +name = "medea-turn" +version = "0.7.0-dev" +authors = ["Instrumentisto Team "] +edition = "2021" +rust-version = "1.70" +description = "TURN implementation used by Medea media server." +license = "MIT OR Apache-2.0" +homepage = "https://github.com/instrumentisto/medea-turn-rs" +repository = "https://github.com/instrumentisto/medea-turn-rs" +publish = false -[profile.dev] -opt-level = 0 +[dependencies] +async-trait = "0.1" +bytecodec = "0.4.15" +bytes = "1.6" +futures = "0.3" +log = "0.4" +rand = "0.8" +stun_codec = "0.3" +thiserror = "1.0" +tokio = { version = "1.32", default-features = false, features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } +tokio-util = { version = "0.7", features = ["codec"] } + +[dev-dependencies] +tokio-test = "0.4" +hex = "0.4" diff --git a/LICENSE-APACHE b/LICENSE-APACHE index 16fe87b06..1b5ec8b78 100644 --- a/LICENSE-APACHE +++ b/LICENSE-APACHE @@ -174,28 +174,3 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT index e11d93bef..59f49a934 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,21 +1,25 @@ MIT License -Copyright (c) 2021 WebRTC.rs +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..2971b7768 --- /dev/null +++ b/Makefile @@ -0,0 +1,106 @@ +############################### +# Common defaults/definitions # +############################### + +comma := , + +# Checks two given strings for equality. +eq = $(if $(or $(1),$(2)),$(and $(findstring $(1),$(2)),\ + $(findstring $(2),$(1))),1) + + + + +########### +# Aliases # +########### + +all: fmt lint docs test + + +docs: cargo.doc + + +fmt: cargo.fmt + + +lint: cargo.lint + + +test: test.cargo + + + + +################## +# Cargo commands # +################## + +# Generate crate documentation from Rust sources. +# +# Usage: +# make cargo.doc [private=(yes|no)] [open=(no|yes)] [clean=(no|yes)] + +cargo.doc: +ifeq ($(clean),yes) + @rm -rf target/doc/ +endif + cargo doc --all-features \ + $(if $(call eq,$(private),no),,--document-private-items) \ + $(if $(call eq,$(open),yes),--open,) + + +# Format Rust sources with rustfmt. +# +# Usage: +# make cargo.fmt [check=(no|yes)] + +cargo.fmt: + cargo +nightly fmt --all $(if $(call eq,$(check),yes),-- --check,) + + +# Lint Rust sources with Clippy. +# +# Usage: +# make cargo.lint + +cargo.lint: + cargo clippy --all-features -- -D warnings + + +cargo.test: test.cargo + + + + +#################### +# Testing commands # +#################### + +# Run Rust tests. +# +# Usage: +# make test.cargo [careful=(no|yes)] + +test.cargo: +ifeq ($(careful),yes) +ifeq ($(shell cargo install --list | grep cargo-careful),) + cargo install cargo-careful +endif +ifeq ($(shell rustup component list --toolchain=nightly \ + | grep 'rust-src (installed)'),) + rustup component add --toolchain=nightly rust-src +endif +endif + cargo $(if $(call eq,$(careful),yes),+nightly careful,) test --all-features + + + + +################## +# .PHONY section # +################## + +.PHONY: all docs fmt lint test \ + cargo.doc cargo.fmt cargo.lint cargo.test \ + test.cargo diff --git a/README.md b/README.md index 90ec4957b..0921b79dc 100644 --- a/README.md +++ b/README.md @@ -1,146 +1,27 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - - - Twitter - -

-

- A pure Rust implementation of WebRTC stack. Rewrite Pion WebRTC stack in Rust -

+`medea-turn` +============ -

-Sponsored with ๐Ÿ’– by
-

- -

-Silver Sponsors:
- -Stream Chat -
- -ChannelTalk -
-Bronze Sponsors:
- -KittyCAD -
-AdrianEddy
-

+[![CI](https://github.com/instrumentisto/medea-turn-rs/workflows/CI/badge.svg?branch=main "CI")](https://github.com/instrumentisto/medea-turn-rs/actions?query=workflow%3ACI+branch%3Amain) +[![Rust 1.70+](https://img.shields.io/badge/rustc-1.70+-lightgray.svg "Rust 1.70+")](https://blog.rust-lang.org/2023/06/01/Rust-1.70.0.html) -
-Table of Content -- [Overview](#overview) -- [Features](#features) -- [Building](#building) - - [Toolchain](#toolchain) - - [Monorepo Setup](#monorepo-setup) -- [Open Source License](#open-source-license) -- [Contributing](#contributing) +[Changelog](https://github.com/instrumentisto/medea-turn-rs/blob/master/CHANGELOG.md) -
+TURN implementation used by [Medea media server](https://github.com/instrumentisto/medea). Majorly refactored fork of the [`webrtc-rs/turn` crate](https://github.com/webrtc-rs/webrtc/tree/89285ceba23dc57fc99386cb978d2d23fe909437/turn). -## Overview -WebRTC.rs is a pure Rust implementation of WebRTC stack, which rewrites Pion stack in Rust. -This project is still in active and early development stage, please refer to the [Roadmap](https://github.com/webrtc-rs/webrtc/issues/1) to track the major milestones and releases. -[Examples](https://github.com/webrtc-rs/webrtc/blob/master/examples/examples/README.md) provide code samples to show how to use webrtc-rs to build media and data channel applications. -## Features -

- WebRTC -
- Media - Interceptor - Data -
- RTP - RTCP - SRTP - SCTP -
- DTLS -
- mDNS - STUN - TURN - ICE -
- SDP - Util -

-

- WebRTC Crates Dependency Graph -

-

- WebRTC Stack -

+## License -## Building +Copyright ยฉ 2024 Instrumentisto Team, -### Toolchain +Licensed under either of [Apache License, Version 2.0][APACHE] or [MIT license][MIT] at your option. -**Minimum Supported Rust Version:** `1.65.0` +Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this crate by you, as defined in the [Apache-2.0 license][APACHE], shall be dual licensed as above, without any additional terms or conditions. -Our minimum supported rust version(MSRV) policy is to support versions of the compiler released within the last six months. We don't eagerly bump the minimum version we support, instead the minimum will be bumped on a needed by needed basis, usually because downstream dependencies force us to. -**Note:** Changes to the minimum supported version are not consider breaking from a [semver](https://semver.org/) perspective. -### Monorepo Setup -All webrtc dependent crates and examples are included in this repository at the top level in a Cargo workspace. - -To build all webrtc examples: - -```shell -cd examples -cargo test # build all examples (maybe very slow) -#[ or just build single example (much faster) -cargo build --example play-from-disk-vpx # build play-from-disk-vpx example only -cargo build --example play-from-disk-h264 # build play-from-disk-h264 example only -#... -#] -``` - -To build webrtc crate: - -```shell -cargo build [or clippy or test or fmt] -``` - -## Open Source License - -Dual licensing under both MIT and Apache-2.0 is the currently accepted standard by the Rust language community and has been used for both the compiler and many public libraries since (see ). In order to match the community standards, webrtc-rs is using the dual MIT+Apache-2.0 license. - -## Contributing - -Contributors or Pull Requests are Welcome!!! +[APACHE]: https://github.com/instrumentisto/medea-turn-rs/blob/main/LICENSE-APACHE +[MIT]: https://github.com/instrumentisto/medea-turn-rs/blob/main/LICENSE-MIT diff --git a/_typos.toml b/_typos.toml deleted file mode 100644 index 5c7a08773..000000000 --- a/_typos.toml +++ /dev/null @@ -1,15 +0,0 @@ -[type.po] -extend-glob = ["*.csr"] -check-file = false - -[default.extend-words] -# Additionals is important for WebRTC -additionals = "additionals" -# STAP-A for WebRTC -stap = "stap" -# MIS value -mis = "mis" -# datas is used a lot for plural. -datas = "datas" -# 2nd for second -2nd = "2nd" diff --git a/codecov.yml b/codecov.yml deleted file mode 100644 index 99256353c..000000000 --- a/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 9c7a93c8-b2b2-4da3-9990-7283701dec58 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/constraints/.gitignore b/constraints/.gitignore deleted file mode 100644 index 926fe98a1..000000000 --- a/constraints/.gitignore +++ /dev/null @@ -1,88 +0,0 @@ - -# Created by https://www.toptal.com/developers/gitignore/api/rust -# Edit at https://www.toptal.com/developers/gitignore?templates=rust - -### Rust ### -# Generated by Cargo -# will have compiled files and executables -debug/ -target/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk - -# MSVC Windows builds of rustc generate these, which store debugging information -*.pdb - -# End of https://www.toptal.com/developers/gitignore/api/rust - -# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode -# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode - -### VisualStudioCode ### -.vscode/* -# !.vscode/settings.json -!.vscode/tasks.json -!.vscode/launch.json -!.vscode/extensions.json -!.vscode/*.code-snippets - -# Local History for Visual Studio Code -.history/ - -# Built Visual Studio Code Extensions -*.vsix - -### VisualStudioCode Patch ### -# Ignore all local history of files -.history -.ionide - -# Support for Project snippet scope -.vscode/*.code-snippets - -# Ignore code-workspaces -*.code-workspace - -# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode - -# Created by https://www.toptal.com/developers/gitignore/api/macos -# Edit at https://www.toptal.com/developers/gitignore?templates=macos - -### macOS ### -# General -.DS_Store -.AppleDouble -.LSOverride - -# Icon must end with two \r -Icon - -# Thumbnails -._* - -# Files that might appear in the root of a volume -.DocumentRevisions-V100 -.fseventsd -.Spotlight-V100 -.TemporaryItems -.Trashes -.VolumeIcon.icns -.com.apple.timemachine.donotpresent - -# Directories potentially created on remote AFP share -.AppleDB -.AppleDesktop -Network Trash Folder -Temporary Items -.apdisk - -### macOS Patch ### -# iCloud generated files -*.icloud - -# End of https://www.toptal.com/developers/gitignore/api/macos diff --git a/constraints/CHANGELOG.md b/constraints/CHANGELOG.md deleted file mode 100644 index c2cce53a4..000000000 --- a/constraints/CHANGELOG.md +++ /dev/null @@ -1,7 +0,0 @@ -# webrtc-constraints changelog - -## Unreleased - -## v0.1.0 - -Initial release. diff --git a/constraints/Cargo.toml b/constraints/Cargo.toml deleted file mode 100644 index 0eed74597..000000000 --- a/constraints/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -name = "webrtc-constraints" -version = "0.1.0" -authors = ["Vincent Esche "] -edition = "2021" -description = "A pure Rust implementation of WebRTC Media Constraints API" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-constraints" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/constraints" - -[dependencies] -indexmap = "2" -serde = { version = "1", features = ["derive"], optional = true } -ordered-float = { version = "4", default-features = false } -thiserror = "1" - -[dev-dependencies] -env_logger = "0.10" -lazy_static = "1" -serde_json = { version = "1", features = ["preserve_order"] } - -[features] -default = ["serde"] -serde = ["dep:serde", "indexmap/serde"] - -[[example]] -name = "json" -path = "examples/json.rs" -required-features = ["serde"] diff --git a/constraints/LICENSE-APACHE b/constraints/LICENSE-APACHE deleted file mode 100644 index b2e847a43..000000000 --- a/constraints/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/constraints/LICENSE-MIT b/constraints/LICENSE-MIT deleted file mode 100644 index 5a980079b..000000000 --- a/constraints/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 Vincent Esche - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/constraints/README.md b/constraints/README.md deleted file mode 100644 index 249f66e9f..000000000 --- a/constraints/README.md +++ /dev/null @@ -1,32 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of the SelectSettings algorithm from the WebRTC/W3C "Media Capture and Streams" spec. - - (Last synced with the spec against commit 8cea879 from 2023/01/09.) -

diff --git a/constraints/examples/json.rs b/constraints/examples/json.rs deleted file mode 100644 index 878cbb142..000000000 --- a/constraints/examples/json.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::iter::FromIterator; - -use webrtc_constraints::algorithms::{ - select_settings_candidates, ClosestToIdealPolicy, DeviceInformationExposureMode, - TieBreakingPolicy, -}; -use webrtc_constraints::property::all::name::*; -use webrtc_constraints::{ - MediaTrackConstraints, MediaTrackSettings, MediaTrackSupportedConstraints, -}; - -fn main() { - let supported_constraints = - MediaTrackSupportedConstraints::from_iter(vec![&DEVICE_ID, &HEIGHT, &WIDTH, &RESIZE_MODE]); - - // Deserialize possible settings from JSON: - let possible_settings: Vec = { - let json = serde_json::json!([ - { "deviceId": "480p", "width": 720, "height": 480, "resizeMode": "crop-and-scale" }, - { "deviceId": "720p", "width": 1280, "height": 720, "resizeMode": "crop-and-scale" }, - { "deviceId": "1080p", "width": 1920, "height": 1080, "resizeMode": "none" }, - { "deviceId": "1440p", "width": 2560, "height": 1440, "resizeMode": "none" }, - { "deviceId": "2160p", "width": 3840, "height": 2160, "resizeMode": "none" }, - ]); - serde_json::from_value(json).unwrap() - }; - - // Deserialize constraints from JSON: - let constraints: MediaTrackConstraints = { - let json = serde_json::json!({ - "width": { - "max": 2560, - }, - "height": { - "max": 1440, - }, - // Unsupported constraint, which should thus get ignored: - "frameRate": { - "exact": 30.0 - }, - // Ideal resize-mode: - "resizeMode": "none", - "advanced": [ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - { "height": 800 }, - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - { "resizeMode": "none" }, - ] - }); - serde_json::from_value(json).unwrap() - }; - - // Resolve bare values to proper constraints: - let resolved_constraints = constraints.into_resolved(); - - // Sanitize constraints, removing empty and unsupported constraints: - let sanitized_constraints = resolved_constraints.into_sanitized(&supported_constraints); - - let candidates = select_settings_candidates( - &possible_settings, - &sanitized_constraints, - DeviceInformationExposureMode::Protected, - ) - .unwrap(); - - // Specify a tie-breaking policy - // - // A couple of basic policies are provided batteries-included, - // but for more sophisticated needs you can implement your own `TieBreakingPolicy`: - let tie_breaking_policy = - ClosestToIdealPolicy::new(possible_settings[2].clone(), &supported_constraints); - - let actual = tie_breaking_policy.select_candidate(candidates); - - let expected = &possible_settings[2]; - - assert_eq!(actual, expected); -} diff --git a/constraints/examples/macros.rs b/constraints/examples/macros.rs deleted file mode 100644 index 0cf9def33..000000000 --- a/constraints/examples/macros.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::iter::FromIterator; - -use webrtc_constraints::algorithms::{ - select_settings_candidates, ClosestToIdealPolicy, DeviceInformationExposureMode, - TieBreakingPolicy, -}; -use webrtc_constraints::macros::*; -use webrtc_constraints::property::all::name::*; -use webrtc_constraints::{settings, MediaTrackSupportedConstraints, ResizeMode}; - -fn main() { - let supported_constraints = - MediaTrackSupportedConstraints::from_iter(vec![&DEVICE_ID, &HEIGHT, &WIDTH, &RESIZE_MODE]); - - let possible_settings = vec![ - settings![ - &DEVICE_ID => "480p", - &HEIGHT => 480, - &WIDTH => 720, - &RESIZE_MODE => ResizeMode::crop_and_scale(), - ], - settings![ - &DEVICE_ID => "720p", - &HEIGHT => 720, - &WIDTH => 1280, - &RESIZE_MODE => ResizeMode::crop_and_scale(), - ], - settings![ - &DEVICE_ID => "1080p", - &HEIGHT => 1080, - &WIDTH => 1920, - &RESIZE_MODE => ResizeMode::none(), - ], - settings![ - &DEVICE_ID => "1440p", - &HEIGHT => 1440, - &WIDTH => 2560, - &RESIZE_MODE => ResizeMode::none(), - ], - settings![ - &DEVICE_ID => "2160p", - &HEIGHT => 2160, - &WIDTH => 3840, - &RESIZE_MODE => ResizeMode::none(), - ], - ]; - - let constraints = constraints! { - mandatory: { - &WIDTH => value_range_constraint!{ - max: 2560 - }, - &HEIGHT => value_range_constraint!{ - max: 1440 - }, - // Unsupported constraint, which should thus get ignored: - &FRAME_RATE => value_range_constraint!{ - exact: 30.0 - }, - }, - advanced: [ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - { - &HEIGHT => value_range_constraint!{ - exact: 800 - } - }, - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - { - &RESIZE_MODE => value_constraint!{ - exact: ResizeMode::none() - } - }, - ] - }; - - // Resolve bare values to proper constraints: - let resolved_constraints = constraints.into_resolved(); - - // Sanitize constraints, removing empty and unsupported constraints: - let sanitized_constraints = resolved_constraints.to_sanitized(&supported_constraints); - - let candidates = select_settings_candidates( - &possible_settings, - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - - // Specify a tie-breaking policy - // - // A couple of basic policies are provided batteries-included, - // but for more sophisticated needs you can implement your own `TieBreakingPolicy`: - let tie_breaking_policy = - ClosestToIdealPolicy::new(possible_settings[2].clone(), &supported_constraints); - - let actual = tie_breaking_policy.select_candidate(candidates); - - let expected = &possible_settings[2]; - - assert_eq!(actual, expected); -} diff --git a/constraints/examples/native.rs b/constraints/examples/native.rs deleted file mode 100644 index 6324ade18..000000000 --- a/constraints/examples/native.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::iter::FromIterator; - -use webrtc_constraints::algorithms::{ - select_settings_candidates, ClosestToIdealPolicy, DeviceInformationExposureMode, - TieBreakingPolicy, -}; -use webrtc_constraints::property::all::name::*; -use webrtc_constraints::{ - AdvancedMediaTrackConstraints, MandatoryMediaTrackConstraints, MediaTrackConstraintSet, - MediaTrackConstraints, MediaTrackSettings, MediaTrackSupportedConstraints, ResizeMode, - ResolvedValueConstraint, ResolvedValueRangeConstraint, ValueConstraint, ValueRangeConstraint, -}; - -fn main() { - let supported_constraints = - MediaTrackSupportedConstraints::from_iter(vec![&DEVICE_ID, &HEIGHT, &WIDTH, &RESIZE_MODE]); - - let possible_settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "480p".into()), - (&HEIGHT, 480.into()), - (&WIDTH, 720.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "720p".into()), - (&HEIGHT, 720.into()), - (&WIDTH, 1280.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1080p".into()), - (&HEIGHT, 1080.into()), - (&WIDTH, 1920.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1440p".into()), - (&HEIGHT, 1440.into()), - (&WIDTH, 2560.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "2160p".into()), - (&HEIGHT, 2160.into()), - (&WIDTH, 3840.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - ]; - - let constraints = MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([ - ( - &WIDTH, - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint::default().max(2560)) - .into(), - ), - ( - &HEIGHT, - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint::default().max(1440)) - .into(), - ), - // Unsupported constraint, which should thus get ignored: - ( - &FRAME_RATE, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().exact(30.0), - ) - .into(), - ), - ]), - advanced: AdvancedMediaTrackConstraints::from_iter([ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - MediaTrackConstraintSet::from_iter([( - &HEIGHT, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().exact(800), - ) - .into(), - )]), - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - MediaTrackConstraintSet::from_iter([( - &RESIZE_MODE, - ValueConstraint::Constraint( - ResolvedValueConstraint::default().exact(ResizeMode::none()), - ) - .into(), - )]), - ]), - }; - - // Resolve bare values to proper constraints: - let resolved_constraints = constraints.into_resolved(); - - // Sanitize constraints, removing empty and unsupported constraints: - let sanitized_constraints = resolved_constraints.to_sanitized(&supported_constraints); - - let candidates = select_settings_candidates( - &possible_settings, - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - - // Specify a tie-breaking policy - // - // A couple of basic policies are provided batteries-included, - // but for more sophisticated needs you can implement your own `TieBreakingPolicy`: - let tie_breaking_policy = - ClosestToIdealPolicy::new(possible_settings[2].clone(), &supported_constraints); - - let actual = tie_breaking_policy.select_candidate(candidates); - - let expected = &possible_settings[2]; - - assert_eq!(actual, expected); -} diff --git a/constraints/src/algorithms.rs b/constraints/src/algorithms.rs deleted file mode 100644 index 46d1f4043..000000000 --- a/constraints/src/algorithms.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! Algorithms as defined in the ["Media Capture and Streams"][mediacapture_streams] spec. -//! -//! [mediacapture_streams]: https://www.w3.org/TR/mediacapture-streams/ - -mod fitness_distance; -mod select_settings; - -pub use self::fitness_distance::*; -pub use self::select_settings::*; diff --git a/constraints/src/algorithms/fitness_distance.rs b/constraints/src/algorithms/fitness_distance.rs deleted file mode 100644 index 973fb2e8a..000000000 --- a/constraints/src/algorithms/fitness_distance.rs +++ /dev/null @@ -1,132 +0,0 @@ -/// The function used to compute the "fitness distance" of a [setting][media_track_settings] value of a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// The trait corresponds to the ["fitness distance"][fitness_distance] function in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_settings]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatracksettings -/// [fitness_distance]: https://www.w3.org/TR/mediacapture-streams/#dfn-fitness-distance -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -pub trait FitnessDistance { - /// The type returned in the event of a computation error. - type Error; - - /// Computes the fitness distance of the given `subject` in the range of `0.0..=1.0`. - /// - /// A distance of `0.0` denotes it maximally fit, one of `1.0` as maximally unfit. - fn fitness_distance(&self, subject: Subject) -> Result; -} - -mod empty_constraint; -mod setting; -mod settings; -mod value_constraint; -mod value_range_constraint; -mod value_sequence_constraint; - -use std::cmp::Ordering; - -pub use self::setting::{SettingFitnessDistanceError, SettingFitnessDistanceErrorKind}; -pub use self::settings::SettingsFitnessDistanceError; - -fn nearly_cmp(lhs: f64, rhs: f64) -> Ordering { - // Based on: https://stackoverflow.com/a/32334103/227536 - - let epsilon: f64 = 128.0 * f64::EPSILON; - let abs_th: f64 = f64::MIN; - - debug_assert!(epsilon < 1.0); - - if lhs == rhs { - return Ordering::Equal; - } - - let diff = (lhs - rhs).abs(); - let norm = (lhs.abs() + rhs.abs()).min(f64::MAX); - - if diff < (epsilon * norm).max(abs_th) { - Ordering::Equal - } else if lhs < rhs { - Ordering::Less - } else { - Ordering::Greater - } -} - -fn is_nearly_greater_than_or_equal_to(actual: f64, min: f64) -> bool { - nearly_cmp(actual, min) != Ordering::Less -} - -fn is_nearly_less_than_or_equal_to(actual: f64, max: f64) -> bool { - nearly_cmp(actual, max) != Ordering::Greater -} - -fn is_nearly_equal_to(actual: f64, exact: f64) -> bool { - nearly_cmp(actual, exact) == Ordering::Equal -} - -fn relative_fitness_distance(actual: f64, ideal: f64) -> f64 { - // As specified in step 7 of the `fitness distance` algorithm: - // - // - // > For all positive numeric constraints [โ€ฆ], - // > the fitness distance is the result of the formula - // > - // > ``` - // > (actual == ideal) ? 0 : |actual - ideal| / max(|actual|, |ideal|) - // > ``` - if (actual - ideal).abs() < f64::EPSILON { - 0.0 - } else { - let numerator = (actual - ideal).abs(); - let denominator = actual.abs().max(ideal.abs()); - if denominator.abs() < f64::EPSILON { - // Avoid division by zero crashes: - 0.0 - } else { - numerator / denominator - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - mod relative_fitness_distance { - #[test] - fn zero_distance() { - // Make sure we're not dividing by zero: - assert_eq!(super::relative_fitness_distance(0.0, 0.0), 0.0); - - assert_eq!(super::relative_fitness_distance(0.5, 0.5), 0.0); - assert_eq!(super::relative_fitness_distance(1.0, 1.0), 0.0); - assert_eq!(super::relative_fitness_distance(2.0, 2.0), 0.0); - } - - #[test] - fn fract_distance() { - assert_eq!(super::relative_fitness_distance(1.0, 2.0), 0.5); - assert_eq!(super::relative_fitness_distance(2.0, 1.0), 0.5); - - assert_eq!(super::relative_fitness_distance(0.5, 1.0), 0.5); - assert_eq!(super::relative_fitness_distance(1.0, 0.5), 0.5); - - assert_eq!(super::relative_fitness_distance(0.25, 0.5), 0.5); - assert_eq!(super::relative_fitness_distance(0.5, 0.25), 0.5); - } - - #[test] - fn one_distance() { - assert_eq!(super::relative_fitness_distance(0.0, 0.5), 1.0); - assert_eq!(super::relative_fitness_distance(0.5, 0.0), 1.0); - - assert_eq!(super::relative_fitness_distance(0.0, 1.0), 1.0); - assert_eq!(super::relative_fitness_distance(1.0, 0.0), 1.0); - - assert_eq!(super::relative_fitness_distance(0.0, 2.0), 1.0); - assert_eq!(super::relative_fitness_distance(2.0, 0.0), 1.0); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/empty_constraint.rs b/constraints/src/algorithms/fitness_distance/empty_constraint.rs deleted file mode 100644 index f27329a9e..000000000 --- a/constraints/src/algorithms/fitness_distance/empty_constraint.rs +++ /dev/null @@ -1,71 +0,0 @@ -use super::setting::SettingFitnessDistanceError; -use super::FitnessDistance; -use crate::constraint::EmptyConstraint; - -impl<'a, T> FitnessDistance> for EmptyConstraint { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, _setting: Option<&'a T>) -> Result { - // As specified in step 1 of the `SelectSettings` algorithm: - // - // - // > If an empty list has been given as the value for a constraint, - // > it MUST be interpreted as if the constraint were not specified - // > (in other words, an empty constraint == no constraint). - Ok(0.0) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Constraint = EmptyConstraint; - - macro_rules! test_empty_constraint { - ( - settings: $t:ty => $s:expr, - expected: $e:expr $(,)? - ) => { - let settings: &[Option<$t>] = $s; - let constraint = &Constraint {}; - for setting in settings { - let actual = constraint.fitness_distance(setting.as_ref()); - - assert_eq!(actual, $e); - } - }; - } - - #[test] - fn bool() { - test_empty_constraint!( - settings: bool => &[None, Some(false)], - expected: Ok(0.0) - ); - } - - #[test] - fn string() { - test_empty_constraint!( - settings: String => &[None, Some("foo".to_owned())], - expected: Ok(0.0) - ); - } - - #[test] - fn i64() { - test_empty_constraint!( - settings: i64 => &[None, Some(42)], - expected: Ok(0.0) - ); - } - - #[test] - fn f64() { - test_empty_constraint!( - settings: f64 => &[None, Some(42.0)], - expected: Ok(0.0) - ); - } -} diff --git a/constraints/src/algorithms/fitness_distance/setting.rs b/constraints/src/algorithms/fitness_distance/setting.rs deleted file mode 100644 index 440a65b1c..000000000 --- a/constraints/src/algorithms/fitness_distance/setting.rs +++ /dev/null @@ -1,532 +0,0 @@ -use super::FitnessDistance; -use crate::{MediaTrackSetting, ResolvedMediaTrackConstraint}; - -/// An error indicating a rejected fitness distance computation, -/// likely caused by a mismatched yet required constraint. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct SettingFitnessDistanceError { - /// The kind of the error (e.g. missing value, mismatching value, โ€ฆ). - pub kind: SettingFitnessDistanceErrorKind, - /// The required constraint value. - pub constraint: String, - /// The offending setting value. - pub setting: Option, -} - -/// The kind of the error (e.g. missing value, mismatching value, โ€ฆ). -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -pub enum SettingFitnessDistanceErrorKind { - /// Settings value is missing. - Missing, - /// Settings value is a mismatch. - Mismatch, - /// Settings value is too small. - TooSmall, - /// Settings value is too large. - TooLarge, -} - -impl<'a> FitnessDistance> for ResolvedMediaTrackConstraint { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a MediaTrackSetting>) -> Result { - type Setting = MediaTrackSetting; - type Constraint = ResolvedMediaTrackConstraint; - - let setting = match setting { - Some(setting) => setting, - None => { - return if self.is_required() { - Err(Self::Error { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } else { - Ok(1.0) - } - } - }; - - let result = match (self, setting) { - // Empty constraint: - (ResolvedMediaTrackConstraint::Empty(constraint), setting) => { - constraint.fitness_distance(Some(setting)) - } - - // Boolean constraint: - (Constraint::Bool(constraint), Setting::Bool(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::Bool(constraint), Setting::Integer(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::Bool(constraint), Setting::Float(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::Bool(constraint), Setting::String(setting)) => { - constraint.fitness_distance(Some(setting)) - } - - // Integer constraint: - (Constraint::IntegerRange(_constraint), Setting::Bool(_setting)) => Ok(0.0), - (Constraint::IntegerRange(constraint), Setting::Integer(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::IntegerRange(constraint), Setting::Float(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::IntegerRange(_constraint), Setting::String(_setting)) => Ok(0.0), - - // Float constraint: - (Constraint::FloatRange(_constraint), Setting::Bool(_setting)) => Ok(0.0), - (Constraint::FloatRange(constraint), Setting::Integer(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::FloatRange(constraint), Setting::Float(setting)) => { - constraint.fitness_distance(Some(setting)) - } - (Constraint::FloatRange(_constraint), Setting::String(_setting)) => Ok(0.0), - - // String constraint: - (Constraint::String(_constraint), Setting::Bool(_setting)) => Ok(0.0), - (Constraint::String(_constraint), Setting::Integer(_setting)) => Ok(0.0), - (Constraint::String(_constraint), Setting::Float(_setting)) => Ok(0.0), - (Constraint::String(constraint), Setting::String(setting)) => { - constraint.fitness_distance(Some(setting)) - } - - // String sequence constraint: - (Constraint::StringSequence(_constraint), Setting::Bool(_setting)) => Ok(0.0), - (Constraint::StringSequence(_constraint), Setting::Integer(_setting)) => Ok(0.0), - (Constraint::StringSequence(_constraint), Setting::Float(_setting)) => Ok(0.0), - (Constraint::StringSequence(constraint), Setting::String(setting)) => { - constraint.fitness_distance(Some(setting)) - } - }; - - #[cfg(debug_assertions)] - if let Ok(fitness_distance) = result { - debug_assert!({ fitness_distance.is_finite() }); - } - - result - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::constraint::EmptyConstraint; - use crate::{MediaTrackSetting, ResolvedMediaTrackConstraint}; - - #[test] - fn empty_constraint() { - // As per step 1 of the `SelectSettings` algorithm from the W3C spec: - // - // - // > Each constraint specifies one or more values (or a range of values) for its property. - // > A property MAY appear more than once in the list of 'advanced' ConstraintSets. - // > If an empty list has been given as the value for a constraint, - // > it MUST be interpreted as if the constraint were not specified - // > (in other words, an empty constraint == no constraint). - let constraint = ResolvedMediaTrackConstraint::Empty(EmptyConstraint {}); - - let settings = [ - MediaTrackSetting::Bool(true), - MediaTrackSetting::Integer(42), - MediaTrackSetting::Float(4.2), - MediaTrackSetting::String("string".to_owned()), - ]; - - let expected = 0.0; - - for setting in settings { - let actual = constraint.fitness_distance(Some(&setting)).unwrap(); - - assert_eq!(actual, expected); - } - } - - mod bool_constraint { - use super::*; - use crate::ResolvedValueConstraint; - - #[test] - fn bool_setting() { - // As per step 8 of the `fitness distance` function from the W3C spec: - // - // - // > For all string, enum and boolean constraints - // > (e.g. deviceId, groupId, facingMode, resizeMode, echoCancellation), - // > the fitness distance is the result of the formula: - // > - // > ``` - // > (actual == ideal) ? 0 : 1 - // > ``` - - let scenarios = [(false, false), (false, true), (true, false), (true, true)]; - - for (constraint_value, setting_value) in scenarios { - let constraint = ResolvedMediaTrackConstraint::Bool(ResolvedValueConstraint { - exact: None, - ideal: Some(constraint_value), - }); - - let setting = MediaTrackSetting::Bool(setting_value); - - let actual = constraint.fitness_distance(Some(&setting)).unwrap(); - - let expected = if constraint_value == setting_value { - 0.0 - } else { - 1.0 - }; - - assert_eq!(actual, expected); - } - } - - #[test] - fn non_bool_settings() { - // As per step 4 of the `fitness distance` function from the W3C spec: - // - // - // > If constraintValue is a boolean, but the constrainable property is not, - // > then the fitness distance is based on whether the settings dictionary's - // > constraintName member exists or not, from the formula: - // > - // > ``` - // > (constraintValue == exists) ? 0 : 1 - // > ``` - - let settings = [ - MediaTrackSetting::Integer(42), - MediaTrackSetting::Float(4.2), - MediaTrackSetting::String("string".to_owned()), - ]; - - let scenarios = [(false, false), (false, true), (true, false), (true, true)]; - - for (constraint_value, setting_value) in scenarios { - let constraint = ResolvedMediaTrackConstraint::Bool(ResolvedValueConstraint { - exact: None, - ideal: Some(constraint_value), - }); - - for setting in settings.iter() { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let setting = if setting_value { Some(setting) } else { None }; - let actual = constraint.fitness_distance(setting).unwrap(); - - let expected = if setting_value { 0.0 } else { 1.0 }; - - assert_eq!(actual, expected); - } - } - } - } - - mod numeric_constraint { - use super::*; - use crate::ResolvedValueRangeConstraint; - - #[test] - fn missing_settings() { - // As per step 5 of the `fitness distance` function from the W3C spec: - // - // - // > If the settings dictionary's constraintName member does not exist, - // > the fitness distance is 1. - - let constraints = [ - ResolvedMediaTrackConstraint::IntegerRange(ResolvedValueRangeConstraint { - exact: None, - ideal: Some(42), - min: None, - max: None, - }), - ResolvedMediaTrackConstraint::FloatRange(ResolvedValueRangeConstraint { - exact: None, - ideal: Some(42.0), - min: None, - max: None, - }), - ]; - - for constraint in constraints { - let actual = constraint.fitness_distance(None).unwrap(); - - let expected = 1.0; - - assert_eq!(actual, expected); - } - } - - #[test] - fn compatible_settings() { - // As per step 7 of the `fitness distance` function from the W3C spec: - // - // - // > For all positive numeric constraints - // > (such as height, width, frameRate, aspectRatio, sampleRate and sampleSize), - // > the fitness distance is the result of the formula - // > - // > ``` - // > (actual == ideal) ? 0 : |actual - ideal| / max(|actual|, |ideal|) - // > ``` - - let settings = [ - MediaTrackSetting::Integer(21), - MediaTrackSetting::Float(21.0), - ]; - - let constraints = [ - ResolvedMediaTrackConstraint::IntegerRange(ResolvedValueRangeConstraint { - exact: None, - ideal: Some(42), - min: None, - max: None, - }), - ResolvedMediaTrackConstraint::FloatRange(ResolvedValueRangeConstraint { - exact: None, - ideal: Some(42.0), - min: None, - max: None, - }), - ]; - - for constraint in constraints { - for setting in settings.iter() { - let actual = constraint.fitness_distance(Some(setting)).unwrap(); - - let expected = 0.5; - - assert_eq!(actual, expected); - } - } - } - - #[test] - fn incompatible_settings() { - // As per step 3 of the `fitness distance` function from the W3C spec: - // - // - // > If the constraint does not apply for this type of object, the fitness distance is 0 - // > (that is, the constraint does not influence the fitness distance). - - let settings = [ - MediaTrackSetting::Bool(true), - MediaTrackSetting::String("string".to_owned()), - ]; - - let constraints = [ - ResolvedMediaTrackConstraint::IntegerRange(ResolvedValueRangeConstraint { - exact: None, - ideal: Some(42), - min: None, - max: None, - }), - ResolvedMediaTrackConstraint::FloatRange(ResolvedValueRangeConstraint { - exact: None, - ideal: Some(42.0), - min: None, - max: None, - }), - ]; - - for constraint in constraints { - for setting in settings.iter() { - let actual = constraint.fitness_distance(Some(setting)).unwrap(); - - let expected = 0.0; - - println!("constraint: {constraint:?}"); - println!("setting: {setting:?}"); - println!("actual: {actual:?}"); - println!("expected: {expected:?}"); - - assert_eq!(actual, expected); - } - } - } - } - - mod string_constraint { - use super::*; - use crate::ResolvedValueConstraint; - - #[test] - fn missing_settings() { - // As per step 5 of the `fitness distance` function from the W3C spec: - // - // - // > If the settings dictionary's constraintName member does not exist, - // > the fitness distance is 1. - - let constraint = ResolvedMediaTrackConstraint::String(ResolvedValueConstraint { - exact: None, - ideal: Some("constraint".to_owned()), - }); - - let actual = constraint.fitness_distance(None).unwrap(); - - let expected = 1.0; - - assert_eq!(actual, expected); - } - - #[test] - fn compatible_settings() { - // As per step 8 of the `fitness distance` function from the W3C spec: - // - // - // > For all string, enum and boolean constraints - // > (e.g. deviceId, groupId, facingMode, resizeMode, echoCancellation), - // > the fitness distance is the result of the formula: - // > - // > ``` - // > (actual == ideal) ? 0 : 1 - // > ``` - - let constraint = ResolvedMediaTrackConstraint::String(ResolvedValueConstraint { - exact: None, - ideal: Some("constraint".to_owned()), - }); - - let settings = [MediaTrackSetting::String("setting".to_owned())]; - - for setting in settings { - let actual = constraint.fitness_distance(Some(&setting)).unwrap(); - - let expected = 1.0; - - assert_eq!(actual, expected); - } - } - - #[test] - fn incompatible_settings() { - // As per step 3 of the `fitness distance` function from the W3C spec: - // - // - // > If the constraint does not apply for this type of object, the fitness distance is 0 - // > (that is, the constraint does not influence the fitness distance). - - let constraint = ResolvedMediaTrackConstraint::String(ResolvedValueConstraint { - exact: None, - ideal: Some("string".to_owned()), - }); - - let settings = [ - MediaTrackSetting::Bool(true), - MediaTrackSetting::Integer(42), - MediaTrackSetting::Float(4.2), - ]; - - for setting in settings { - let actual = constraint.fitness_distance(Some(&setting)).unwrap(); - - let expected = 0.0; - - println!("constraint: {constraint:?}"); - println!("setting: {setting:?}"); - println!("actual: {actual:?}"); - println!("expected: {expected:?}"); - - assert_eq!(actual, expected); - } - } - } - - mod string_sequence_constraint { - use super::*; - use crate::ResolvedValueSequenceConstraint; - - #[test] - fn missing_settings() { - // As per step 5 of the `fitness distance` function from the W3C spec: - // - // - // > If the settings dictionary's constraintName member does not exist, - // > the fitness distance is 1. - - let constraint = - ResolvedMediaTrackConstraint::StringSequence(ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec!["constraint".to_owned()]), - }); - - let actual = constraint.fitness_distance(None).unwrap(); - - let expected = 1.0; - - assert_eq!(actual, expected); - } - - #[test] - fn compatible_settings() { - // As per step 8 of the `fitness distance` function from the W3C spec: - // - // - // > For all string, enum and boolean constraints - // > (e.g. deviceId, groupId, facingMode, resizeMode, echoCancellation), - // > the fitness distance is the result of the formula: - // > - // > ``` - // > (actual == ideal) ? 0 : 1 - // > ``` - // - // As well as the preliminary definition: - // - // > For string valued constraints, we define "==" below to be true if one of the - // > values in the sequence is exactly the same as the value being compared against. - - let constraint = - ResolvedMediaTrackConstraint::StringSequence(ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec!["constraint".to_owned()]), - }); - - let settings = [MediaTrackSetting::String("setting".to_owned())]; - - for setting in settings { - let actual = constraint.fitness_distance(Some(&setting)).unwrap(); - - let expected = 1.0; - - assert_eq!(actual, expected); - } - } - - #[test] - fn incompatible_settings() { - // As per step 3 of the `fitness distance` function from the W3C spec: - // - // - // > If the constraint does not apply for this type of object, the fitness distance is 0 - // > (that is, the constraint does not influence the fitness distance). - - let constraint = - ResolvedMediaTrackConstraint::StringSequence(ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec!["constraint".to_owned()]), - }); - - let settings = [ - MediaTrackSetting::Bool(true), - MediaTrackSetting::Integer(42), - MediaTrackSetting::Float(4.2), - ]; - - for setting in settings { - let actual = constraint.fitness_distance(Some(&setting)).unwrap(); - - let expected = 0.0; - - assert_eq!(actual, expected); - } - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/settings.rs b/constraints/src/algorithms/fitness_distance/settings.rs deleted file mode 100644 index d37f7c3ec..000000000 --- a/constraints/src/algorithms/fitness_distance/settings.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::collections::HashMap; - -use super::setting::SettingFitnessDistanceError; -use super::FitnessDistance; -use crate::{MediaTrackProperty, MediaTrackSettings, SanitizedMediaTrackConstraintSet}; - -/// A list of media track properties and their corresponding fitness distance errors. -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct SettingsFitnessDistanceError { - /// Setting errors per media track property. - pub setting_errors: HashMap, -} - -impl<'a> FitnessDistance<&'a MediaTrackSettings> for SanitizedMediaTrackConstraintSet { - type Error = SettingsFitnessDistanceError; - - fn fitness_distance(&self, settings: &'a MediaTrackSettings) -> Result { - let results: HashMap = self - .iter() - .map(|(property, constraint)| { - let setting = settings.get(property); - let result = constraint.fitness_distance(setting); - (property.clone(), result) - }) - .collect(); - - let mut total_fitness_distance = 0.0; - - let mut setting_errors: HashMap = - Default::default(); - - for (property, result) in results.into_iter() { - match result { - Ok(fitness_distance) => total_fitness_distance += fitness_distance, - Err(error) => { - setting_errors.insert(property, error); - } - } - } - - if setting_errors.is_empty() { - Ok(total_fitness_distance) - } else { - Err(SettingsFitnessDistanceError { setting_errors }) - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_constraint.rs b/constraints/src/algorithms/fitness_distance/value_constraint.rs deleted file mode 100644 index 3ed6755d9..000000000 --- a/constraints/src/algorithms/fitness_distance/value_constraint.rs +++ /dev/null @@ -1,218 +0,0 @@ -use super::setting::SettingFitnessDistanceError; -use super::{FitnessDistance, SettingFitnessDistanceErrorKind}; -use crate::constraint::ResolvedValueConstraint; - -// Standard implementation for value constraints of arbitrary `Setting` and `Constraint` -// types where `Setting: PartialEq`: -macro_rules! impl_non_numeric_value_constraint { - (setting: $s:ty, constraint: $c:ty) => { - impl<'a> FitnessDistance> for ResolvedValueConstraint<$c> - where - $s: PartialEq<$c>, - { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a $s>) -> Result { - if let Some(exact) = self.exact.as_ref() { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'exact' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(actual) if actual == exact => {} - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(ideal) = self.ideal.as_ref() { - match setting { - Some(actual) if actual == ideal => { - // As specified in step 8 of the `fitness distance` algorithm: - // - // - // > For all string, enum and boolean constraints [โ€ฆ], - // > the fitness distance is the result of the formula: - // > - // > ``` - // > (actual == ideal) ? 0 : 1 - // > ``` - Ok(0.0) - } - _ => { - // As specified in step 5 of the `fitness distance` algorithm: - // - // - // > If the settings dictionary's `constraintName` member - // > does not exist, the fitness distance is 1. - Ok(1.0) - } - } - } else { - // As specified in step 6 of the `fitness distance` algorithm: - // - // - // > If no ideal value is specified (constraintValue either - // > contains no member named 'ideal', or, if bare values are to be - // > treated as 'ideal', isn't a bare value), the fitness distance is 0. - Ok(0.0) - } - } - } - }; -} - -impl_non_numeric_value_constraint!(setting: bool, constraint: bool); -impl_non_numeric_value_constraint!(setting: String, constraint: String); - -// Specialized implementations for floating-point value constraints (and settings): - -macro_rules! impl_numeric_value_constraint { - (setting: $s:ty, constraint: $c:ty) => { - impl<'a> FitnessDistance> for ResolvedValueConstraint<$c> { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a $s>) -> Result { - if let Some(exact) = self.exact { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'exact' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(&actual) if super::is_nearly_equal_to(actual as f64, exact as f64) => { - } - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(ideal) = self.ideal { - match setting { - Some(&actual) => { - let actual: f64 = actual as f64; - let ideal: f64 = ideal as f64; - // As specified in step 7 of the `fitness distance` algorithm: - // - // - // > For all positive numeric constraints [โ€ฆ], - // > the fitness distance is the result of the formula - // > - // > ``` - // > (actual == ideal) ? 0 : |actual - ideal| / max(|actual|, |ideal|) - // > ``` - Ok(super::relative_fitness_distance(actual, ideal)) - } - None => { - // As specified in step 5 of the `fitness distance` algorithm: - // - // - // > If the settings dictionary's `constraintName` member - // > does not exist, the fitness distance is 1. - Ok(1.0) - } - } - } else { - // As specified in step 6 of the `fitness distance` algorithm: - // - // - // > If no ideal value is specified (constraintValue either - // > contains no member named 'ideal', or, if bare values are to be - // > treated as 'ideal', isn't a bare value), the fitness distance is 0. - Ok(0.0) - } - } - } - }; -} - -impl_numeric_value_constraint!(setting: f64, constraint: f64); -impl_numeric_value_constraint!(setting: i64, constraint: u64); -impl_numeric_value_constraint!(setting: i64, constraint: f64); -impl_numeric_value_constraint!(setting: f64, constraint: u64); - -// Specialized implementations for boolean value constraints of mismatching -// and thus either "existence"-checked or ignored setting types: -macro_rules! impl_exists_value_constraint { - (settings: [$($s:ty),+], constraint: bool) => { - $(impl_exists_value_constraint!(setting: $s, constraint: bool);)+ - }; - (setting: $s:ty, constraint: bool) => { - impl<'a> FitnessDistance> for ResolvedValueConstraint { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a $s>) -> Result { - // A bare boolean value (as described in step 4 of the - // `fitness distance` algorithm) gets parsed as: - // ``` - // ResolvedValueConstraint:: { - // exact: Some(bare), - // ideal: None, - // } - // ``` - // - // For all other configurations we just interpret it as an incompatible constraint. - match self.exact { - // As specified in step 4 of the `fitness distance` algorithm: - // - // - // > If constraintValue is a boolean, but the constrainable property is not, - // > then the fitness distance is based on whether the settings dictionary's - // > `constraintName` member exists or not, from the formula: - // > - // > ``` - // > (constraintValue == exists) ? 0 : 1 - // > ``` - Some(expected) => { - if setting.is_some() == expected { - Ok(0.0) - } else { - Ok(1.0) - } - } - // As specified in step 3 of the `fitness distance` algorithm: - // - // - // > If the constraint does not apply for this type of object, - // > the fitness distance is 0 (that is, the constraint does not - // > influence the fitness distance). - None => Ok(0.0), - } - } - } - }; -} - -impl_exists_value_constraint!(settings: [String, i64, f64], constraint: bool); - -#[cfg(test)] -mod tests; diff --git a/constraints/src/algorithms/fitness_distance/value_constraint/tests.rs b/constraints/src/algorithms/fitness_distance/value_constraint/tests.rs deleted file mode 100644 index 7090c6626..000000000 --- a/constraints/src/algorithms/fitness_distance/value_constraint/tests.rs +++ /dev/null @@ -1,144 +0,0 @@ -use super::*; - -macro_rules! generate_value_constraint_tests { - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr $(,)? - }),+ $(,)? - ], - constraints: $ct:ty => $ce:expr, - expected: $e:expr $(,)? - ) => { - generate_value_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - constraints: $ct => $ce, - }),+ - ], - expected: $e - ); - }; - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr $(,)? - }),+ $(,)? - ], - expected: $e:expr $(,)? - ) => { - generate_value_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - constraints: $ct => $ce, - }),+ - ], - validate: |result| { - assert_eq!(result, $e); - } - ); - }; - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr $(,)? - }),+ $(,)? - ], - validate: |$a:ident| $b:block - ) => { - $( - #[test] - fn $ti() { - test_value_constraint!( - settings: $st => $se, - constraints: $ct => $ce, - validate: |$a| $b - ); - } - )+ - }; -} - -macro_rules! test_value_constraint { - ( - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr, - expected: $e:expr $(,)? - ) => { - test_value_constraint!( - settings: $st => $se, - constraints: $ct => $ce, - validate: |result| { - assert_eq!(result, $e); - } - ); - }; - ( - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr, - validate: |$a:ident| $b:block - ) => {{ - let settings: &[Option<$st>] = $se; - let constraints: &[ResolvedValueConstraint<$ct>] = $ce; - - for constraint in constraints { - for setting in settings { - let closure = |$a| $b; - let actual = constraint.fitness_distance(setting.as_ref()); - closure(actual); - } - } - }}; - ( - checks: [ - $({ - setting: $st:ty => $se:expr, - constraint: $ct:ty => $ce:expr, - expected: $ee:expr $(,)? - }),+ $(,)? - ] - ) => { - test_value_constraint!( - checks: [ - $({ - setting: $st => $se, - constraint: $ct => $ce, - expected: $ee, - }),+ - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - }; - ( - checks: [ - $({ - setting: $st:ty => $se:expr, - constraint: $ct:ty => $ce:expr, - expected: $ee:expr $(,)? - }),+ $(,)? - ], - validate: |$ai:ident, $ei:ident| $b:block - ) => {{ - $({ - let closure = |$ai, $ei| $b; - let actual = $ce.fitness_distance($se.as_ref()); - closure(actual, $ee); - })+ - }}; -} - -mod bool; -mod f64; -mod string; -mod u64; diff --git a/constraints/src/algorithms/fitness_distance/value_constraint/tests/bool.rs b/constraints/src/algorithms/fitness_distance/value_constraint/tests/bool.rs deleted file mode 100644 index aec57d984..000000000 --- a/constraints/src/algorithms/fitness_distance/value_constraint/tests/bool.rs +++ /dev/null @@ -1,283 +0,0 @@ -use super::*; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: None, - ideal: None, - }, - ResolvedValueConstraint { - exact: None, - ideal: Some(true), - }, - ], - expected: Ok(0.0) - ); - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: None, - ideal: Some(false), - }, - ], - expected: Ok(0.0) - ); - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[None, Some(false)], - }, - ], - constraints: bool => &[ResolvedValueConstraint { - exact: None, - ideal: Some(true), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - // A constraint that does apply for a type of setting, - // is expected to return a fitness distance of `0`, - // iff the setting matches the constraint: - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - ], - constraints: bool => &[ResolvedValueConstraint { - exact: Some(true), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[None], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: Some(true), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(true), - ideal: Some(true), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == true)".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(false)], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: Some(true), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(true), - ideal: Some(true), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == true)".to_owned(), - setting: Some("false".to_owned()), - }) - ); - } - } - - // Required boolean constraints have specialized logic as per - // rule 4 of the fitness distance algorithm specification: - // - - mod specialization { - use super::*; - - mod expected { - use super::*; - - mod existing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: Some(true), - ideal: None, - }, - ], - expected: Ok(0.0) - ); - } - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[None], - }, - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: Some(true), - ideal: None, - }, - ], - expected: Ok(1.0) - ); - } - } - - mod unexpected { - use super::*; - - mod existing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: Some(false), - ideal: None, - }, - ], - expected: Ok(1.0) - ); - } - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[None], - }, - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: bool => &[ - ResolvedValueConstraint { - exact: Some(false), - ideal: None, - }, - ], - expected: Ok(0.0) - ); - } - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_constraint/tests/f64.rs b/constraints/src/algorithms/fitness_distance/value_constraint/tests/f64.rs deleted file mode 100644 index 0ad486021..000000000 --- a/constraints/src/algorithms/fitness_distance/value_constraint/tests/f64.rs +++ /dev/null @@ -1,243 +0,0 @@ -use super::*; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: f64 => &[ - ResolvedValueConstraint { - exact: None, - ideal: None, - }, - ResolvedValueConstraint { - exact: None, - ideal: Some(42.0), - }, - ], - expected: Ok(0.0) - ); - } - - mod fract_distance { - use super::*; - - #[test] - fn i64_setting() { - test_value_constraint!( - checks: [ - { - setting: i64 => Some(1), - constraint: f64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.75), - }, - { - setting: i64 => Some(2), - constraint: f64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.5), - }, - { - setting: i64 => Some(3), - constraint: f64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - - #[test] - fn f64_setting() { - test_value_constraint!( - checks: [ - { - setting: f64 => Some(1.0), - constraint: f64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.75), - }, - { - setting: f64 => Some(2.0), - constraint: f64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.5), - }, - { - setting: f64 => Some(3.0), - constraint: f64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: f64 => &[ResolvedValueConstraint { - exact: None, - ideal: Some(42.0), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: f64 => &[ResolvedValueConstraint { - exact: Some(42.0), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: f64 => &[ - ResolvedValueConstraint { - exact: Some(42.0), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(42.0), - ideal: Some(42.0), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == 42.0)".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - ], - constraints: f64 => &[ - ResolvedValueConstraint { - exact: Some(42.0), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(42.0), - ideal: Some(42.0), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42.0)".to_owned(), - setting: Some("0".to_owned()), - }) - ); - - generate_value_constraint_tests!( - tests: [ - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: f64 => &[ - ResolvedValueConstraint { - exact: Some(42.0), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(42.0), - ideal: Some(42.0), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42.0)".to_owned(), - setting: Some("0.0".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_constraint/tests/string.rs b/constraints/src/algorithms/fitness_distance/value_constraint/tests/string.rs deleted file mode 100644 index ca663a446..000000000 --- a/constraints/src/algorithms/fitness_distance/value_constraint/tests/string.rs +++ /dev/null @@ -1,132 +0,0 @@ -use super::*; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: String => &[ - ResolvedValueConstraint { - exact: None, - ideal: None, - }, - ResolvedValueConstraint { - exact: None, - ideal: Some("foo".to_owned()), - }, - ], - expected: Ok(0.0) - ); - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[None, Some("bar".to_owned())], - }, - ], - constraints: String => &[ResolvedValueConstraint { - exact: None, - ideal: Some("foo".to_owned()), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - // A constraint that does apply for a type of setting, - // is expected to return a fitness distance of `0`, - // iff the setting matches the constraint: - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: String => &[ResolvedValueConstraint { - exact: Some("foo".to_owned()), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[None], - }, - ], - constraints: String => &[ - ResolvedValueConstraint { - exact: Some("foo".to_owned()), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some("foo".to_owned()), - ideal: Some("foo".to_owned()), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == \"foo\")".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("bar".to_owned())], - }, - ], - constraints: String => &[ - ResolvedValueConstraint { - exact: Some("foo".to_owned()), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some("foo".to_owned()), - ideal: Some("foo".to_owned()), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == \"foo\")".to_owned(), - setting: Some("\"bar\"".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_constraint/tests/u64.rs b/constraints/src/algorithms/fitness_distance/value_constraint/tests/u64.rs deleted file mode 100644 index 81be604ee..000000000 --- a/constraints/src/algorithms/fitness_distance/value_constraint/tests/u64.rs +++ /dev/null @@ -1,243 +0,0 @@ -use super::*; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: u64 => &[ - ResolvedValueConstraint { - exact: None, - ideal: None, - }, - ResolvedValueConstraint { - exact: None, - ideal: Some(42), - }, - ], - expected: Ok(0.0) - ); - } - - mod fract_distance { - use super::*; - - #[test] - fn i64_setting() { - test_value_constraint!( - checks: [ - { - setting: i64 => Some(1), - constraint: i64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4), - }, - expected: Ok(0.75), - }, - { - setting: i64 => Some(2), - constraint: i64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4), - }, - expected: Ok(0.5), - }, - { - setting: i64 => Some(3), - constraint: i64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - - #[test] - fn f64_setting() { - test_value_constraint!( - checks: [ - { - setting: f64 => Some(1.0), - constraint: u64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4), - }, - expected: Ok(0.75), - }, - { - setting: f64 => Some(2.0), - constraint: u64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4), - }, - expected: Ok(0.5), - }, - { - setting: f64 => Some(3.0), - constraint: u64 => ResolvedValueConstraint { - exact: None, - ideal: Some(4), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None, Some(0)], - }, - { - name: f64_setting, - settings: f64 => &[None, Some(0.0)], - }, - ], - constraints: u64 => &[ResolvedValueConstraint { - exact: None, - ideal: Some(42), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: u64 => &[ResolvedValueConstraint { - exact: Some(42), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: u64 => &[ - ResolvedValueConstraint { - exact: Some(42), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(42), - ideal: Some(42), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == 42)".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - ], - constraints: u64 => &[ - ResolvedValueConstraint { - exact: Some(42), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(42), - ideal: Some(42), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42)".to_owned(), - setting: Some("0".to_owned()), - }) - ); - - generate_value_constraint_tests!( - tests: [ - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: u64 => &[ - ResolvedValueConstraint { - exact: Some(42), - ideal: None, - }, - ResolvedValueConstraint { - exact: Some(42), - ideal: Some(42), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42)".to_owned(), - setting: Some("0.0".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_range_constraint.rs b/constraints/src/algorithms/fitness_distance/value_range_constraint.rs deleted file mode 100644 index 68147c768..000000000 --- a/constraints/src/algorithms/fitness_distance/value_range_constraint.rs +++ /dev/null @@ -1,172 +0,0 @@ -use super::setting::SettingFitnessDistanceError; -use super::{FitnessDistance, SettingFitnessDistanceErrorKind}; -use crate::ResolvedValueRangeConstraint; - -macro_rules! impl_value_range_constraint { - (setting: $s:ty, constraint: $c:ty) => { - impl<'a> FitnessDistance> for ResolvedValueRangeConstraint<$c> { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a $s>) -> Result { - if let Some(exact) = self.exact { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'exact' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(&actual) if super::is_nearly_equal_to(actual as f64, exact as f64) => { - } - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(min) = self.min { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'min' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(&actual) - if super::is_nearly_greater_than_or_equal_to( - actual as f64, - min as f64, - ) => {} - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::TooSmall, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(max) = self.max { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'max' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(&actual) - if super::is_nearly_less_than_or_equal_to( - actual as f64, - max as f64, - ) => {} - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::TooLarge, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(ideal) = self.ideal { - match setting { - Some(&actual) => { - let actual: f64 = actual as f64; - let ideal: f64 = ideal as f64; - // As specified in step 7 of the `fitness distance` algorithm: - // - // - // > For all positive numeric constraints [โ€ฆ], - // > the fitness distance is the result of the formula - // > - // > ``` - // > (actual == ideal) ? 0 : |actual - ideal| / max(|actual|, |ideal|) - // > ``` - Ok(super::relative_fitness_distance(actual, ideal)) - } - None => { - // As specified in step 5 of the `fitness distance` algorithm: - // - // - // > If the settings dictionary's `constraintName` member - // > does not exist, the fitness distance is 1. - Ok(1.0) - } - } - } else { - // As specified in step 6 of the `fitness distance` algorithm: - // - // - // > If no ideal value is specified (constraintValue either - // > contains no member named 'ideal', or, if bare values are to be - // > treated as 'ideal', isn't a bare value), the fitness distance is 0. - Ok(0.0) - } - } - } - }; -} - -impl_value_range_constraint!(setting: f64, constraint: f64); -impl_value_range_constraint!(setting: i64, constraint: u64); -impl_value_range_constraint!(setting: i64, constraint: f64); -impl_value_range_constraint!(setting: f64, constraint: u64); - -// Specialized implementations for non-boolean value constraints of mismatching, -// and thus ignored setting types: -macro_rules! impl_ignored_value_range_constraint { - (settings: [$($s:ty),+], constraint: $c:ty) => { - $(impl_ignored_value_range_constraint!(setting: $s, constraint: $c);)+ - }; - (setting: $s:ty, constraint: $c:ty) => { - impl<'a> FitnessDistance> for ResolvedValueRangeConstraint<$c> { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, _setting: Option<&'a $s>) -> Result { - // As specified in step 3 of the `fitness distance` algorithm: - // - // - // > If the constraint does not apply for this type of object, - // > the fitness distance is 0 (that is, the constraint does not - // > influence the fitness distance). - Ok(0.0) - } - } - }; -} - -impl_ignored_value_range_constraint!(settings: [bool, String], constraint: u64); -impl_ignored_value_range_constraint!(settings: [bool, String], constraint: f64); - -#[cfg(test)] -mod tests; diff --git a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests.rs b/constraints/src/algorithms/fitness_distance/value_range_constraint/tests.rs deleted file mode 100644 index 97e495aee..000000000 --- a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests.rs +++ /dev/null @@ -1,143 +0,0 @@ -use super::*; - -macro_rules! generate_value_range_constraint_tests { - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr $(,)? - }),+ $(,)? - ], - constraints: $ct:ty => $ce:expr, - expected: $e:expr $(,)? - ) => { - generate_value_range_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - constraints: $ct => $ce, - }),+ - ], - expected: $e - ); - }; - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr $(,)? - }),+ $(,)? - ], - expected: $e:expr $(,)? - ) => { - generate_value_range_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - constraints: $ct => $ce, - }),+ - ], - validate: |result| { - assert_eq!(result, $e); - } - ); - }; - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr $(,)? - }),+ $(,)? - ], - validate: |$a:ident| $b:block - ) => { - $( - #[test] - fn $ti() { - test_value_range_constraint!( - settings: $st => $se, - constraints: $ct => $ce, - validate: |$a| $b - ); - } - )+ - }; -} - -macro_rules! test_value_range_constraint { - ( - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr, - expected: $e:expr $(,)? - ) => { - test_value_range_constraint!( - settings: $st => $se, - constraints: $ct => $ce, - validate: |result| { - assert_eq!(result, $e); - } - ); - }; - ( - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr, - validate: |$a:ident| $b:block - ) => {{ - let settings: &[Option<$st>] = $se; - let constraints: &[ResolvedValueRangeConstraint<$ct>] = $ce; - - for constraint in constraints { - for setting in settings { - let closure = |$a| $b; - let actual = constraint.fitness_distance(setting.as_ref()); - closure(actual); - } - } - }}; - ( - checks: [ - $({ - setting: $st:ty => $se:expr, - constraint: $ct:ty => $ce:expr, - expected: $ee:expr $(,)? - }),+ $(,)? - ] - ) => { - test_value_range_constraint!( - checks: [ - $({ - setting: $st => $se, - constraint: $ct => $ce, - expected: $ee, - }),+ - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - }; - ( - checks: [ - $({ - setting: $st:ty => $se:expr, - constraint: $ct:ty => $ce:expr, - expected: $ee:expr $(,)? - }),+ $(,)? - ], - validate: |$ai:ident, $ei:ident| $b:block - ) => {{ - $({ - let closure = |$ai, $ei| $b; - let actual = $ce.fitness_distance($se.as_ref()); - closure(actual, $ee); - })+ - }}; -} - -mod empty; -mod f64; -mod u64; diff --git a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/empty.rs b/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/empty.rs deleted file mode 100644 index c2fdce3fc..000000000 --- a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/empty.rs +++ /dev/null @@ -1,75 +0,0 @@ -use super::*; - -macro_rules! generate_empty_value_range_constraint_tests { - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr $(,)? - }),+ $(,)? - ], - constraint: $ct:ty $(,)? - ) => { - generate_value_range_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - }),+ - ], - constraints: $ct => &[ - ResolvedValueRangeConstraint::<$ct> { - min: None, - max: None, - exact: None, - ideal: None, - } - ], - expected: Ok(0.0) - ); - }; -} - -mod u64_constraint { - use super::*; - - generate_empty_value_range_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(false)], - }, - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraint: u64, - ); -} - -mod f64_constraint { - use super::*; - - generate_empty_value_range_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(false)], - }, - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - ], - constraint: f64, - ); -} diff --git a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/f64.rs b/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/f64.rs deleted file mode 100644 index 79a4afc45..000000000 --- a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/f64.rs +++ /dev/null @@ -1,312 +0,0 @@ -use super::*; -use crate::algorithms::SettingFitnessDistanceErrorKind; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: f64 => &[ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(42.0), - }], - expected: Ok(0.0) - ); - - generate_value_range_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: f64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(42.0), - } - ], - expected: Ok(0.0) - ); - } - - mod fract_distance { - use super::*; - - #[test] - fn i64_setting() { - test_value_range_constraint!( - checks: [ - { - setting: i64 => Some(1), - constraint: f64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.75), - }, - { - setting: i64 => Some(2), - constraint: f64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.5), - }, - { - setting: i64 => Some(3), - constraint: f64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - - #[test] - fn f64_setting() { - test_value_range_constraint!( - checks: [ - { - setting: f64 => Some(1.0), - constraint: f64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.75), - }, - { - setting: f64 => Some(2.0), - constraint: f64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.5), - }, - { - setting: f64 => Some(3.0), - constraint: f64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4.0), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - } - - mod one_distance { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: f64 => &[ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(42.0), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: f64 => &[ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: None, - }], - expected: Ok(0.0) - ); - - generate_value_range_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: f64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: None, - } - ], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: f64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: None, - }, - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: Some(42.0), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == 42.0)".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - ], - constraints: f64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: None, - }, - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: Some(42.0), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42.0)".to_owned(), - setting: Some("0".to_owned()), - }) - ); - - generate_value_range_constraint_tests!( - tests: [ - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: f64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: None, - }, - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42.0), - ideal: Some(42.0), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42.0)".to_owned(), - setting: Some("0.0".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/u64.rs b/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/u64.rs deleted file mode 100644 index 72ec1845e..000000000 --- a/constraints/src/algorithms/fitness_distance/value_range_constraint/tests/u64.rs +++ /dev/null @@ -1,312 +0,0 @@ -use super::*; -use crate::algorithms::SettingFitnessDistanceErrorKind; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: u64 => &[ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(42), - }], - expected: Ok(0.0) - ); - - generate_value_range_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: u64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(42), - } - ], - expected: Ok(0.0) - ); - } - - mod fract_distance { - use super::*; - - #[test] - fn i64_setting() { - test_value_range_constraint!( - checks: [ - { - setting: i64 => Some(1), - constraint: u64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4), - }, - expected: Ok(0.75), - }, - { - setting: i64 => Some(2), - constraint: u64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4), - }, - expected: Ok(0.5), - }, - { - setting: i64 => Some(3), - constraint: u64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - - #[test] - fn f64_setting() { - test_value_range_constraint!( - checks: [ - { - setting: f64 => Some(1.0), - constraint: u64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4), - }, - expected: Ok(0.75), - }, - { - setting: f64 => Some(2.0), - constraint: u64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4), - }, - expected: Ok(0.5), - }, - { - setting: f64 => Some(3.0), - constraint: u64 => ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(4), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - } - - mod one_distance { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None, Some(0)], - }, - { - name: f64_setting, - settings: f64 => &[None, Some(0.0)], - }, - ], - constraints: u64 => &[ResolvedValueRangeConstraint { - min: None, - max: None, - exact: None, - ideal: Some(42), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: u64 => &[ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: None, - }], - expected: Ok(0.0) - ); - - generate_value_range_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: u64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: None, - } - ], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: u64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: None, - }, - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: Some(42), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == 42)".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_range_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - ], - constraints: u64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: None, - }, - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: Some(42), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42)".to_owned(), - setting: Some("0".to_owned()), - }) - ); - - generate_value_range_constraint_tests!( - tests: [ - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: u64 => &[ - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: None, - }, - ResolvedValueRangeConstraint { - min: None, - max: None, - exact: Some(42), - ideal: Some(42), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == 42)".to_owned(), - setting: Some("0.0".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_sequence_constraint.rs b/constraints/src/algorithms/fitness_distance/value_sequence_constraint.rs deleted file mode 100644 index e8f68c28c..000000000 --- a/constraints/src/algorithms/fitness_distance/value_sequence_constraint.rs +++ /dev/null @@ -1,173 +0,0 @@ -use super::setting::SettingFitnessDistanceError; -use super::{FitnessDistance, SettingFitnessDistanceErrorKind}; -use crate::ResolvedValueSequenceConstraint; - -macro_rules! impl_non_numeric_value_sequence_constraint { - (setting: $s:ty, constraint: $c:ty) => { - impl<'a> FitnessDistance> for ResolvedValueSequenceConstraint<$c> - where - $s: PartialEq<$c>, - { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a $s>) -> Result { - if let Some(exact) = self.exact.as_ref() { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'exact' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(actual) if exact.contains(actual) => {} - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(ideal) = self.ideal.as_ref() { - // As specified in step 8 of the `fitness distance` algorithm: - // - // - // > For all string, enum and boolean constraints [โ€ฆ], - // > the fitness distance is the result of the formula: - // > - // > ``` - // > (actual == ideal) ? 0 : 1 - // > ``` - // - // As well as step 5 of the `fitness distance` algorithm: - // - // - // > If the settings dictionary's `constraintName` member - // > does not exist, the fitness distance is 1. - match setting { - Some(actual) if ideal.contains(actual) => Ok(0.0), - Some(_) => Ok(1.0), - None => Ok(1.0), - } - } else { - // As specified in step 6 of the `fitness distance` algorithm: - // - // - // > If no ideal value is specified (constraintValue either - // > contains no member named 'ideal', or, if bare values are to be - // > treated as 'ideal', isn't a bare value), the fitness distance is 0. - Ok(0.0) - } - } - } - }; -} - -impl_non_numeric_value_sequence_constraint!(setting: bool, constraint: bool); -impl_non_numeric_value_sequence_constraint!(setting: String, constraint: String); - -macro_rules! impl_numeric_value_sequence_constraint { - (setting: $s:ty, constraint: $c:ty) => { - impl<'a> FitnessDistance> for ResolvedValueSequenceConstraint<$c> { - type Error = SettingFitnessDistanceError; - - fn fitness_distance(&self, setting: Option<&'a $s>) -> Result { - if let Some(exact) = &self.exact { - // As specified in step 2 of the `fitness distance` algorithm: - // - // - // > If the constraint is required (constraintValue either contains - // > one or more members named [โ€ฆ] 'exact' [โ€ฆ]), and the settings - // > dictionary's constraintName member's value does not satisfy the - // > constraint or doesn't exist, the fitness distance is positive infinity. - match setting { - Some(&actual) if exact.contains(&(actual as $c)) => {} - Some(setting) => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: format!("{}", self.to_required_only()), - setting: Some(format!("{:?}", setting)), - }) - } - None => { - return Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: format!("{}", self.to_required_only()), - setting: None, - }) - } - }; - } - - if let Some(ideal) = &self.ideal { - // As specified in step 8 of the `fitness distance` algorithm: - // - // - // > For all string, enum and boolean constraints [โ€ฆ], - // > the fitness distance is the result of the formula: - // > - // > ``` - // > (actual == ideal) ? 0 : 1 - // > ``` - // - // As well as step 5 of the `fitness distance` algorithm: - // - // - // > If the settings dictionary's `constraintName` member - // > does not exist, the fitness distance is 1. - match setting { - Some(&actual) => { - let actual: f64 = actual as f64; - let mut min_fitness_distance = 1.0; - for ideal in ideal.into_iter() { - let ideal: f64 = (*ideal) as f64; - // As specified in step 7 of the `fitness distance` algorithm: - // - // - // > For all positive numeric constraints [โ€ฆ], - // > the fitness distance is the result of the formula - // > - // > ``` - // > (actual == ideal) ? 0 : |actual - ideal| / max(|actual|, |ideal|) - // > ``` - let fitness_distance = - super::relative_fitness_distance(actual, ideal); - if fitness_distance < min_fitness_distance { - min_fitness_distance = fitness_distance; - } - } - Ok(min_fitness_distance) - } - None => Ok(1.0), - } - } else { - // As specified in step 6 of the `fitness distance` algorithm: - // - // - // > If no ideal value is specified (constraintValue either - // > contains no member named 'ideal', or, if bare values are to be - // > treated as 'ideal', isn't a bare value), the fitness distance is 0. - Ok(0.0) - } - } - } - }; -} - -impl_numeric_value_sequence_constraint!(setting: f64, constraint: f64); -impl_numeric_value_sequence_constraint!(setting: i64, constraint: u64); -impl_numeric_value_sequence_constraint!(setting: i64, constraint: f64); -impl_numeric_value_sequence_constraint!(setting: f64, constraint: u64); - -#[cfg(test)] -mod tests; diff --git a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests.rs b/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests.rs deleted file mode 100644 index 5750ae0a8..000000000 --- a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests.rs +++ /dev/null @@ -1,144 +0,0 @@ -use super::*; - -macro_rules! generate_value_constraint_tests { - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr $(,)? - }),+ $(,)? - ], - constraints: $ct:ty => $ce:expr, - expected: $e:expr $(,)? - ) => { - generate_value_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - constraints: $ct => $ce, - }),+ - ], - expected: $e - ); - }; - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr $(,)? - }),+ $(,)? - ], - expected: $e:expr $(,)? - ) => { - generate_value_constraint_tests!( - tests: [ - $({ - name: $ti, - settings: $st => $se, - constraints: $ct => $ce, - }),+ - ], - validate: |result| { - assert_eq!(result, $e); - } - ); - }; - ( - tests: [ - $({ - name: $ti:ident, - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr $(,)? - }),+ $(,)? - ], - validate: |$a:ident| $b:block - ) => { - $( - #[test] - fn $ti() { - test_value_constraint!( - settings: $st => $se, - constraints: $ct => $ce, - validate: |$a| $b - ); - } - )+ - }; -} - -macro_rules! test_value_constraint { - ( - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr, - expected: $e:expr $(,)? - ) => { - test_value_constraint!( - settings: $st => $se, - constraints: $ct => $ce, - validate: |result| { - assert_eq!(result, $e); - } - ); - }; - ( - settings: $st:ty => $se:expr, - constraints: $ct:ty => $ce:expr, - validate: |$a:ident| $b:block - ) => {{ - let settings: &[Option<$st>] = $se; - let constraints: &[ResolvedValueSequenceConstraint<$ct>] = $ce; - - for constraint in constraints { - for setting in settings { - let closure = |$a| $b; - let actual = constraint.fitness_distance(setting.as_ref()); - closure(actual); - } - } - }}; - ( - checks: [ - $({ - setting: $st:ty => $se:expr, - constraint: $ct:ty => $ce:expr, - expected: $ee:expr $(,)? - }),+ $(,)? - ] - ) => { - test_value_constraint!( - checks: [ - $({ - setting: $st => $se, - constraint: $ct => $ce, - expected: $ee, - }),+ - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - }; - ( - checks: [ - $({ - setting: $st:ty => $se:expr, - constraint: $ct:ty => $ce:expr, - expected: $ee:expr $(,)? - }),+ $(,)? - ], - validate: |$ai:ident, $ei:ident| $b:block - ) => {{ - $({ - let closure = |$ai, $ei| $b; - let actual = $ce.fitness_distance($se.as_ref()); - closure(actual, $ee); - })+ - }}; -} - -mod bool; -mod f64; -mod string; -mod u64; diff --git a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/bool.rs b/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/bool.rs deleted file mode 100644 index b6d4bc245..000000000 --- a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/bool.rs +++ /dev/null @@ -1,132 +0,0 @@ -use super::*; -use crate::algorithms::SettingFitnessDistanceErrorKind; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - ], - constraints: bool => &[ - ResolvedValueSequenceConstraint { - exact: None, - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![true]), - }, - ], - expected: Ok(0.0) - ); - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[None, Some(false)], - }, - ], - constraints: bool => &[ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![true]), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - // A constraint that does apply for a type of setting, - // is expected to return a fitness distance of `0`, - // iff the setting matches the constraint: - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(true)], - }, - ], - constraints: bool => &[ResolvedValueSequenceConstraint { - exact: Some(vec![true]), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[None], - }, - ], - constraints: bool => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![true]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![true]), - ideal: Some(vec![true]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == [true])".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: bool_setting, - settings: bool => &[Some(false)], - }, - ], - constraints: bool => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![true]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![true]), - ideal: Some(vec![true]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == [true])".to_owned(), - setting: Some("false".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/f64.rs b/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/f64.rs deleted file mode 100644 index aa3803323..000000000 --- a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/f64.rs +++ /dev/null @@ -1,245 +0,0 @@ -use super::*; -use crate::algorithms::SettingFitnessDistanceErrorKind; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: f64 => &[ - ResolvedValueSequenceConstraint { - exact: None, - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![42.0]), - }, - ], - expected: Ok(0.0) - ); - } - - mod fract_distance { - use super::*; - - #[test] - fn i64_setting() { - test_value_constraint!( - checks: [ - { - setting: i64 => Some(1), - constraint: f64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4.0]), - }, - expected: Ok(0.75), - }, - { - setting: i64 => Some(2), - constraint: f64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4.0]), - }, - expected: Ok(0.5), - }, - { - setting: i64 => Some(3), - constraint: f64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4.0]), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - - #[test] - fn f64_setting() { - test_value_constraint!( - checks: [ - { - setting: f64 => Some(1.0), - constraint: f64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4.0]), - }, - expected: Ok(0.75), - }, - { - setting: f64 => Some(2.0), - constraint: f64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4.0]), - }, - expected: Ok(0.5), - }, - { - setting: f64 => Some(3.0), - constraint: f64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4.0]), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: f64 => &[ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![42.0]), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: f64 => &[ResolvedValueSequenceConstraint { - exact: Some(vec![42.0]), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: f64 => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![1.0, 1.5, 2.0]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![1.0, 1.5, 2.0]), - ideal: Some(vec![1.5]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == [1.0, 1.5, 2.0])".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - ], - constraints: f64 => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![1.0, 1.5, 2.0]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![1.0, 1.5, 2.0]), - ideal: Some(vec![1.5]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == [1.0, 1.5, 2.0])".to_owned(), - setting: Some("0".to_owned()), - }) - ); - - generate_value_constraint_tests!( - tests: [ - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: f64 => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![1.0, 1.5, 2.0]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![1.0, 1.5, 2.0]), - ideal: Some(vec![1.5]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == [1.0, 1.5, 2.0])".to_owned(), - setting: Some("0.0".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/string.rs b/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/string.rs deleted file mode 100644 index 5b2b11ece..000000000 --- a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/string.rs +++ /dev/null @@ -1,133 +0,0 @@ -use super::*; -use crate::algorithms::SettingFitnessDistanceErrorKind; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: String => &[ - ResolvedValueSequenceConstraint { - exact: None, - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec!["foo".to_owned()]), - }, - ], - expected: Ok(0.0) - ); - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[None, Some("bar".to_owned())], - }, - ], - constraints: String => &[ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec!["foo".to_owned()]), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - // A constraint that does apply for a type of setting, - // is expected to return a fitness distance of `0`, - // iff the setting matches the constraint: - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("foo".to_owned())], - }, - ], - constraints: String => &[ResolvedValueSequenceConstraint { - exact: Some(vec!["foo".to_owned()]), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[None], - }, - ], - constraints: String => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec!["foo".to_owned(), "bar".to_owned(), "baz".to_owned()]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec!["foo".to_owned(), "bar".to_owned(), "baz".to_owned()]), - ideal: Some(vec!["foo".to_owned()]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == [\"foo\", \"bar\", \"baz\"])".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: string_setting, - settings: String => &[Some("blee".to_owned())], - }, - ], - constraints: String => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec!["foo".to_owned(), "bar".to_owned(), "baz".to_owned()]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec!["foo".to_owned(), "bar".to_owned(), "baz".to_owned()]), - ideal: Some(vec!["foo".to_owned()]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == [\"foo\", \"bar\", \"baz\"])".to_owned(), - setting: Some("\"blee\"".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/u64.rs b/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/u64.rs deleted file mode 100644 index 4693fa143..000000000 --- a/constraints/src/algorithms/fitness_distance/value_sequence_constraint/tests/u64.rs +++ /dev/null @@ -1,244 +0,0 @@ -use super::*; -use crate::algorithms::SettingFitnessDistanceErrorKind; - -mod basic { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: u64 => &[ - ResolvedValueSequenceConstraint { - exact: None, - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![42]), - }, - ], - expected: Ok(0.0) - ); - } - - mod fract_distance { - use super::*; - - #[test] - fn i64_setting() { - test_value_constraint!( - checks: [ - { - setting: i64 => Some(1), - constraint: i64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4]), - }, - expected: Ok(0.75), - }, - { - setting: i64 => Some(2), - constraint: i64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4]), - }, - expected: Ok(0.5), - }, - { - setting: i64 => Some(3), - constraint: i64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4]), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - - #[test] - fn f64_setting() { - test_value_constraint!( - checks: [ - { - setting: f64 => Some(1.0), - constraint: u64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4]), - }, - expected: Ok(0.75), - }, - { - setting: f64 => Some(2.0), - constraint: u64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4]), - }, - expected: Ok(0.5), - }, - { - setting: f64 => Some(3.0), - constraint: u64 => ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![4]), - }, - expected: Ok(0.25), - }, - ], - validate: |actual, expected| { - assert_eq!(actual, expected); - } - ); - } - } - - mod one_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None, Some(0)], - }, - { - name: f64_setting, - settings: f64 => &[None, Some(0.0)], - }, - ], - constraints: u64 => &[ResolvedValueSequenceConstraint { - exact: None, - ideal: Some(vec![42]), - }], - expected: Ok(1.0) - ); - } -} - -mod required { - use super::*; - - mod zero_distance { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(42)], - }, - { - name: f64_setting, - settings: f64 => &[Some(42.0)], - }, - ], - constraints: u64 => &[ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: None, - }], - expected: Ok(0.0) - ); - } - - mod inf_distance { - use super::*; - - mod missing { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[None], - }, - { - name: f64_setting, - settings: f64 => &[None], - }, - ], - constraints: u64 => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: Some(vec![42]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Missing, - constraint: "(x == [42])".to_owned(), - setting: None, - }) - ); - } - - mod mismatch { - use super::*; - - generate_value_constraint_tests!( - tests: [ - { - name: i64_setting, - settings: i64 => &[Some(0)], - }, - ], - constraints: u64 => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: Some(vec![42]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == [42])".to_owned(), - setting: Some("0".to_owned()), - }) - ); - - generate_value_constraint_tests!( - tests: [ - { - name: f64_setting, - settings: f64 => &[Some(0.0)], - }, - ], - constraints: u64 => &[ - ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: None, - }, - ResolvedValueSequenceConstraint { - exact: Some(vec![42]), - ideal: Some(vec![42]), - }, - ], - expected: Err(SettingFitnessDistanceError { - kind: SettingFitnessDistanceErrorKind::Mismatch, - constraint: "(x == [42])".to_owned(), - setting: Some("0.0".to_owned()), - }) - ); - } - } -} diff --git a/constraints/src/algorithms/select_settings.rs b/constraints/src/algorithms/select_settings.rs deleted file mode 100644 index bd0798b3e..000000000 --- a/constraints/src/algorithms/select_settings.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::collections::HashSet; - -use thiserror::Error; - -use crate::algorithms::fitness_distance::SettingFitnessDistanceError; -use crate::errors::OverconstrainedError; -use crate::{MediaTrackSettings, SanitizedMediaTrackConstraints}; - -mod apply_advanced; -mod apply_mandatory; -mod select_optimal; -mod tie_breaking; - -use self::apply_advanced::*; -use self::apply_mandatory::*; -use self::select_optimal::*; -pub use self::tie_breaking::*; - -/// A mode indicating whether device information may be exposed. -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -pub enum DeviceInformationExposureMode { - /// Device information may be exposed. - Exposed, - /// Device information may NOT be exposed. - Protected, -} - -/// An error type indicating a failure of the `SelectSettings` algorithm. -#[derive(Error, Clone, Eq, PartialEq, Debug)] -pub enum SelectSettingsError { - /// An error caused by one or more over-constrained settings. - #[error(transparent)] - Overconstrained(#[from] OverconstrainedError), -} - -/// This function implements steps 1-5 of the `SelectSettings` algorithm -/// as defined by the W3C spec: -/// -/// -/// Step 6 (tie-breaking) is omitted by this implementation and expected to be performed -/// manually on the returned candidates. -/// For this several implementation of `TieBreakingPolicy` are provided by this crate. -pub fn select_settings_candidates<'a, I>( - possible_settings: I, - constraints: &SanitizedMediaTrackConstraints, - exposure_mode: DeviceInformationExposureMode, -) -> Result, SelectSettingsError> -where - I: IntoIterator, -{ - let possible_settings = possible_settings.into_iter(); - - // As specified in step 1 of the `SelectSettings` algorithm: - // - // - // > Each constraint specifies one or more values (or a range of values) for its property. - // > A property MAY appear more than once in the list of 'advanced' ConstraintSets. - // > If an empty list has been given as the value for a constraint, - // > it MUST be interpreted as if the constraint were not specified - // > (in other words, an empty constraint == no constraint). - // > - // > Note that unknown properties are discarded by WebIDL, - // > which means that unknown/unsupported required constraints will silently disappear. - // > To avoid this being a surprise, application authors are expected to first use - // > the `getSupportedConstraints()` method [โ€ฆ]. - - // We expect "sanitized" constraints to not contain empty constraints: - debug_assert!(constraints - .mandatory - .iter() - .all(|(_, constraint)| !constraint.is_empty())); - - // Obtain candidates by filtering possible settings, dropping those with infinite fitness distances: - // - // This function call corresponds to steps 3 & 4 of the `SelectSettings` algorithm: - // - - let candidates_and_fitness_distances = - apply_mandatory_constraints(possible_settings, &constraints.mandatory, exposure_mode)?; - - // As specified in step 5 of the `SelectSettings` algorithm: - // - // - // > Iterate over the 'advanced' ConstraintSets in newConstraints in the order in which they were specified. - // > - // > For each ConstraintSet: - // > - // > 1. compute the fitness distance between it and each settings dictionary in candidates, - // > treating bare values of properties as exact. - // > - // > 2. If the fitness distance is finite for one or more settings dictionaries in candidates, - // > keep those settings dictionaries in candidates, discarding others. - // > - // > If the fitness distance is infinite for all settings dictionaries in candidates, - // > ignore this ConstraintSet. - let candidates = - apply_advanced_constraints(candidates_and_fitness_distances, &constraints.advanced); - - // As specified in step 6 of the `SelectSettings` algorithm: - // - // - // > Select one settings dictionary from candidates, and return it as the result of the `SelectSettings` algorithm. - // > The User Agent MUST use one with the smallest fitness distance, as calculated in step 3. - // > If more than one settings dictionary have the smallest fitness distance, - // > the User Agent chooses one of them based on system default property values and User Agent default property values. - // - // # Important - // Instead of return just ONE settings instance "with the smallest fitness distance, as calculated in step 3" - // we instead return ALL settings instances "with the smallest fitness distance, as calculated in step 3" - // and leave tie-breaking to the User Agent in a separate step: - Ok(select_optimal_candidates(candidates)) -} - -#[derive(Default)] -pub(crate) struct ConstraintFailureInfo { - pub(crate) failures: usize, - pub(crate) errors: HashSet, -} - -#[cfg(test)] -mod tests; diff --git a/constraints/src/algorithms/select_settings/apply_advanced.rs b/constraints/src/algorithms/select_settings/apply_advanced.rs deleted file mode 100644 index b7adfbe51..000000000 --- a/constraints/src/algorithms/select_settings/apply_advanced.rs +++ /dev/null @@ -1,188 +0,0 @@ -use crate::algorithms::FitnessDistance; -use crate::constraints::SanitizedAdvancedMediaTrackConstraints; -use crate::MediaTrackSettings; - -/// Returns the set of settings for which all non-overconstraining advanced constraints' -/// fitness distance is finite. -/// -/// Implements step 5 of the `SelectSettings` algorithm: -/// -/// -/// # Note: -/// This may change the order of items in `feasible_candidates`. -/// In practice however this is not a problem as we have to sort -/// it by fitness-distance eventually anyway. -pub(super) fn apply_advanced_constraints<'a>( - mut candidates: Vec<(&'a MediaTrackSettings, f64)>, - advanced_constraints: &SanitizedAdvancedMediaTrackConstraints, -) -> Vec<(&'a MediaTrackSettings, f64)> { - // As specified in step 5 of the `SelectSettings` algorithm: - // - // - // > Iterate over the 'advanced' ConstraintSets in newConstraints in the order in which they were specified. - // > - // > For each ConstraintSet: - // > - // > 1. compute the fitness distance between it and each settings dictionary in candidates, - // > treating bare values of properties as exact. - // > - // > 2. If the fitness distance is finite for one or more settings dictionaries in candidates, - // > keep those settings dictionaries in candidates, discarding others. - // > - // > If the fitness distance is infinite for all settings dictionaries in candidates, - // > ignore this ConstraintSet. - - let mut selected_candidates = Vec::with_capacity(candidates.len()); - - // Double-buffered sieving to avoid excessive vec allocations: - for advanced_constraint_set in advanced_constraints.iter() { - for (candidate, fitness_distance) in candidates.iter() { - if advanced_constraint_set.fitness_distance(candidate).is_ok() { - selected_candidates.push((*candidate, *fitness_distance)); - } - } - - if !selected_candidates.is_empty() { - candidates.clear(); - std::mem::swap(&mut candidates, &mut selected_candidates); - } - } - - candidates -} - -#[cfg(test)] -mod tests { - use std::iter::FromIterator; - - use super::*; - use crate::property::all::name::*; - use crate::{ - MediaTrackSupportedConstraints, ResizeMode, ResolvedAdvancedMediaTrackConstraints, - ResolvedMediaTrackConstraintSet, ResolvedValueConstraint, ResolvedValueRangeConstraint, - }; - - // Advanced constraint sets that doe not match any - // candidates should just get ignored: - #[test] - fn overconstrained() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let settings = [ - MediaTrackSettings::from_iter([(&DEVICE_ID, "foo".into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "bar".into())]), - ]; - - let candidates: Vec<_> = settings - .iter() - // attach a dummy fitness function: - .map(|settings| (settings, 42.0)) - .collect(); - - let constraints = ResolvedAdvancedMediaTrackConstraints::from_iter([ - ResolvedMediaTrackConstraintSet::from_iter([( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("bazblee".to_owned()) - .into(), - )]), - ]); - - let sanitized_constraints = constraints.to_sanitized(&supported_constraints); - - let actual: Vec<_> = apply_advanced_constraints(candidates, &sanitized_constraints) - .into_iter() - // drop the dummy fitness distance: - .map(|(settings, _)| settings) - .collect(); - - let expected: Vec<_> = settings.iter().collect(); - - assert_eq!(actual, expected); - } - - #[test] - fn constrained() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "480p".into()), - (&HEIGHT, 480.into()), - (&WIDTH, 720.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "720p".into()), - (&HEIGHT, 720.into()), - (&WIDTH, 1280.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1080p".into()), - (&HEIGHT, 1080.into()), - (&WIDTH, 1920.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1440p".into()), - (&HEIGHT, 1440.into()), - (&WIDTH, 2560.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "2160p".into()), - (&HEIGHT, 2160.into()), - (&WIDTH, 3840.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - ]; - - let candidates: Vec<_> = settings.iter().map(|settings| (settings, 42.0)).collect(); - - let constraints = ResolvedAdvancedMediaTrackConstraints::from_iter([ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - ResolvedMediaTrackConstraintSet::from_iter([( - &HEIGHT, - ResolvedValueRangeConstraint::default().exact(800).into(), - )]), - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - ResolvedMediaTrackConstraintSet::from_iter([( - &RESIZE_MODE, - ResolvedValueConstraint::default() - .exact(ResizeMode::none()) - .into(), - )]), - // The second advanced constraint set of "max 1440p" does match - // candidates and should thus be applied by the algorithm: - ResolvedMediaTrackConstraintSet::from_iter([( - &HEIGHT, - ResolvedValueRangeConstraint::default().max(1440).into(), - )]), - ]); - - let sanitized_constraints = constraints.to_sanitized(&supported_constraints); - - let actual: Vec<_> = apply_advanced_constraints(candidates, &sanitized_constraints) - .into_iter() - // drop the dummy fitness distance: - .map(|(settings, _)| settings) - .collect(); - - let expected = vec![&settings[2], &settings[3]]; - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/algorithms/select_settings/apply_mandatory.rs b/constraints/src/algorithms/select_settings/apply_mandatory.rs deleted file mode 100644 index a9bba70e8..000000000 --- a/constraints/src/algorithms/select_settings/apply_mandatory.rs +++ /dev/null @@ -1,213 +0,0 @@ -use std::collections::HashMap; - -use crate::algorithms::select_settings::{ConstraintFailureInfo, DeviceInformationExposureMode}; -use crate::algorithms::FitnessDistance; -use crate::errors::OverconstrainedError; -use crate::{MediaTrackProperty, MediaTrackSettings, SanitizedMediaTrackConstraintSet}; - -/// Returns the set of settings for which all mandatory constraints' -/// fitness distance is finite. -/// -/// Implements step 5 of the `SelectSettings` algorithm: -/// -pub(super) fn apply_mandatory_constraints<'a, I>( - candidates: I, - mandatory_constraints: &SanitizedMediaTrackConstraintSet, - exposure_mode: DeviceInformationExposureMode, -) -> Result, OverconstrainedError> -where - I: IntoIterator, -{ - // As specified in step 3 of the `SelectSettings` algorithm: - // - // - // > For every possible settings dictionary of copy compute its fitness distance, - // > treating bare values of properties as ideal values. Let candidates be the - // > set of settings dictionaries for which the fitness distance is finite. - - let mut feasible_candidates: Vec<(&'a MediaTrackSettings, f64)> = vec![]; - let mut failed_constraints: HashMap = - Default::default(); - - for candidate in candidates { - match mandatory_constraints.fitness_distance(candidate) { - Ok(fitness_distance) => { - debug_assert!(fitness_distance.is_finite()); - - feasible_candidates.push((candidate, fitness_distance)); - } - Err(error) => { - for (property, setting_error) in error.setting_errors { - let entry = failed_constraints.entry(property).or_default(); - entry.failures += 1; - entry.errors.insert(setting_error); - } - } - } - } - - if feasible_candidates.is_empty() { - return Err(match exposure_mode { - DeviceInformationExposureMode::Exposed => { - OverconstrainedError::exposing_device_information(failed_constraints) - } - DeviceInformationExposureMode::Protected => OverconstrainedError::default(), - }); - } - - Ok(feasible_candidates) -} - -#[cfg(test)] -mod tests { - use std::iter::FromIterator; - - use super::*; - use crate::property::all::name::*; - use crate::{ - MediaTrackSupportedConstraints, ResizeMode, ResolvedMandatoryMediaTrackConstraints, - ResolvedValueConstraint, ResolvedValueRangeConstraint, - }; - - // Advanced constraint sets that do not match any candidates should just get ignored: - #[test] - fn overconstrained() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let settings = [ - MediaTrackSettings::from_iter([(&DEVICE_ID, "foo".into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "bar".into())]), - ]; - - let candidates: Vec<_> = settings.iter().collect(); - - let constraints = ResolvedMandatoryMediaTrackConstraints::from_iter([( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("mismatched-device".to_owned()) - .into(), - )]); - - let sanitized_constraints = constraints.to_sanitized(&supported_constraints); - - // Exposed exposure mode: - - let error = apply_mandatory_constraints( - candidates.clone(), - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap_err(); - - let constraint = &error.constraint; - let err_message = error.message.as_ref().expect("Error message."); - - assert_eq!(constraint, &DEVICE_ID); - assert_eq!( - err_message, - "Setting was a mismatch ([\"bar\", \"foo\"] do not satisfy (x == \"mismatched-device\"))." - ); - - // Protected exposure mode: - - let error = apply_mandatory_constraints( - candidates, - &sanitized_constraints, - DeviceInformationExposureMode::Protected, - ) - .unwrap_err(); - - let constraint = &error.constraint; - let err_message = error.message; - - assert_eq!( - constraint, - &MediaTrackProperty::from(""), - "Constraint should not have been exposed" - ); - assert!( - err_message.is_none(), - "Error message should not have been exposed" - ); - } - - #[test] - fn constrained() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "480p".into()), - (&HEIGHT, 480.into()), - (&WIDTH, 720.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "720p".into()), - (&HEIGHT, 720.into()), - (&WIDTH, 1280.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1080p".into()), - (&HEIGHT, 1080.into()), - (&WIDTH, 1920.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1440p".into()), - (&HEIGHT, 1440.into()), - (&WIDTH, 2560.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "2160p".into()), - (&HEIGHT, 2160.into()), - (&WIDTH, 3840.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - ]; - - let candidates: Vec<_> = settings.iter().collect(); - - let constraints = ResolvedMandatoryMediaTrackConstraints::from_iter([ - ( - &RESIZE_MODE, - ResolvedValueConstraint::default() - .exact(ResizeMode::none()) - .into(), - ), - ( - &HEIGHT, - ResolvedValueRangeConstraint::default().min(1000).into(), - ), - ( - &WIDTH, - ResolvedValueRangeConstraint::default().max(2000).into(), - ), - ]); - - let sanitized_constraints = constraints.to_sanitized(&supported_constraints); - - let actual = apply_mandatory_constraints( - candidates, - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - - let expected = vec![(&settings[2], 0.0)]; - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/algorithms/select_settings/select_optimal.rs b/constraints/src/algorithms/select_settings/select_optimal.rs deleted file mode 100644 index dbc631053..000000000 --- a/constraints/src/algorithms/select_settings/select_optimal.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::MediaTrackSettings; - -pub(super) fn select_optimal_candidates<'a, I>(candidates: I) -> Vec<&'a MediaTrackSettings> -where - I: IntoIterator, -{ - let mut optimal_candidates = vec![]; - let mut optimal_fitness_distance = f64::INFINITY; - - for (candidate, fitness_distance) in candidates { - use std::cmp::Ordering; - - #[cfg(feature = "total_cmp")] - let ordering = fitness_distance.total_cmp(&optimal_fitness_distance); - - // TODO: remove fallback once MSRV has been bumped to 1.62 or later: - #[cfg(not(feature = "total_cmp"))] - let ordering = { - // See http://doc.rust-lang.org/1.65.0/core/primitive.f64.html#method.total_cmp: - - let mut left = fitness_distance.to_bits() as i64; - let mut right = optimal_fitness_distance.to_bits() as i64; - - left ^= (((left >> 63) as u64) >> 1) as i64; - right ^= (((right >> 63) as u64) >> 1) as i64; - - left.cmp(&right) - }; - - if ordering == Ordering::Less { - // Candidate is new optimal, so drop current selection: - optimal_candidates.clear(); - optimal_fitness_distance = fitness_distance; - } - - if ordering != Ordering::Greater { - // Candidate is optimal, so add to selection: - optimal_candidates.push(candidate); - } - } - - optimal_candidates -} - -#[cfg(test)] -mod tests { - use super::select_optimal_candidates; - use crate::MediaTrackSettings; - - #[test] - fn monotonic_increasing() { - let settings = [ - MediaTrackSettings::default(), - MediaTrackSettings::default(), - MediaTrackSettings::default(), - MediaTrackSettings::default(), - ]; - - let candidates = vec![ - (&settings[0], 0.1), - (&settings[1], 0.1), - (&settings[2], 0.2), - (&settings[3], 0.3), - ]; - - let actual = select_optimal_candidates(candidates); - - let expected = vec![&settings[0], &settings[1]]; - - assert_eq!(actual, expected); - } - - #[test] - fn monotonic_decreasing() { - let settings = [ - MediaTrackSettings::default(), - MediaTrackSettings::default(), - MediaTrackSettings::default(), - MediaTrackSettings::default(), - ]; - - let candidates = vec![ - (&settings[0], 0.3), - (&settings[1], 0.2), - (&settings[2], 0.1), - (&settings[3], 0.1), - ]; - - let actual = select_optimal_candidates(candidates); - - let expected = vec![&settings[2], &settings[3]]; - - assert_eq!(actual, expected); - } - - #[test] - fn alternating() { - let settings = [ - MediaTrackSettings::default(), - MediaTrackSettings::default(), - MediaTrackSettings::default(), - MediaTrackSettings::default(), - ]; - - let candidates = vec![ - (&settings[0], 0.2), - (&settings[1], 0.1), - (&settings[2], 0.2), - (&settings[3], 0.1), - ]; - - let actual = select_optimal_candidates(candidates); - - let expected = vec![&settings[1], &settings[3]]; - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/algorithms/select_settings/tests.rs b/constraints/src/algorithms/select_settings/tests.rs deleted file mode 100644 index 5745662f4..000000000 --- a/constraints/src/algorithms/select_settings/tests.rs +++ /dev/null @@ -1,778 +0,0 @@ -use std::iter::FromIterator; - -use lazy_static::lazy_static; - -use super::DeviceInformationExposureMode; -use crate::algorithms::{select_settings_candidates, SelectSettingsError}; -use crate::errors::OverconstrainedError; -use crate::property::all::name::*; -use crate::property::all::names as all_properties; -use crate::{ - AdvancedMediaTrackConstraints, FacingMode, MandatoryMediaTrackConstraints, - MediaTrackConstraints, MediaTrackSettings, MediaTrackSupportedConstraints, ResizeMode, - ResolvedAdvancedMediaTrackConstraints, ResolvedMandatoryMediaTrackConstraints, - ResolvedMediaTrackConstraint, ResolvedMediaTrackConstraints, ResolvedValueConstraint, - ResolvedValueRangeConstraint, ResolvedValueSequenceConstraint, SanitizedMediaTrackConstraints, -}; - -lazy_static! { - static ref VIDEO_IDEAL: MediaTrackSettings = MediaTrackSettings::from_iter([ - (&ASPECT_RATIO, 0.5625.into()), - (&FACING_MODE, FacingMode::user().into()), - (&FRAME_RATE, 60.0.into()), - (&WIDTH, 1920.into()), - (&HEIGHT, 1080.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]); - static ref VIDEO_480P: MediaTrackSettings = MediaTrackSettings::from_iter([ - (&DEVICE_ID, "480p".into()), - (&ASPECT_RATIO, 0.5625.into()), - (&FACING_MODE, FacingMode::user().into()), - (&FRAME_RATE, 240.into()), - (&WIDTH, 720.into()), - (&HEIGHT, 480.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]); - static ref VIDEO_720P: MediaTrackSettings = MediaTrackSettings::from_iter([ - (&DEVICE_ID, "720p".into()), - (&ASPECT_RATIO, 0.5625.into()), - (&FACING_MODE, FacingMode::user().into()), - (&FRAME_RATE, 120.into()), - (&WIDTH, 1280.into()), - (&HEIGHT, 720.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]); - static ref VIDEO_1080P: MediaTrackSettings = MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1080p".into()), - (&ASPECT_RATIO, 0.5625.into()), - (&FACING_MODE, FacingMode::user().into()), - (&FRAME_RATE, 60.into()), - (&WIDTH, 1920.into()), - (&HEIGHT, 1080.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]); - static ref VIDEO_1440P: MediaTrackSettings = MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1440p".into()), - (&ASPECT_RATIO, 0.5625.into()), - (&FACING_MODE, FacingMode::user().into()), - (&FRAME_RATE, 30.into()), - (&WIDTH, 2560.into()), - (&HEIGHT, 1440.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]); - static ref VIDEO_2160P: MediaTrackSettings = MediaTrackSettings::from_iter([ - (&DEVICE_ID, "2160p".into()), - (&ASPECT_RATIO, 0.5625.into()), - (&FACING_MODE, FacingMode::user().into()), - (&FRAME_RATE, 15.into()), - (&WIDTH, 3840.into()), - (&HEIGHT, 2160.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]); -} - -fn default_possible_settings() -> Vec { - vec![ - VIDEO_480P.clone(), - VIDEO_720P.clone(), - VIDEO_1080P.clone(), - VIDEO_1440P.clone(), - VIDEO_2160P.clone(), - ] -} - -fn default_supported_constraints() -> MediaTrackSupportedConstraints { - MediaTrackSupportedConstraints::from_iter(all_properties().into_iter().cloned()) -} - -fn test_overconstrained( - possible_settings: &[MediaTrackSettings], - mandatory_constraints: ResolvedMandatoryMediaTrackConstraints, - exposure_mode: DeviceInformationExposureMode, -) -> OverconstrainedError { - let constraints = ResolvedMediaTrackConstraints { - mandatory: mandatory_constraints, - advanced: ResolvedAdvancedMediaTrackConstraints::default(), - } - .to_sanitized(&default_supported_constraints()); - - let result = select_settings_candidates(possible_settings.iter(), &constraints, exposure_mode); - - let actual = result.err().unwrap(); - - let SelectSettingsError::Overconstrained(overconstrained_error) = actual; - - overconstrained_error -} - -fn test_constrained( - possible_settings: &[MediaTrackSettings], - mandatory_constraints: ResolvedMandatoryMediaTrackConstraints, - advanced_constraints: ResolvedAdvancedMediaTrackConstraints, -) -> Vec<&MediaTrackSettings> { - let constraints = ResolvedMediaTrackConstraints { - mandatory: mandatory_constraints, - advanced: advanced_constraints, - } - .to_sanitized(&default_supported_constraints()); - - let result = select_settings_candidates( - possible_settings.iter(), - &constraints, - DeviceInformationExposureMode::Exposed, - ); - - result.unwrap() -} - -mod unconstrained { - use super::*; - - fn default_constraints() -> MediaTrackConstraints { - MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::default(), - advanced: AdvancedMediaTrackConstraints::default(), - } - } - - fn default_resolved_constraints() -> ResolvedMediaTrackConstraints { - default_constraints().into_resolved() - } - - fn default_sanitized_constraints() -> SanitizedMediaTrackConstraints { - default_resolved_constraints().into_sanitized(&default_supported_constraints()) - } - - #[test] - fn pass_through() { - let possible_settings = default_possible_settings(); - let sanitized_constraints = default_sanitized_constraints(); - - let actual = select_settings_candidates( - &possible_settings[..], - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - let expected: Vec<_> = possible_settings.iter().collect(); - - assert_eq!(actual, expected); - } -} - -mod overconstrained { - use super::*; - use crate::MediaTrackProperty; - - #[test] - fn protected() { - let error = test_overconstrained( - &default_possible_settings(), - ResolvedMandatoryMediaTrackConstraints::from_iter([( - GROUP_ID.clone(), - ResolvedValueConstraint::default() - .exact("missing-group".to_owned()) - .into(), - )]), - DeviceInformationExposureMode::Protected, - ); - - assert_eq!(error.constraint, MediaTrackProperty::from("")); - assert_eq!(error.message, None); - } - - mod exposed { - use super::*; - - #[test] - fn missing() { - let error = test_overconstrained( - &default_possible_settings(), - ResolvedMandatoryMediaTrackConstraints::from_iter([( - GROUP_ID.clone(), - ResolvedValueConstraint::default() - .exact("missing-group".to_owned()) - .into(), - )]), - DeviceInformationExposureMode::Exposed, - ); - - let constraint = &error.constraint; - let err_message = error.message.as_ref().expect("Error message."); - - assert_eq!(constraint, &GROUP_ID); - assert_eq!( - err_message, - "Setting was missing (does not satisfy (x == \"missing-group\"))." - ); - } - - #[test] - fn mismatch() { - let error = test_overconstrained( - &default_possible_settings(), - ResolvedMandatoryMediaTrackConstraints::from_iter([( - DEVICE_ID.clone(), - ResolvedValueConstraint::default() - .exact("mismatched-device".to_owned()) - .into(), - )]), - DeviceInformationExposureMode::Exposed, - ); - - let constraint = &error.constraint; - let err_message = error.message.as_ref().expect("Error message."); - - assert_eq!(constraint, &DEVICE_ID); - assert_eq!( - err_message, - "Setting was a mismatch ([\"1080p\", \"1440p\", \"2160p\", \"480p\", \"720p\"] do not satisfy (x == \"mismatched-device\"))." - ); - } - - #[test] - fn too_small() { - let error = test_overconstrained( - &default_possible_settings(), - ResolvedMandatoryMediaTrackConstraints::from_iter([( - FRAME_RATE.clone(), - ResolvedValueRangeConstraint::default().min(1000).into(), - )]), - DeviceInformationExposureMode::Exposed, - ); - - let constraint = &error.constraint; - let err_message = error.message.as_ref().expect("Error message."); - - assert_eq!(constraint, &FRAME_RATE); - assert_eq!( - err_message, - "Setting was too small ([120, 15, 240, 30, 60] do not satisfy (1000 <= x))." - ); - } - - #[test] - fn too_large() { - let error = test_overconstrained( - &default_possible_settings(), - ResolvedMandatoryMediaTrackConstraints::from_iter([( - FRAME_RATE.clone(), - ResolvedValueRangeConstraint::default().max(10).into(), - )]), - DeviceInformationExposureMode::Exposed, - ); - - let constraint = &error.constraint; - let err_message = error.message.as_ref().expect("Error message."); - - assert_eq!(constraint, &FRAME_RATE); - assert_eq!( - err_message, - "Setting was too large ([120, 15, 240, 30, 60] do not satisfy (x <= 10))." - ); - } - } -} - -mod constrained { - use super::*; - - #[test] - fn specific_device_id() { - let possible_settings = default_possible_settings(); - - for target_settings in possible_settings.iter() { - let setting = match target_settings.get(&DEVICE_ID) { - Some(setting) => setting, - None => continue, - }; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - DEVICE_ID.clone(), - ResolvedMediaTrackConstraint::exact_from(setting.clone()), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![target_settings]; - - assert_eq!(actual, expected); - } - } - - mod exact { - use super::*; - - #[test] - fn value() { - let possible_settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "a".into()), - (&GROUP_ID, "group-0".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "b".into()), - (&GROUP_ID, "group-1".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "c".into()), - (&GROUP_ID, "group-2".into()), - ]), - ]; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - &GROUP_ID, - ResolvedValueConstraint::default() - .exact("group-1".to_owned()) - .into(), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![&possible_settings[1]]; - - assert_eq!(actual, expected); - } - - #[test] - fn value_range() { - let possible_settings = vec![ - MediaTrackSettings::from_iter([(&DEVICE_ID, "a".into()), (&FRAME_RATE, 15.into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "b".into()), (&FRAME_RATE, 30.into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "c".into()), (&FRAME_RATE, 60.into())]), - ]; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - &FRAME_RATE, - ResolvedValueRangeConstraint::default().exact(30).into(), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![&possible_settings[1]]; - - assert_eq!(actual, expected); - } - - #[test] - fn value_sequence() { - let possible_settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "a".into()), - (&GROUP_ID, "group-0".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "b".into()), - (&GROUP_ID, "group-1".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "c".into()), - (&GROUP_ID, "group-2".into()), - ]), - ]; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - &GROUP_ID, - ResolvedValueSequenceConstraint::default() - .exact(vec!["group-1".to_owned(), "group-3".to_owned()]) - .into(), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![&possible_settings[1]]; - - assert_eq!(actual, expected); - } - } - - mod ideal { - use super::*; - - #[test] - fn value() { - let possible_settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "a".into()), - (&GROUP_ID, "group-0".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "b".into()), - (&GROUP_ID, "group-1".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "c".into()), - (&GROUP_ID, "group-2".into()), - ]), - ]; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - &GROUP_ID, - ResolvedValueConstraint::default() - .ideal("group-1".to_owned()) - .into(), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![&possible_settings[1]]; - - assert_eq!(actual, expected); - } - - #[test] - fn value_range() { - let possible_settings = vec![ - MediaTrackSettings::from_iter([(&DEVICE_ID, "a".into()), (&FRAME_RATE, 15.into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "b".into()), (&FRAME_RATE, 30.into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "c".into()), (&FRAME_RATE, 60.into())]), - ]; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - &FRAME_RATE, - ResolvedValueRangeConstraint::default().ideal(32).into(), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![&possible_settings[1]]; - - assert_eq!(actual, expected); - } - - #[test] - fn value_sequence() { - let possible_settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "a".into()), - (&GROUP_ID, "group-0".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "b".into()), - (&GROUP_ID, "group-1".into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "c".into()), - (&GROUP_ID, "group-2".into()), - ]), - ]; - - let actual = test_constrained( - &possible_settings, - ResolvedMandatoryMediaTrackConstraints::from_iter([( - &GROUP_ID, - ResolvedValueSequenceConstraint::default() - .ideal(vec!["group-1".to_owned(), "group-3".to_owned()]) - .into(), - )]), - ResolvedAdvancedMediaTrackConstraints::default(), - ); - - let expected = vec![&possible_settings[1]]; - - assert_eq!(actual, expected); - } - } -} - -// ``` -// โ”Œ -// mandatory constraints: โ”ค โ”„โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค -// โ”” -// โ”Œ -// advanced constraints: โ”ค โ”œโ”€โ”ค โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”„ -// โ”” -// โ”Œ -// possible settings: โ”ค โ—โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ—โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ—โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ—โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ— -// โ”” 480p 720p 1080p 1440p 2160p -// โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”˜ -// selected settings: โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ -// ``` -mod smoke { - use super::*; - use crate::{MediaTrackConstraintSet, ValueConstraint, ValueRangeConstraint}; - - #[test] - fn native() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let possible_settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "480p".into()), - (&HEIGHT, 480.into()), - (&WIDTH, 720.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "720p".into()), - (&HEIGHT, 720.into()), - (&WIDTH, 1280.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1080p".into()), - (&HEIGHT, 1080.into()), - (&WIDTH, 1920.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1440p".into()), - (&HEIGHT, 1440.into()), - (&WIDTH, 2560.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "2160p".into()), - (&HEIGHT, 2160.into()), - (&WIDTH, 3840.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - ]; - - let constraints = MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([ - ( - &WIDTH, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().max(2560), - ) - .into(), - ), - ( - &HEIGHT, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().max(1440), - ) - .into(), - ), - // Unsupported constraint, which should thus get ignored: - ( - &FRAME_RATE, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().exact(30.0), - ) - .into(), - ), - // Ideal resize-mode: - ( - &RESIZE_MODE, - ValueConstraint::Bare(ResizeMode::none()).into(), - ), - ]), - advanced: AdvancedMediaTrackConstraints::from_iter([ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - MediaTrackConstraintSet::from_iter([( - &HEIGHT, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().exact(800), - ) - .into(), - )]), - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - MediaTrackConstraintSet::from_iter([( - &RESIZE_MODE, - ValueConstraint::Constraint( - ResolvedValueConstraint::default().exact(ResizeMode::none()), - ) - .into(), - )]), - ]), - }; - - // Resolve bare values to proper constraints: - let resolved_constraints = constraints.into_resolved(); - - // Sanitize constraints, removing empty and unsupported constraints: - let sanitized_constraints = resolved_constraints.to_sanitized(&supported_constraints); - - let actual = select_settings_candidates( - &possible_settings, - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - - let expected = vec![&possible_settings[2], &possible_settings[3]]; - - assert_eq!(actual, expected); - } - - #[test] - fn macros() { - use crate::macros::*; - - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let possible_settings = vec![ - settings![ - &DEVICE_ID => "480p", - &HEIGHT => 480, - &WIDTH => 720, - &RESIZE_MODE => ResizeMode::crop_and_scale(), - ], - settings![ - &DEVICE_ID => "720p", - &HEIGHT => 720, - &WIDTH => 1280, - &RESIZE_MODE => ResizeMode::crop_and_scale(), - ], - settings![ - &DEVICE_ID => "1080p", - &HEIGHT => 1080, - &WIDTH => 1920, - &RESIZE_MODE => ResizeMode::none(), - ], - settings![ - &DEVICE_ID => "1440p", - &HEIGHT => 1440, - &WIDTH => 2560, - &RESIZE_MODE => ResizeMode::none(), - ], - settings![ - &DEVICE_ID => "2160p", - &HEIGHT => 2160, - &WIDTH => 3840, - &RESIZE_MODE => ResizeMode::none(), - ], - ]; - - let constraints = constraints! { - mandatory: { - &WIDTH => value_range_constraint!{ - max: 2560 - }, - &HEIGHT => value_range_constraint!{ - max: 1440 - }, - // Unsupported constraint, which should thus get ignored: - &FRAME_RATE => value_range_constraint!{ - exact: 30.0 - }, - }, - advanced: [ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - { - &HEIGHT => value_range_constraint!{ - exact: 800 - } - }, - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - { - &RESIZE_MODE => value_constraint!{ - exact: ResizeMode::none() - } - }, - ] - }; - - // Resolve bare values to proper constraints: - let resolved_constraints = constraints.into_resolved(); - - // Sanitize constraints, removing empty and unsupported constraints: - let sanitized_constraints = resolved_constraints.to_sanitized(&supported_constraints); - - let actual = select_settings_candidates( - &possible_settings, - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - - let expected = vec![&possible_settings[2], &possible_settings[3]]; - - assert_eq!(actual, expected); - } - - #[cfg(feature = "serde")] - #[test] - fn json() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - // Deserialize possible settings from JSON: - let possible_settings: Vec = { - let json = serde_json::json!([ - { "deviceId": "480p", "width": 720, "height": 480, "resizeMode": "crop-and-scale" }, - { "deviceId": "720p", "width": 1280, "height": 720, "resizeMode": "crop-and-scale" }, - { "deviceId": "1080p", "width": 1920, "height": 1080, "resizeMode": "none" }, - { "deviceId": "1440p", "width": 2560, "height": 1440, "resizeMode": "none" }, - { "deviceId": "2160p", "width": 3840, "height": 2160, "resizeMode": "none" }, - ]); - serde_json::from_value(json).unwrap() - }; - - // Deserialize constraints from JSON: - let constraints: MediaTrackConstraints = { - let json = serde_json::json!({ - "width": { - "max": 2560, - }, - "height": { - "max": 1440, - }, - // Unsupported constraint, which should thus get ignored: - "frameRate": { - "exact": 30.0 - }, - // Ideal resize-mode: - "resizeMode": "none", - "advanced": [ - // The first advanced constraint set of "exact 800p" does not match - // any candidate and should thus get ignored by the algorithm: - { "height": 800 }, - // The second advanced constraint set of "no resizing" does match - // candidates and should thus be applied by the algorithm: - { "resizeMode": "none" }, - ] - }); - serde_json::from_value(json).unwrap() - }; - - // Resolve bare values to proper constraints: - let resolved_constraints = constraints.into_resolved(); - - // Sanitize constraints, removing empty and unsupported constraints: - let sanitized_constraints = resolved_constraints.into_sanitized(&supported_constraints); - - let actual = select_settings_candidates( - &possible_settings, - &sanitized_constraints, - DeviceInformationExposureMode::Exposed, - ) - .unwrap(); - - let expected = vec![&possible_settings[2], &possible_settings[3]]; - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/algorithms/select_settings/tie_breaking.rs b/constraints/src/algorithms/select_settings/tie_breaking.rs deleted file mode 100644 index 52e5c2ac6..000000000 --- a/constraints/src/algorithms/select_settings/tie_breaking.rs +++ /dev/null @@ -1,187 +0,0 @@ -use std::iter::FromIterator; - -use ordered_float::NotNan; - -use crate::algorithms::FitnessDistance; -use crate::{ - MandatoryMediaTrackConstraints, MediaTrackSettings, MediaTrackSupportedConstraints, - SanitizedMandatoryMediaTrackConstraints, -}; - -/// A tie-breaking policy used for selecting a single preferred candidate -/// from a set list of equally optimal setting candidates. -pub trait TieBreakingPolicy { - /// Selects a preferred candidate from a non-empty selection of optimal candidates. - /// - /// As specified in step 6 of the `SelectSettings` algorithm: - /// - /// - /// > Select one settings dictionary from candidates, and return it as the result - /// > of the SelectSettings algorithm. The User Agent MUST use one with the - /// > smallest fitness distance, as calculated in step 3. - /// > If more than one settings dictionary have the smallest fitness distance, - /// > the User Agent chooses one of them based on system default property values - /// > and User Agent default property values. - fn select_candidate<'a, I>(&self, candidates: I) -> &'a MediaTrackSettings - where - I: IntoIterator; -} - -/// A naรฏve tie-breaking policy that just picks the first settings item it encounters. -pub struct FirstPolicy; - -impl FirstPolicy { - /// Creates a new policy. - pub fn new() -> Self { - Self - } -} - -impl Default for FirstPolicy { - fn default() -> Self { - Self::new() - } -} - -impl TieBreakingPolicy for FirstPolicy { - fn select_candidate<'a, I>(&self, candidates: I) -> &'a MediaTrackSettings - where - I: IntoIterator, - { - // Safety: We know that `candidates is non-empty: - candidates - .into_iter() - .next() - .expect("The `candidates` iterator should have produced at least one item.") - } -} - -/// A tie-breaking policy that picks the settings item that's closest to the specified ideal settings. -pub struct ClosestToIdealPolicy { - sanitized_constraints: SanitizedMandatoryMediaTrackConstraints, -} - -impl ClosestToIdealPolicy { - /// Creates a new policy from the given ideal settings and supported constraints. - pub fn new( - ideal_settings: MediaTrackSettings, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> Self { - let sanitized_constraints = MandatoryMediaTrackConstraints::from_iter( - ideal_settings - .into_iter() - .map(|(property, setting)| (property, setting.into())), - ) - .into_resolved() - .into_sanitized(supported_constraints); - - Self { - sanitized_constraints, - } - } -} - -impl TieBreakingPolicy for ClosestToIdealPolicy { - fn select_candidate<'b, I>(&self, candidates: I) -> &'b MediaTrackSettings - where - I: IntoIterator, - { - candidates - .into_iter() - .min_by_key(|settings| { - let fitness_distance = self - .sanitized_constraints - .fitness_distance(settings) - .expect("Fitness distance should be positive."); - NotNan::new(fitness_distance).expect("Expected non-NaN fitness distance.") - }) - .expect("The `candidates` iterator should have produced at least one item.") - } -} - -#[cfg(test)] -mod tests { - use std::iter::FromIterator; - - use super::*; - use crate::property::all::name::*; - use crate::{MediaTrackSettings, MediaTrackSupportedConstraints, ResizeMode}; - - #[test] - fn first() { - let settings = vec![ - MediaTrackSettings::from_iter([(&DEVICE_ID, "device-id-0".into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "device-id-1".into())]), - MediaTrackSettings::from_iter([(&DEVICE_ID, "device-id-2".into())]), - ]; - - let policy = FirstPolicy; - - let actual = policy.select_candidate(&settings); - - let expected = &settings[0]; - - assert_eq!(actual, expected); - } - - #[test] - fn closest_to_ideal() { - let supported_constraints = MediaTrackSupportedConstraints::from_iter(vec![ - &DEVICE_ID, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ]); - - let settings = vec![ - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "480p".into()), - (&HEIGHT, 480.into()), - (&WIDTH, 720.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "720p".into()), - (&HEIGHT, 720.into()), - (&WIDTH, 1280.into()), - (&RESIZE_MODE, ResizeMode::crop_and_scale().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1080p".into()), - (&HEIGHT, 1080.into()), - (&WIDTH, 1920.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "1440p".into()), - (&HEIGHT, 1440.into()), - (&WIDTH, 2560.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - MediaTrackSettings::from_iter([ - (&DEVICE_ID, "2160p".into()), - (&HEIGHT, 2160.into()), - (&WIDTH, 3840.into()), - (&RESIZE_MODE, ResizeMode::none().into()), - ]), - ]; - - let ideal_settings = vec![ - MediaTrackSettings::from_iter([(&HEIGHT, 450.into()), (&WIDTH, 700.into())]), - MediaTrackSettings::from_iter([(&HEIGHT, 700.into()), (&WIDTH, 1250.into())]), - MediaTrackSettings::from_iter([(&HEIGHT, 1000.into()), (&WIDTH, 2000.into())]), - MediaTrackSettings::from_iter([(&HEIGHT, 1500.into()), (&WIDTH, 2500.into())]), - MediaTrackSettings::from_iter([(&HEIGHT, 2000.into()), (&WIDTH, 3750.into())]), - ]; - - for (index, ideal) in ideal_settings.iter().enumerate() { - let policy = ClosestToIdealPolicy::new(ideal.clone(), &supported_constraints); - - let actual = policy.select_candidate(&settings); - - let expected = &settings[index]; - - assert_eq!(actual, expected); - } - } -} diff --git a/constraints/src/capabilities.rs b/constraints/src/capabilities.rs deleted file mode 100644 index b0b0cdc62..000000000 --- a/constraints/src/capabilities.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod stream; -mod track; - -pub(crate) use self::stream::MediaStreamCapabilities; -pub use self::track::MediaTrackCapabilities; diff --git a/constraints/src/capabilities/stream.rs b/constraints/src/capabilities/stream.rs deleted file mode 100644 index cc7200099..000000000 --- a/constraints/src/capabilities/stream.rs +++ /dev/null @@ -1,58 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::MediaTrackCapabilities; - -/// The capabilities of a [`MediaStream`][media_stream] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastream -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Default, Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub(crate) struct MediaStreamCapabilities { - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub audio: Option, - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub video: Option, -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaStreamCapabilities; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = Subject { - audio: Some(MediaTrackCapabilities::default()), - video: None, - }; - let json = serde_json::json!({ - "audio": {} - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/capabilities/track.rs b/constraints/src/capabilities/track.rs deleted file mode 100644 index 6fb7f8ac8..000000000 --- a/constraints/src/capabilities/track.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::collections::HashMap; -use std::iter::FromIterator; -use std::ops::{Deref, DerefMut}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::{MediaTrackCapability, MediaTrackProperty}; - -/// The capabilities of a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`MediaTrackCapabilities`][media_track_capabilities] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// The W3C spec defines `MediaTrackSettings` in terma of a dictionary, -/// which per the [WebIDL spec][webidl_spec] is an ordered map (e.g. `IndexMap`). -/// Since the spec however does not make use of the order of items -/// in the map we use a simple `HashMap`. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_capabilities]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackcapabilities -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -/// [webidl_spec]: https://webidl.spec.whatwg.org/#idl-dictionaries -#[derive(Debug, Clone, Default, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct MediaTrackCapabilities(HashMap); - -impl MediaTrackCapabilities { - /// Creates a capabilities value from its inner hashmap. - pub fn new(capabilities: HashMap) -> Self { - Self(capabilities) - } - - /// Consumes the value, returning its inner hashmap. - pub fn into_inner(self) -> HashMap { - self.0 - } -} - -impl Deref for MediaTrackCapabilities { - type Target = HashMap; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for MediaTrackCapabilities { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl FromIterator<(T, MediaTrackCapability)> for MediaTrackCapabilities -where - T: Into, -{ - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - Self::new(iter.into_iter().map(|(k, v)| (k.into(), v)).collect()) - } -} - -impl IntoIterator for MediaTrackCapabilities { - type Item = (MediaTrackProperty, MediaTrackCapability); - type IntoIter = std::collections::hash_map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::property::all::name::*; - - type Subject = MediaTrackCapabilities; - - #[test] - fn into_inner() { - let hash_map = HashMap::from_iter([ - (DEVICE_ID.clone(), "device-id".into()), - (AUTO_GAIN_CONTROL.clone(), true.into()), - (CHANNEL_COUNT.clone(), (12..=34).into()), - (LATENCY.clone(), (1.2..=3.4).into()), - ]); - - let subject = Subject::new(hash_map.clone()); - - let actual = subject.into_inner(); - - let expected = hash_map; - - assert_eq!(actual, expected); - } - - #[test] - fn into_iter() { - let hash_map = HashMap::from_iter([ - (DEVICE_ID.clone(), "device-id".into()), - (AUTO_GAIN_CONTROL.clone(), true.into()), - (CHANNEL_COUNT.clone(), (12..=34).into()), - (LATENCY.clone(), (1.2..=3.4).into()), - ]); - - let subject = Subject::new(hash_map.clone()); - - let actual: HashMap<_, _> = subject.into_iter().collect(); - - let expected = hash_map; - - assert_eq!(actual, expected); - } - - #[test] - fn deref_and_deref_mut() { - let mut subject = Subject::default(); - - // Deref mut: - subject.insert(DEVICE_ID.clone(), "device-id".into()); - - // Deref: - assert!(subject.contains_key(&DEVICE_ID)); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - use crate::property::all::name::*; - - type Subject = MediaTrackCapabilities; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = Subject::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, (12..=34).into()), - (&LATENCY, (1.2..=3.4).into()), - ]); - let json = serde_json::json!({ - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": { "min": 12, "max": 34 }, - "latency": { "min": 1.2, "max": 3.4 }, - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/capability.rs b/constraints/src/capability.rs deleted file mode 100644 index 28621b643..000000000 --- a/constraints/src/capability.rs +++ /dev/null @@ -1,219 +0,0 @@ -mod value; -mod value_range; -mod value_sequence; - -use std::ops::RangeInclusive; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -pub use self::value::MediaTrackValueCapability; -pub use self::value_range::MediaTrackValueRangeCapability; -pub use self::value_sequence::MediaTrackValueSequenceCapability; - -/// A single [capability][media_track_capabilities] value of a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_capabilities]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackcapabilities -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum MediaTrackCapability { - // IMPORTANT: - // `BoolSequence` must be ordered before `Bool(โ€ฆ)` in order for - // `serde` to decode the correct variant. - /// A sequence of boolean-valued media track capabilities. - BoolSequence(MediaTrackValueSequenceCapability), - /// A single boolean-valued media track capability. - Bool(MediaTrackValueCapability), - // `IntegerRange` must be ordered before `FloatRange(โ€ฆ)` in order for - // `serde` to decode the correct variant. - /// A range of integer-valued media track capabilities. - IntegerRange(MediaTrackValueRangeCapability), - /// A range of floating-point-valued media track capabilities. - FloatRange(MediaTrackValueRangeCapability), - // IMPORTANT: - // `StringSequence` must be ordered before `String(โ€ฆ)` in order for - // `serde` to decode the correct variant. - /// A sequence of string-valued media track capabilities. - StringSequence(MediaTrackValueSequenceCapability), - /// A single string-valued media track capability. - String(MediaTrackValueCapability), -} - -impl From for MediaTrackCapability { - fn from(capability: bool) -> Self { - Self::Bool(capability.into()) - } -} - -impl From> for MediaTrackCapability { - fn from(capability: Vec) -> Self { - Self::BoolSequence(capability.into()) - } -} - -impl From> for MediaTrackCapability { - fn from(capability: RangeInclusive) -> Self { - Self::IntegerRange(capability.into()) - } -} - -impl From> for MediaTrackCapability { - fn from(capability: RangeInclusive) -> Self { - Self::FloatRange(capability.into()) - } -} - -impl From for MediaTrackCapability { - fn from(capability: String) -> Self { - Self::String(capability.into()) - } -} - -impl<'a> From<&'a str> for MediaTrackCapability { - fn from(capability: &'a str) -> Self { - let capability: String = capability.to_owned(); - Self::from(capability) - } -} - -impl From> for MediaTrackCapability { - fn from(capability: Vec) -> Self { - Self::StringSequence(capability.into()) - } -} - -impl From> for MediaTrackCapability { - fn from(capability: Vec<&str>) -> Self { - let capability: Vec = capability.into_iter().map(|c| c.to_owned()).collect(); - Self::from(capability) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Subject = MediaTrackCapability; - - mod from { - use super::*; - - #[test] - fn bool_sequence() { - let actual = Subject::from(vec![false, true]); - let expected = Subject::BoolSequence(vec![false, true].into()); - - assert_eq!(actual, expected); - } - - #[test] - fn bool() { - let actual = Subject::from(true); - let expected = Subject::Bool(true.into()); - - assert_eq!(actual, expected); - } - - #[test] - fn integer_range() { - let actual = Subject::from(12..=34); - let expected = Subject::IntegerRange((12..=34).into()); - - assert_eq!(actual, expected); - } - - #[test] - fn float() { - let actual = Subject::from(1.2..=3.4); - let expected = Subject::FloatRange((1.2..=3.4).into()); - - assert_eq!(actual, expected); - } - - #[test] - fn string_sequence() { - let actual = Subject::from(vec!["foo".to_owned(), "bar".to_owned()]); - let expected = Subject::StringSequence(vec!["foo".to_owned(), "bar".to_owned()].into()); - - assert_eq!(actual, expected); - } - - #[test] - fn string() { - let actual = Subject::from("foo".to_owned()); - let expected = Subject::String("foo".to_owned().into()); - - assert_eq!(actual, expected); - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackCapability; - - #[test] - fn bool_sequence() { - let subject = Subject::BoolSequence(vec![false, true].into()); - let json = serde_json::json!([false, true]); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn bool() { - let subject = Subject::Bool(true.into()); - let json = serde_json::json!(true); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn integer_range() { - let subject = Subject::IntegerRange((12..=34).into()); - let json = serde_json::json!({ - "min": 12, - "max": 34, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn float() { - let subject = Subject::FloatRange((1.2..=3.4).into()); - let json = serde_json::json!({ - "min": 1.2, - "max": 3.4, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string_sequence() { - let subject = Subject::StringSequence(vec!["foo".to_owned(), "bar".to_owned()].into()); - let json = serde_json::json!(["foo", "bar"]); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string() { - let subject = Subject::String("foo".to_owned().into()); - let json = serde_json::json!("foo"); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/capability/value.rs b/constraints/src/capability/value.rs deleted file mode 100644 index a55c0a32e..000000000 --- a/constraints/src/capability/value.rs +++ /dev/null @@ -1,74 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// A capability specifying a single supported value. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `MediaTrackValueCapability` type aims to be a -/// generalization over multiple types in the W3C spec: -/// -/// | Rust | W3C | -/// | ----------------------------------- | ------------------------- | -/// | `MediaTrackValueCapability` | [`DOMString`][dom_string] | -/// -/// [dom_string]: https://webidl.spec.whatwg.org/#idl-DOMString -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct MediaTrackValueCapability { - pub value: T, -} - -impl From for MediaTrackValueCapability { - fn from(value: T) -> Self { - Self { value } - } -} - -impl From<&str> for MediaTrackValueCapability { - fn from(value: &str) -> Self { - Self { - value: value.to_owned(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Subject = MediaTrackValueCapability; - - #[test] - fn from_str() { - let subject = Subject::from("string"); - - let actual = subject.value.as_str(); - let expected = "string"; - - assert_eq!(actual, expected); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackValueCapability; - - #[test] - fn customized() { - let subject = Subject { - value: "string".to_owned(), - }; - let json = serde_json::json!("string"); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/capability/value_range.rs b/constraints/src/capability/value_range.rs deleted file mode 100644 index 786b9b7c8..000000000 --- a/constraints/src/capability/value_range.rs +++ /dev/null @@ -1,207 +0,0 @@ -use std::ops::{RangeFrom, RangeInclusive, RangeToInclusive}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// A capability specifying a range of supported values. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `MediaTrackValueRangeCapability` type aims to be a -/// generalization over multiple types in the W3C spec: -/// -/// | Rust | W3C | -/// | ------------------------------------- | ----------------------------- | -/// | `MediaTrackValueRangeCapability` | [`ULongRange`][ulong_range] | -/// | `MediaTrackValueRangeCapability` | [`DoubleRange`][double_range] | -/// -/// [double_range]: https://www.w3.org/TR/mediacapture-streams/#dom-doublerange -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -/// [ulong_range]: https://www.w3.org/TR/mediacapture-streams/#dom-ulongrange -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub struct MediaTrackValueRangeCapability { - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub min: Option, - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub max: Option, -} - -impl Default for MediaTrackValueRangeCapability { - fn default() -> Self { - Self { - min: Default::default(), - max: Default::default(), - } - } -} - -impl From> for MediaTrackValueRangeCapability { - fn from(range: RangeInclusive) -> Self { - let (min, max) = range.into_inner(); - Self { - min: Some(min), - max: Some(max), - } - } -} - -impl From> for MediaTrackValueRangeCapability { - fn from(range: RangeFrom) -> Self { - Self { - min: Some(range.start), - max: None, - } - } -} - -impl From> for MediaTrackValueRangeCapability { - fn from(range: RangeToInclusive) -> Self { - Self { - min: None, - max: Some(range.end), - } - } -} - -impl MediaTrackValueRangeCapability { - pub fn contains(&self, value: &T) -> bool - where - T: PartialOrd, - { - // FIXME(regexident): replace with if-let-chain, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/53667 - if let Some(ref min) = self.min { - if min > value { - return false; - } - } - // FIXME(regexident): replace with if-let-chain, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/53667 - if let Some(ref max) = self.max { - if max < value { - return false; - } - } - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Subject = MediaTrackValueRangeCapability; - - #[test] - fn default() { - let subject = Subject::default(); - - assert_eq!(subject.min, None); - assert_eq!(subject.max, None); - } - - mod from { - use super::*; - - #[test] - fn range_inclusive() { - let subject = Subject::from(1..=5); - - assert_eq!(subject.min, Some(1)); - assert_eq!(subject.max, Some(5)); - } - - #[test] - fn range_from() { - let subject = Subject::from(1..); - - assert_eq!(subject.min, Some(1)); - assert_eq!(subject.max, None); - } - - #[test] - fn range_to_inclusive() { - let subject = Subject::from(..=5); - - assert_eq!(subject.min, None); - assert_eq!(subject.max, Some(5)); - } - } - - mod contains { - use super::*; - - #[test] - fn default() { - let subject = Subject::default(); - - assert!(subject.contains(&0)); - assert!(subject.contains(&1)); - assert!(subject.contains(&5)); - assert!(subject.contains(&6)); - } - - #[test] - fn from_range_inclusive() { - let subject = Subject::from(1..=5); - - assert!(!subject.contains(&0)); - assert!(subject.contains(&1)); - assert!(subject.contains(&5)); - assert!(!subject.contains(&6)); - } - - #[test] - fn from_range_from() { - let subject = Subject::from(1..); - - assert!(!subject.contains(&0)); - assert!(subject.contains(&1)); - assert!(subject.contains(&5)); - assert!(subject.contains(&6)); - } - - #[test] - fn from_range_to_inclusive() { - let subject = Subject::from(..=5); - - assert!(subject.contains(&0)); - assert!(subject.contains(&1)); - assert!(subject.contains(&5)); - assert!(!subject.contains(&6)); - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackValueRangeCapability; - - #[test] - fn customized() { - let subject = Subject { - min: Some(12), - max: Some(34), - }; - let json = serde_json::json!({ - "min": 12, - "max": 34, - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/capability/value_sequence.rs b/constraints/src/capability/value_sequence.rs deleted file mode 100644 index 0a090c499..000000000 --- a/constraints/src/capability/value_sequence.rs +++ /dev/null @@ -1,99 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// A capability specifying a range of supported values. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`sequence`][sequence] from the W3C ["WebIDL"][webidl_spec] spec: -/// -/// | Rust | W3C | -/// | ----------------------------------------- | --------------------- | -/// | `MediaTrackValueSequenceCapability` | `sequence` | -/// | `MediaTrackValueSequenceCapability` | `sequence` | -/// -/// [sequence]: https://webidl.spec.whatwg.org/#idl-sequence -/// [webidl_spec]: https://webidl.spec.whatwg.org/ -#[derive(Default, Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct MediaTrackValueSequenceCapability { - pub values: Vec, -} - -impl From for MediaTrackValueSequenceCapability { - fn from(value: T) -> Self { - Self { - values: vec![value], - } - } -} - -impl From> for MediaTrackValueSequenceCapability { - fn from(values: Vec) -> Self { - Self { values } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Subject = MediaTrackValueSequenceCapability; - - #[test] - fn default() { - let subject = Subject::default(); - - let actual = subject.values; - - let expected: Vec = vec![]; - - assert_eq!(actual, expected); - } - - mod from { - use super::*; - - #[test] - fn value() { - let subject = Subject::from("foo".to_owned()); - - let actual = subject.values; - - let expected: Vec = vec!["foo".to_owned()]; - - assert_eq!(actual, expected); - } - - #[test] - fn values() { - let subject = Subject::from(vec!["foo".to_owned(), "bar".to_owned()]); - - let actual = subject.values; - - let expected: Vec = vec!["foo".to_owned(), "bar".to_owned()]; - - assert_eq!(actual, expected); - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackValueSequenceCapability; - - #[test] - fn customized() { - let subject = Subject { - values: vec!["foo".to_owned(), "bar".to_owned()], - }; - let json = serde_json::json!(["foo", "bar"]); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/constraint.rs b/constraints/src/constraint.rs deleted file mode 100644 index 2ad4f10b4..000000000 --- a/constraints/src/constraint.rs +++ /dev/null @@ -1,789 +0,0 @@ -use std::ops::Deref; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -pub use self::value::{ResolvedValueConstraint, ValueConstraint}; -pub use self::value_range::{ResolvedValueRangeConstraint, ValueRangeConstraint}; -pub use self::value_sequence::{ResolvedValueSequenceConstraint, ValueSequenceConstraint}; -use crate::MediaTrackSetting; - -mod value; -mod value_range; -mod value_sequence; - -/// An empty [constraint][media_track_constraints] value for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// The purpose of this type is to reduce parsing ambiguity, since all constraint variant types -/// support serializing from an empty map, but an empty map isn't typed, really, -/// so parsing to a specifically typed constraint would be wrong, type-wise. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(deny_unknown_fields))] -pub struct EmptyConstraint {} - -/// The strategy of a track [constraint][constraint]. -/// -/// [constraint]: https://www.w3.org/TR/mediacapture-streams/#dfn-constraint -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -pub enum MediaTrackConstraintResolutionStrategy { - /// Resolve bare values to `ideal` constraints. - BareToIdeal, - /// Resolve bare values to `exact` constraints. - BareToExact, -} - -/// A single [constraint][media_track_constraints] value for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum MediaTrackConstraint { - /// An empty constraint. - Empty(EmptyConstraint), - // `IntegerRange` must be ordered before `FloatRange(โ€ฆ)` in order for - // `serde` to decode the correct variant. - /// An integer-valued media track range constraint. - IntegerRange(ValueRangeConstraint), - /// An floating-point-valued media track range constraint. - FloatRange(ValueRangeConstraint), - // `Bool` must be ordered after `IntegerRange(โ€ฆ)`/`FloatRange(โ€ฆ)` in order for - // `serde` to decode the correct variant. - /// A single boolean-valued media track constraint. - Bool(ValueConstraint), - // `StringSequence` must be ordered before `String(โ€ฆ)` in order for - // `serde` to decode the correct variant. - /// A sequence of string-valued media track constraints. - StringSequence(ValueSequenceConstraint), - /// A single string-valued media track constraint. - String(ValueConstraint), -} - -impl Default for MediaTrackConstraint { - fn default() -> Self { - Self::Empty(EmptyConstraint {}) - } -} - -// Bool constraint: - -impl From for MediaTrackConstraint { - fn from(bare: bool) -> Self { - Self::Bool(bare.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ResolvedValueConstraint) -> Self { - Self::Bool(constraint.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ValueConstraint) -> Self { - Self::Bool(constraint) - } -} - -// Unsigned integer range constraint: - -impl From for MediaTrackConstraint { - fn from(bare: u64) -> Self { - Self::IntegerRange(bare.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ResolvedValueRangeConstraint) -> Self { - Self::IntegerRange(constraint.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ValueRangeConstraint) -> Self { - Self::IntegerRange(constraint) - } -} - -// Floating-point range constraint: - -impl From for MediaTrackConstraint { - fn from(bare: f64) -> Self { - Self::FloatRange(bare.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ResolvedValueRangeConstraint) -> Self { - Self::FloatRange(constraint.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ValueRangeConstraint) -> Self { - Self::FloatRange(constraint) - } -} - -// String sequence constraint: - -impl From> for MediaTrackConstraint { - fn from(bare: Vec) -> Self { - Self::StringSequence(bare.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(bare: Vec<&str>) -> Self { - let bare: Vec = bare.into_iter().map(|c| c.to_owned()).collect(); - Self::from(bare) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ResolvedValueSequenceConstraint) -> Self { - Self::StringSequence(constraint.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ValueSequenceConstraint) -> Self { - Self::StringSequence(constraint) - } -} - -// String constraint: - -impl From for MediaTrackConstraint { - fn from(bare: String) -> Self { - Self::String(bare.into()) - } -} - -impl<'a> From<&'a str> for MediaTrackConstraint { - fn from(bare: &'a str) -> Self { - let bare: String = bare.to_owned(); - Self::from(bare) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ResolvedValueConstraint) -> Self { - Self::String(constraint.into()) - } -} - -impl From> for MediaTrackConstraint { - fn from(constraint: ValueConstraint) -> Self { - Self::String(constraint) - } -} - -// Conversion from settings: - -impl From for MediaTrackConstraint { - fn from(settings: MediaTrackSetting) -> Self { - match settings { - MediaTrackSetting::Bool(value) => Self::Bool(value.into()), - MediaTrackSetting::Integer(value) => { - Self::IntegerRange((value.clamp(0, i64::MAX) as u64).into()) - } - MediaTrackSetting::Float(value) => Self::FloatRange(value.into()), - MediaTrackSetting::String(value) => Self::String(value.into()), - } - } -} - -impl MediaTrackConstraint { - /// Returns `true` if `self` is empty, otherwise `false`. - pub fn is_empty(&self) -> bool { - match self { - Self::Empty(_) => true, - Self::IntegerRange(constraint) => constraint.is_empty(), - Self::FloatRange(constraint) => constraint.is_empty(), - Self::Bool(constraint) => constraint.is_empty(), - Self::StringSequence(constraint) => constraint.is_empty(), - Self::String(constraint) => constraint.is_empty(), - } - } - - /// Returns a resolved representation of the constraint - /// with bare values resolved to fully-qualified constraints. - pub fn to_resolved( - &self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedMediaTrackConstraint { - self.clone().into_resolved(strategy) - } - - /// Consumes the constraint, returning a resolved representation of the - /// constraint with bare values resolved to fully-qualified constraints. - pub fn into_resolved( - self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedMediaTrackConstraint { - match self { - Self::Empty(constraint) => ResolvedMediaTrackConstraint::Empty(constraint), - Self::IntegerRange(constraint) => { - ResolvedMediaTrackConstraint::IntegerRange(constraint.into_resolved(strategy)) - } - Self::FloatRange(constraint) => { - ResolvedMediaTrackConstraint::FloatRange(constraint.into_resolved(strategy)) - } - Self::Bool(constraint) => { - ResolvedMediaTrackConstraint::Bool(constraint.into_resolved(strategy)) - } - Self::StringSequence(constraint) => { - ResolvedMediaTrackConstraint::StringSequence(constraint.into_resolved(strategy)) - } - Self::String(constraint) => { - ResolvedMediaTrackConstraint::String(constraint.into_resolved(strategy)) - } - } - } -} - -/// A single [constraint][media_track_constraints] value for a [`MediaStreamTrack`][media_stream_track] object -/// with its potential bare value either resolved to an `exact` or `ideal` constraint. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum ResolvedMediaTrackConstraint { - /// An empty constraint. - Empty(EmptyConstraint), - /// An integer-valued media track range constraint. - IntegerRange(ResolvedValueRangeConstraint), - /// An floating-point-valued media track range constraint. - FloatRange(ResolvedValueRangeConstraint), - /// A single boolean-valued media track constraint. - Bool(ResolvedValueConstraint), - /// A sequence of string-valued media track constraints. - StringSequence(ResolvedValueSequenceConstraint), - /// A single string-valued media track constraint. - String(ResolvedValueConstraint), -} - -impl Default for ResolvedMediaTrackConstraint { - fn default() -> Self { - Self::Empty(EmptyConstraint {}) - } -} - -impl std::fmt::Display for ResolvedMediaTrackConstraint { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Empty(_constraint) => "".fmt(f), - Self::IntegerRange(constraint) => constraint.fmt(f), - Self::FloatRange(constraint) => constraint.fmt(f), - Self::Bool(constraint) => constraint.fmt(f), - Self::StringSequence(constraint) => constraint.fmt(f), - Self::String(constraint) => constraint.fmt(f), - } - } -} - -// Bool constraint: - -impl From> for ResolvedMediaTrackConstraint { - fn from(constraint: ResolvedValueConstraint) -> Self { - Self::Bool(constraint) - } -} - -// Unsigned integer range constraint: - -impl From> for ResolvedMediaTrackConstraint { - fn from(constraint: ResolvedValueRangeConstraint) -> Self { - Self::IntegerRange(constraint) - } -} - -// Floating-point range constraint: - -impl From> for ResolvedMediaTrackConstraint { - fn from(constraint: ResolvedValueRangeConstraint) -> Self { - Self::FloatRange(constraint) - } -} - -// String sequence constraint: - -impl From> for ResolvedMediaTrackConstraint { - fn from(constraint: ResolvedValueSequenceConstraint) -> Self { - Self::StringSequence(constraint) - } -} - -// String constraint: - -impl From> for ResolvedMediaTrackConstraint { - fn from(constraint: ResolvedValueConstraint) -> Self { - Self::String(constraint) - } -} - -impl ResolvedMediaTrackConstraint { - /// Creates a resolved media track constraint by resolving - /// bare values to exact constraints: `{ exact: bare }`. - pub fn exact_from(setting: MediaTrackSetting) -> Self { - MediaTrackConstraint::from(setting) - .into_resolved(MediaTrackConstraintResolutionStrategy::BareToExact) - } - - /// Creates a resolved media track constraint by resolving - /// bare values to ideal constraints: `{ ideal: bare }`. - pub fn ideal_from(setting: MediaTrackSetting) -> Self { - MediaTrackConstraint::from(setting) - .into_resolved(MediaTrackConstraintResolutionStrategy::BareToIdeal) - } - - /// Returns `true` if `self` is required, otherwise `false`. - pub fn is_required(&self) -> bool { - match self { - Self::Empty(_constraint) => false, - Self::IntegerRange(constraint) => constraint.is_required(), - Self::FloatRange(constraint) => constraint.is_required(), - Self::Bool(constraint) => constraint.is_required(), - Self::StringSequence(constraint) => constraint.is_required(), - Self::String(constraint) => constraint.is_required(), - } - } - - /// Returns `true` if `self` is empty, otherwise `false`. - pub fn is_empty(&self) -> bool { - match self { - Self::Empty(_constraint) => true, - Self::IntegerRange(constraint) => constraint.is_empty(), - Self::FloatRange(constraint) => constraint.is_empty(), - Self::Bool(constraint) => constraint.is_empty(), - Self::StringSequence(constraint) => constraint.is_empty(), - Self::String(constraint) => constraint.is_empty(), - } - } - - /// Returns a corresponding constraint containing only required values. - pub fn to_required_only(&self) -> Self { - self.clone().into_required_only() - } - - /// Consumes `self, returning a corresponding constraint - /// containing only required values. - pub fn into_required_only(self) -> Self { - match self { - Self::Empty(constraint) => Self::Empty(constraint), - Self::IntegerRange(constraint) => Self::IntegerRange(constraint.into_required_only()), - Self::FloatRange(constraint) => Self::FloatRange(constraint.into_required_only()), - Self::Bool(constraint) => Self::Bool(constraint.into_required_only()), - Self::StringSequence(constraint) => { - Self::StringSequence(constraint.into_required_only()) - } - Self::String(constraint) => Self::String(constraint.into_required_only()), - } - } - - /// Returns a corresponding sanitized constraint - /// if `self` is non-empty, otherwise `None`. - pub fn to_sanitized(&self) -> Option { - self.clone().into_sanitized() - } - - /// Consumes `self`, returning a corresponding sanitized constraint - /// if `self` is non-empty, otherwise `None`. - pub fn into_sanitized(self) -> Option { - if self.is_empty() { - return None; - } - - Some(SanitizedMediaTrackConstraint(self)) - } -} - -/// A single non-empty [constraint][media_track_constraints] value for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # Invariant -/// -/// The wrapped `ResolvedMediaTrackConstraint` MUST not be empty. -/// -/// To enforce this invariant the only way to create an instance of this type -/// is by calling `constraint.to_sanitized()`/`constraint.into_sanitized()` on -/// an instance of `ResolvedMediaTrackConstraint`, which returns `None` if `self` is empty. -/// -/// Further more `self.0` MUST NOT be exposed mutably, -/// as otherwise it could become empty via mutation. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints -#[derive(Debug, Clone, PartialEq)] -pub struct SanitizedMediaTrackConstraint(ResolvedMediaTrackConstraint); - -impl Deref for SanitizedMediaTrackConstraint { - type Target = ResolvedMediaTrackConstraint; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl SanitizedMediaTrackConstraint { - /// Consumes `self` returning its inner resolved constraint. - pub fn into_inner(self) -> ResolvedMediaTrackConstraint { - self.0 - } -} - -#[cfg(test)] -mod tests { - use MediaTrackConstraintResolutionStrategy::*; - - use super::*; - - type Subject = MediaTrackConstraint; - - #[test] - fn default() { - let subject = Subject::default(); - - let actual = subject.is_empty(); - let expected = true; - - assert_eq!(actual, expected); - } - - mod from { - - use super::*; - - #[test] - fn setting() { - use crate::MediaTrackSetting; - - assert!(matches!( - Subject::from(MediaTrackSetting::Bool(true)), - Subject::Bool(ValueConstraint::Bare(_)) - )); - assert!(matches!( - Subject::from(MediaTrackSetting::Integer(42)), - Subject::IntegerRange(ValueRangeConstraint::Bare(_)) - )); - assert!(matches!( - Subject::from(MediaTrackSetting::Float(4.2)), - Subject::FloatRange(ValueRangeConstraint::Bare(_)) - )); - assert!(matches!( - Subject::from(MediaTrackSetting::String("string".to_owned())), - Subject::String(ValueConstraint::Bare(_)) - )); - } - - #[test] - fn bool() { - let subjects = [ - Subject::from(false), - Subject::from(ValueConstraint::::default()), - Subject::from(ResolvedValueConstraint::::default()), - ]; - - for subject in subjects { - // TODO: replace with `assert_matches!(โ€ฆ)`, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/82775 - assert!(matches!(subject, Subject::Bool(_))); - } - } - - #[test] - fn integer_range() { - let subjects = [ - Subject::from(42_u64), - Subject::from(ValueRangeConstraint::::default()), - Subject::from(ResolvedValueRangeConstraint::::default()), - ]; - - for subject in subjects { - // TODO: replace with `assert_matches!(โ€ฆ)`, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/82775 - assert!(matches!(subject, Subject::IntegerRange(_))); - } - } - - #[test] - fn float_range() { - let subjects = [ - Subject::from(42.0_f64), - Subject::from(ValueRangeConstraint::::default()), - Subject::from(ResolvedValueRangeConstraint::::default()), - ]; - - for subject in subjects { - // TODO: replace with `assert_matches!(โ€ฆ)`, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/82775 - assert!(matches!(subject, Subject::FloatRange(_))); - } - } - - #[test] - fn string() { - let subjects = [ - Subject::from(""), - Subject::from(String::new()), - Subject::from(ValueConstraint::::default()), - Subject::from(ResolvedValueConstraint::::default()), - ]; - - for subject in subjects { - // TODO: replace with `assert_matches!(โ€ฆ)`, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/82775 - assert!(matches!(subject, Subject::String(_))); - } - } - - #[test] - fn string_sequence() { - let subjects = [ - Subject::from(vec![""]), - Subject::from(vec![String::new()]), - Subject::from(ValueSequenceConstraint::::default()), - Subject::from(ResolvedValueSequenceConstraint::::default()), - ]; - - for subject in subjects { - // TODO: replace with `assert_matches!(โ€ฆ)`, once stabilized: - // Tracking issue: https://github.com/rust-lang/rust/issues/82775 - assert!(matches!(subject, Subject::StringSequence(_))); - } - } - } - - #[test] - fn is_empty() { - let empty_subject = Subject::Empty(EmptyConstraint {}); - - assert!(empty_subject.is_empty()); - - let non_empty_subjects = [ - Subject::Bool(ValueConstraint::Bare(true)), - Subject::FloatRange(ValueRangeConstraint::Bare(42.0)), - Subject::IntegerRange(ValueRangeConstraint::Bare(42)), - Subject::String(ValueConstraint::Bare("string".to_owned())), - Subject::StringSequence(ValueSequenceConstraint::Bare(vec!["string".to_owned()])), - ]; - - for non_empty_subject in non_empty_subjects { - assert!(!non_empty_subject.is_empty()); - } - } - - #[test] - fn to_resolved() { - let subjects = [ - ( - Subject::Empty(EmptyConstraint {}), - ResolvedMediaTrackConstraint::Empty(EmptyConstraint {}), - ), - ( - Subject::Bool(ValueConstraint::Bare(true)), - ResolvedMediaTrackConstraint::Bool(ResolvedValueConstraint::default().exact(true)), - ), - ( - Subject::FloatRange(ValueRangeConstraint::Bare(42.0)), - ResolvedMediaTrackConstraint::FloatRange( - ResolvedValueRangeConstraint::default().exact(42.0), - ), - ), - ( - Subject::IntegerRange(ValueRangeConstraint::Bare(42)), - ResolvedMediaTrackConstraint::IntegerRange( - ResolvedValueRangeConstraint::default().exact(42), - ), - ), - ( - Subject::String(ValueConstraint::Bare("string".to_owned())), - ResolvedMediaTrackConstraint::String( - ResolvedValueConstraint::default().exact("string".to_owned()), - ), - ), - ( - Subject::StringSequence(ValueSequenceConstraint::Bare(vec!["string".to_owned()])), - ResolvedMediaTrackConstraint::StringSequence( - ResolvedValueSequenceConstraint::default().exact(vec!["string".to_owned()]), - ), - ), - ]; - - for (subject, expected) in subjects { - let actual = subject.to_resolved(BareToExact); - - assert_eq!(actual, expected); - } - } - - mod resolved { - use super::*; - - type Subject = ResolvedMediaTrackConstraint; - - #[test] - fn to_string() { - let scenarios = [ - (Subject::Empty(EmptyConstraint {}), ""), - ( - Subject::Bool(ResolvedValueConstraint::default().exact(true)), - "(x == true)", - ), - ( - Subject::FloatRange(ResolvedValueRangeConstraint::default().exact(42.0)), - "(x == 42.0)", - ), - ( - Subject::IntegerRange(ResolvedValueRangeConstraint::default().exact(42)), - "(x == 42)", - ), - ( - Subject::String(ResolvedValueConstraint::default().exact("string".to_owned())), - "(x == \"string\")", - ), - ( - Subject::StringSequence( - ResolvedValueSequenceConstraint::default().exact(vec!["string".to_owned()]), - ), - "(x == [\"string\"])", - ), - ]; - - for (subject, expected) in scenarios { - let actual = subject.to_string(); - - assert_eq!(actual, expected); - } - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackConstraint; - - #[test] - fn empty() { - let subject = Subject::Empty(EmptyConstraint {}); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn bool_bare() { - let subject = Subject::Bool(true.into()); - let json = serde_json::json!(true); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn bool_constraint() { - let subject = Subject::Bool(ResolvedValueConstraint::default().exact(true).into()); - let json = serde_json::json!({ "exact": true }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn integer_range_bare() { - let subject = Subject::IntegerRange(42.into()); - let json = serde_json::json!(42); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn integer_range_constraint() { - let subject = - Subject::IntegerRange(ResolvedValueRangeConstraint::default().exact(42).into()); - let json = serde_json::json!({ "exact": 42 }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn float_range_bare() { - let subject = Subject::FloatRange(4.2.into()); - let json = serde_json::json!(4.2); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn float_range_constraint() { - let subject = - Subject::FloatRange(ResolvedValueRangeConstraint::default().exact(42.0).into()); - let json = serde_json::json!({ "exact": 42.0 }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string_sequence_bare() { - let subject = Subject::StringSequence(vec!["foo".to_owned(), "bar".to_owned()].into()); - let json = serde_json::json!(["foo", "bar"]); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string_sequence_constraint() { - let subject = Subject::StringSequence( - ResolvedValueSequenceConstraint::default() - .exact(vec!["foo".to_owned(), "bar".to_owned()]) - .into(), - ); - let json = serde_json::json!({ "exact": ["foo", "bar"] }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string_bare() { - let subject = Subject::String("foo".to_owned().into()); - let json = serde_json::json!("foo"); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string_constraint() { - let subject = Subject::String( - ResolvedValueConstraint::default() - .exact("foo".to_owned()) - .into(), - ); - let json = serde_json::json!({ "exact": "foo" }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/constraint/value.rs b/constraints/src/constraint/value.rs deleted file mode 100644 index 600074b13..000000000 --- a/constraints/src/constraint/value.rs +++ /dev/null @@ -1,419 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::MediaTrackConstraintResolutionStrategy; - -/// A bare value or constraint specifying a single accepted value. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `ValueConstraint` type aims to be a generalization over -/// multiple types in the spec. -/// -/// | Rust | W3C | -/// | ------------------------------ | --------------------------------------- | -/// | `ValueConstraint` | [`ConstrainBoolean`][constrain_boolean] | -/// -/// [constrain_boolean]: https://www.w3.org/TR/mediacapture-streams/#dom-constrainboolean -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum ValueConstraint { - /// A bare-valued media track constraint. - Bare(T), - /// A fully-qualified media track constraint. - Constraint(ResolvedValueConstraint), -} - -impl Default for ValueConstraint { - fn default() -> Self { - Self::Constraint(Default::default()) - } -} - -impl From for ValueConstraint { - fn from(bare: T) -> Self { - Self::Bare(bare) - } -} - -impl From> for ValueConstraint { - fn from(constraint: ResolvedValueConstraint) -> Self { - Self::Constraint(constraint) - } -} - -impl ValueConstraint -where - T: Clone, -{ - /// Returns a resolved representation of the constraint - /// with bare values resolved to fully-qualified constraints. - pub fn to_resolved( - &self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedValueConstraint { - self.clone().into_resolved(strategy) - } - - /// Consumes the constraint, returning a resolved representation of the - /// constraint with bare values resolved to fully-qualified constraints. - pub fn into_resolved( - self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedValueConstraint { - match self { - Self::Bare(bare) => match strategy { - MediaTrackConstraintResolutionStrategy::BareToIdeal => { - ResolvedValueConstraint::default().ideal(bare) - } - MediaTrackConstraintResolutionStrategy::BareToExact => { - ResolvedValueConstraint::default().exact(bare) - } - }, - Self::Constraint(constraint) => constraint, - } - } -} - -impl ValueConstraint { - /// Returns `true` if `self` is empty, otherwise `false`. - pub fn is_empty(&self) -> bool { - match self { - Self::Bare(_) => false, - Self::Constraint(constraint) => constraint.is_empty(), - } - } -} - -/// A constraint specifying a single accepted value. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `ValueConstraint` type aims to be a -/// generalization over multiple types in the W3C spec: -/// -/// | Rust | W3C | -/// | ------------------------------ | --------------------------------------- | -/// | `ResolvedValueConstraint` | [`ConstrainBooleanParameters`][constrain_boolean_parameters] | -/// -/// [constrain_boolean_parameters]: https://www.w3.org/TR/mediacapture-streams/#dom-constrainbooleanparameters -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub struct ResolvedValueConstraint { - /// The exact required value for this property. - /// - /// This is a required value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub exact: Option, - /// The ideal (target) value for this property. - /// - /// This is an optional value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub ideal: Option, -} - -impl ResolvedValueConstraint { - /// Consumes `self`, returning a corresponding constraint - /// with the exact required value set to `exact`. - #[inline] - pub fn exact(mut self, exact: U) -> Self - where - Option: From, - { - self.exact = exact.into(); - self - } - - /// Consumes `self`, returning a corresponding constraint - /// with the ideal required value set to `ideal`. - #[inline] - pub fn ideal(mut self, ideal: U) -> Self - where - Option: From, - { - self.ideal = ideal.into(); - self - } - - /// Returns `true` if `value.is_some()` is `true` for any of its required values, - /// otherwise `false`. - pub fn is_required(&self) -> bool { - self.exact.is_some() - } - - /// Returns `true` if `value.is_none()` is `true` for all of its values, - /// otherwise `false`. - pub fn is_empty(&self) -> bool { - self.exact.is_none() && self.ideal.is_none() - } - - /// Returns a corresponding constraint containing only required values. - pub fn to_required_only(&self) -> Self - where - T: Clone, - { - self.clone().into_required_only() - } - - /// Consumes `self, returning a corresponding constraint - /// containing only required values. - pub fn into_required_only(self) -> Self { - Self { - exact: self.exact, - ideal: None, - } - } -} - -impl Default for ResolvedValueConstraint { - #[inline] - fn default() -> Self { - Self { - exact: None, - ideal: None, - } - } -} - -impl std::fmt::Display for ResolvedValueConstraint -where - T: std::fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut is_first = true; - f.write_str("(")?; - if let Some(ref exact) = &self.exact { - f.write_fmt(format_args!("x == {exact:?}"))?; - is_first = false; - } - if let Some(ref ideal) = &self.ideal { - if !is_first { - f.write_str(" && ")?; - } - f.write_fmt(format_args!("x ~= {ideal:?}"))?; - is_first = false; - } - if is_first { - f.write_str("")?; - } - f.write_str(")")?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn to_string() { - let scenarios = [ - (ResolvedValueConstraint::default(), "()"), - ( - ResolvedValueConstraint::default().exact(true), - "(x == true)", - ), - ( - ResolvedValueConstraint::default().ideal(true), - "(x ~= true)", - ), - ( - ResolvedValueConstraint::default().exact(true).ideal(true), - "(x == true && x ~= true)", - ), - ]; - - for (constraint, expected) in scenarios { - let actual = constraint.to_string(); - - assert_eq!(actual, expected); - } - } - - #[test] - fn is_required() { - let scenarios = [ - (ResolvedValueConstraint::default(), false), - (ResolvedValueConstraint::default().exact(true), true), - (ResolvedValueConstraint::default().ideal(true), false), - ( - ResolvedValueConstraint::default().exact(true).ideal(true), - true, - ), - ]; - - for (constraint, expected) in scenarios { - let actual = constraint.is_required(); - - assert_eq!(actual, expected); - } - } - - mod is_empty { - use super::*; - - #[test] - fn bare() { - let constraint = ValueConstraint::Bare(true); - - assert!(!constraint.is_empty()); - } - - #[test] - fn constraint() { - let scenarios = [ - (ResolvedValueConstraint::default(), true), - (ResolvedValueConstraint::default().exact(true), false), - (ResolvedValueConstraint::default().ideal(true), false), - ( - ResolvedValueConstraint::default().exact(true).ideal(true), - false, - ), - ]; - - for (constraint, expected) in scenarios { - let constraint = ValueConstraint::::Constraint(constraint); - - let actual = constraint.is_empty(); - - assert_eq!(actual, expected); - } - } - } - - #[test] - fn resolve_to_advanced() { - let constraints = [ - ValueConstraint::Bare(true), - ValueConstraint::Constraint(ResolvedValueConstraint::default().exact(true)), - ]; - let strategy = MediaTrackConstraintResolutionStrategy::BareToExact; - - for constraint in constraints { - let actuals = [ - constraint.to_resolved(strategy), - constraint.into_resolved(strategy), - ]; - - let expected = ResolvedValueConstraint::default().exact(true); - - for actual in actuals { - assert_eq!(actual, expected); - } - } - } - - #[test] - fn resolve_to_basic() { - let constraints = [ - ValueConstraint::Bare(true), - ValueConstraint::Constraint(ResolvedValueConstraint::default().ideal(true)), - ]; - let strategy = MediaTrackConstraintResolutionStrategy::BareToIdeal; - - for constraint in constraints { - let actuals = [ - constraint.to_resolved(strategy), - constraint.into_resolved(strategy), - ]; - - let expected = ResolvedValueConstraint::default().ideal(true); - - for actual in actuals { - assert_eq!(actual, expected); - } - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - macro_rules! test_serde { - ($t:ty => { - value: $value:expr - }) => { - type Subject = ValueConstraint<$t>; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn bare() { - let subject = Subject::Bare($value.to_owned()); - let json = serde_json::json!($value); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn exact_constraint() { - let subject = Subject::Constraint(ResolvedValueConstraint::default().exact($value.to_owned())); - let json = serde_json::json!({ - "exact": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn ideal_constraint() { - let subject = Subject::Constraint(ResolvedValueConstraint::default().ideal($value.to_owned())); - let json = serde_json::json!({ - "ideal": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn full_constraint() { - let subject = Subject::Constraint(ResolvedValueConstraint::default().exact($value.to_owned()).ideal($value.to_owned())); - let json = serde_json::json!({ - "exact": $value, - "ideal": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - }; - } - - mod bool { - use super::*; - - test_serde!(bool => { - value: true - }); - } - - mod string { - use super::*; - - test_serde!(String => { - value: "VALUE" - }); - } -} diff --git a/constraints/src/constraint/value_range.rs b/constraints/src/constraint/value_range.rs deleted file mode 100644 index 4fed9808a..000000000 --- a/constraints/src/constraint/value_range.rs +++ /dev/null @@ -1,513 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::MediaTrackConstraintResolutionStrategy; - -/// A bare value or constraint specifying a range of accepted values. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `ValueConstraint` type aims to be a generalization over -/// multiple types in the spec. -/// -/// | Rust | W3C | -/// | ---------------------------------- | ------------------------------------- | -/// | `ValueRangeConstraint` | [`ConstrainULong`][constrain_ulong] | -/// | `ValueRangeConstraint` | [`ConstrainDouble`][constrain_double] | -/// -/// [constrain_double]: https://www.w3.org/TR/mediacapture-streams/#dom-constraindouble -/// [constrain_ulong]: https://www.w3.org/TR/mediacapture-streams/#dom-constrainulong -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum ValueRangeConstraint { - /// A bare-valued media track constraint. - Bare(T), - /// A fully-qualified media track constraint. - Constraint(ResolvedValueRangeConstraint), -} - -impl Default for ValueRangeConstraint { - fn default() -> Self { - Self::Constraint(Default::default()) - } -} - -impl From for ValueRangeConstraint { - fn from(bare: T) -> Self { - Self::Bare(bare) - } -} - -impl From> for ValueRangeConstraint { - fn from(constraint: ResolvedValueRangeConstraint) -> Self { - Self::Constraint(constraint) - } -} - -impl ValueRangeConstraint -where - T: Clone, -{ - /// Returns a resolved representation of the constraint - /// with bare values resolved to fully-qualified constraints. - pub fn to_resolved( - &self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedValueRangeConstraint { - self.clone().into_resolved(strategy) - } - - /// Consumes the constraint, returning a resolved representation of the - /// constraint with bare values resolved to fully-qualified constraints. - pub fn into_resolved( - self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedValueRangeConstraint { - match self { - Self::Bare(bare) => match strategy { - MediaTrackConstraintResolutionStrategy::BareToIdeal => { - ResolvedValueRangeConstraint::default().ideal(bare) - } - MediaTrackConstraintResolutionStrategy::BareToExact => { - ResolvedValueRangeConstraint::default().exact(bare) - } - }, - Self::Constraint(constraint) => constraint, - } - } -} - -impl ValueRangeConstraint { - /// Returns `true` if `self` is empty, otherwise `false`. - pub fn is_empty(&self) -> bool { - match self { - Self::Bare(_) => false, - Self::Constraint(constraint) => constraint.is_empty(), - } - } -} - -/// A constraint specifying a range of accepted values. -/// -/// Corresponding W3C spec types as per ["Media Capture and Streams"][spec]: -/// - `ConstrainDouble` => `ResolvedValueRangeConstraint` -/// - `ConstrainULong` => `ResolvedValueRangeConstraint` -/// -/// [spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub struct ResolvedValueRangeConstraint { - /// The minimum legal value of this property. - /// - /// This is a required value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub min: Option, - /// The maximum legal value of this property. - /// - /// This is a required value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub max: Option, - /// The exact required value for this property. - /// - /// This is a required value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub exact: Option, - /// The ideal (target) value for this property. - /// - /// This is an optional value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub ideal: Option, -} - -impl ResolvedValueRangeConstraint { - /// Consumes `self`, returning a corresponding constraint - /// with the exact required value set to `exact`. - #[inline] - pub fn exact(mut self, exact: U) -> Self - where - Option: From, - { - self.exact = exact.into(); - self - } - - /// Consumes `self`, returning a corresponding constraint - /// with the ideal required value set to `ideal`. - #[inline] - pub fn ideal(mut self, ideal: U) -> Self - where - Option: From, - { - self.ideal = ideal.into(); - self - } - - /// Consumes `self`, returning a corresponding constraint - /// with the minimum required value set to `min`. - #[inline] - pub fn min(mut self, min: U) -> Self - where - Option: From, - { - self.min = min.into(); - self - } - - /// Consumes `self`, returning a corresponding constraint - /// with the maximum required value set to `max`. - #[inline] - pub fn max(mut self, max: U) -> Self - where - Option: From, - { - self.max = max.into(); - self - } - - /// Returns `true` if `value.is_some()` is `true` for any of its required values, - /// otherwise `false`. - pub fn is_required(&self) -> bool { - self.min.is_some() || self.max.is_some() || self.exact.is_some() - } - - /// Returns `true` if `value.is_none()` is `true` for all of its values, - /// otherwise `false`. - pub fn is_empty(&self) -> bool { - self.min.is_none() && self.max.is_none() && self.exact.is_none() && self.ideal.is_none() - } - - /// Returns a corresponding constraint containing only required values. - pub fn to_required_only(&self) -> Self - where - T: Clone, - { - self.clone().into_required_only() - } - - /// Consumes `self, returning a corresponding constraint - /// containing only required values. - pub fn into_required_only(self) -> Self { - Self { - min: self.min, - max: self.max, - exact: self.exact, - ideal: None, - } - } -} - -impl Default for ResolvedValueRangeConstraint { - #[inline] - fn default() -> Self { - Self { - min: None, - max: None, - exact: None, - ideal: None, - } - } -} - -impl std::fmt::Display for ResolvedValueRangeConstraint -where - T: std::fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut is_first = true; - f.write_str("(")?; - if let Some(exact) = &self.exact { - f.write_fmt(format_args!("x == {exact:?}"))?; - is_first = false; - } else if let (Some(min), Some(max)) = (&self.min, &self.max) { - f.write_fmt(format_args!("{min:?} <= x <= {max:?}"))?; - is_first = false; - } else if let Some(min) = &self.min { - f.write_fmt(format_args!("{min:?} <= x"))?; - is_first = false; - } else if let Some(max) = &self.max { - f.write_fmt(format_args!("x <= {max:?}"))?; - is_first = false; - } - if let Some(ideal) = &self.ideal { - if !is_first { - f.write_str(" && ")?; - } - f.write_fmt(format_args!("x ~= {ideal:?}"))?; - is_first = false; - } - if is_first { - f.write_str("")?; - } - f.write_str(")")?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn to_string() { - let scenarios = [ - (ResolvedValueRangeConstraint::default(), "()"), - (ResolvedValueRangeConstraint::default().exact(1), "(x == 1)"), - (ResolvedValueRangeConstraint::default().ideal(2), "(x ~= 2)"), - ( - ResolvedValueRangeConstraint::default().exact(1).ideal(2), - "(x == 1 && x ~= 2)", - ), - ]; - - for (constraint, expected) in scenarios { - let actual = constraint.to_string(); - - assert_eq!(actual, expected); - } - } - - #[test] - fn is_required() { - for min_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let min = if min_is_some { Some(1) } else { None }; - for max_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let max = if max_is_some { Some(2) } else { None }; - for exact_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let exact = if exact_is_some { Some(3) } else { None }; - for ideal_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let ideal = if ideal_is_some { Some(4) } else { None }; - - let constraint = ResolvedValueRangeConstraint:: { - min, - max, - exact, - ideal, - }; - - let actual = constraint.is_required(); - let expected = min_is_some || max_is_some || exact_is_some; - - assert_eq!(actual, expected); - } - } - } - } - } - - mod is_empty { - use super::*; - - #[test] - fn bare() { - let constraint = ValueRangeConstraint::Bare(42); - - assert!(!constraint.is_empty()); - } - - #[test] - fn constraint() { - for min_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let min = if min_is_some { Some(1) } else { None }; - for max_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let max = if max_is_some { Some(2) } else { None }; - for exact_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let exact = if exact_is_some { Some(3) } else { None }; - for ideal_is_some in [false, true] { - // TODO: Replace `if { Some(_) } else { None }` with `.then_some(_)` - // once MSRV has passed 1.62.0: - let ideal = if ideal_is_some { Some(4) } else { None }; - - let constraint = ResolvedValueRangeConstraint:: { - min, - max, - exact, - ideal, - }; - - let actual = constraint.is_empty(); - let expected = - !(min_is_some || max_is_some || exact_is_some || ideal_is_some); - - assert_eq!(actual, expected); - } - } - } - } - } - } -} - -#[test] -fn resolve_to_advanced() { - let constraints = [ - ValueRangeConstraint::Bare(42), - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint::default().exact(42)), - ]; - let strategy = MediaTrackConstraintResolutionStrategy::BareToExact; - - for constraint in constraints { - let actuals = [ - constraint.to_resolved(strategy), - constraint.into_resolved(strategy), - ]; - - let expected = ResolvedValueRangeConstraint::default().exact(42); - - for actual in actuals { - assert_eq!(actual, expected); - } - } -} - -#[test] -fn resolve_to_basic() { - let constraints = [ - ValueRangeConstraint::Bare(42), - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint::default().ideal(42)), - ]; - let strategy = MediaTrackConstraintResolutionStrategy::BareToIdeal; - - for constraint in constraints { - let actuals = [ - constraint.to_resolved(strategy), - constraint.into_resolved(strategy), - ]; - - let expected = ResolvedValueRangeConstraint::default().ideal(42); - - for actual in actuals { - assert_eq!(actual, expected); - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - macro_rules! test_serde { - ($t:ty => { - value: $value:expr - }) => { - type Subject = ValueRangeConstraint<$t>; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn bare() { - let subject = Subject::Bare($value.to_owned()); - let json = serde_json::json!($value); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn min_constraint() { - let subject = Subject::Constraint(ResolvedValueRangeConstraint::default().min($value.to_owned())); - let json = serde_json::json!({ - "min": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn max_constraint() { - let subject = Subject::Constraint(ResolvedValueRangeConstraint::default().max($value.to_owned())); - let json = serde_json::json!({ - "max": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn exact_constraint() { - let subject = Subject::Constraint(ResolvedValueRangeConstraint::default().exact($value.to_owned())); - let json = serde_json::json!({ - "exact": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn ideal_constraint() { - let subject = Subject::Constraint(ResolvedValueRangeConstraint::default().ideal($value.to_owned())); - let json = serde_json::json!({ - "ideal": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn full_constraint() { - let subject = Subject::Constraint(ResolvedValueRangeConstraint::default().min($value.to_owned()).max($value.to_owned()).exact($value.to_owned()).ideal($value.to_owned())); - let json = serde_json::json!({ - "min": $value, - "max": $value, - "exact": $value, - "ideal": $value, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - }; - } - - mod f64 { - use super::*; - - test_serde!(f64 => { - value: 42.0 - }); - } - - mod u64 { - use super::*; - - test_serde!(u64 => { - value: 42 - }); - } -} diff --git a/constraints/src/constraint/value_sequence.rs b/constraints/src/constraint/value_sequence.rs deleted file mode 100644 index b3d95517c..000000000 --- a/constraints/src/constraint/value_sequence.rs +++ /dev/null @@ -1,440 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::MediaTrackConstraintResolutionStrategy; - -/// A bare value or constraint specifying a sequence of accepted values. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `ValueConstraint` type aims to be a generalization over -/// multiple types in the spec. -/// -/// | Rust | W3C | -/// | ---------------------------------------- | -------------------------------------------- | -/// | `ValueSequenceConstraint` | [`ConstrainDOMString`][constrain_dom_string] | -/// -/// [constrain_dom_string]: https://www.w3.org/TR/mediacapture-streams/#dom-constraindomstring -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum ValueSequenceConstraint { - /// A bare-valued media track constraint. - Bare(Vec), - /// A fully-qualified media track constraint. - Constraint(ResolvedValueSequenceConstraint), -} - -impl Default for ValueSequenceConstraint { - fn default() -> Self { - Self::Constraint(Default::default()) - } -} - -impl From for ValueSequenceConstraint { - fn from(bare: T) -> Self { - Self::Bare(vec![bare]) - } -} - -impl From> for ValueSequenceConstraint { - fn from(bare: Vec) -> Self { - Self::Bare(bare) - } -} - -impl From> for ValueSequenceConstraint { - fn from(constraint: ResolvedValueSequenceConstraint) -> Self { - Self::Constraint(constraint) - } -} - -impl ValueSequenceConstraint -where - T: Clone, -{ - /// Returns a resolved representation of the constraint - /// with bare values resolved to fully-qualified constraints. - pub fn to_resolved( - &self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedValueSequenceConstraint { - self.clone().into_resolved(strategy) - } - - /// Consumes the constraint, returning a resolved representation of the - /// constraint with bare values resolved to fully-qualified constraints. - pub fn into_resolved( - self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedValueSequenceConstraint { - match self { - Self::Bare(bare) => match strategy { - MediaTrackConstraintResolutionStrategy::BareToIdeal => { - ResolvedValueSequenceConstraint::default().ideal(bare) - } - MediaTrackConstraintResolutionStrategy::BareToExact => { - ResolvedValueSequenceConstraint::default().exact(bare) - } - }, - Self::Constraint(constraint) => constraint, - } - } -} - -impl ValueSequenceConstraint { - /// Returns `true` if `self` is empty, otherwise `false`. - pub fn is_empty(&self) -> bool { - match self { - Self::Bare(bare) => bare.is_empty(), - Self::Constraint(constraint) => constraint.is_empty(), - } - } -} - -/// A constraint specifying a sequence of accepted values. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `ValueSequenceConstraint` type aims to be a -/// generalization over multiple types in the W3C spec: -/// -/// | Rust | W3C | -/// | --------------------------------- | ----------------------------------------------------------------- | -/// | `ResolvedValueSequenceConstraint` | [`ConstrainDOMStringParameters`][constrain_dom_string_parameters] | -/// -/// [constrain_dom_string_parameters]: https://www.w3.org/TR/mediacapture-streams/#dom-constraindomstringparameters -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub struct ResolvedValueSequenceConstraint { - /// The exact required value for this property. - /// - /// This is a required value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub exact: Option>, - /// The ideal (target) value for this property. - /// - /// This is an optional value. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub ideal: Option>, -} - -impl ResolvedValueSequenceConstraint { - /// Consumes `self`, returning a corresponding constraint - /// with the exact required value set to `exact`. - #[inline] - pub fn exact(mut self, exact: U) -> Self - where - Option>: From, - { - self.exact = exact.into(); - self - } - - /// Consumes `self`, returning a corresponding constraint - /// with the ideal required value set to `ideal`. - #[inline] - pub fn ideal(mut self, ideal: U) -> Self - where - Option>: From, - { - self.ideal = ideal.into(); - self - } - - /// Returns `true` if `value.is_some()` is `true` for any of its required values, - /// otherwise `false`. - pub fn is_required(&self) -> bool { - self.exact.is_some() - } - - /// Returns `true` if `value.is_none()` is `true` for all of its values, - /// otherwise `false`. - pub fn is_empty(&self) -> bool { - let exact_is_empty = self.exact.as_ref().map_or(true, Vec::is_empty); - let ideal_is_empty = self.ideal.as_ref().map_or(true, Vec::is_empty); - exact_is_empty && ideal_is_empty - } - - /// Returns a corresponding constraint containing only required values. - pub fn to_required_only(&self) -> Self - where - T: Clone, - { - self.clone().into_required_only() - } - - /// Consumes `self, returning a corresponding constraint - /// containing only required values. - pub fn into_required_only(self) -> Self { - Self { - exact: self.exact, - ideal: None, - } - } -} - -impl Default for ResolvedValueSequenceConstraint { - fn default() -> Self { - Self { - exact: None, - ideal: None, - } - } -} - -impl std::fmt::Display for ResolvedValueSequenceConstraint -where - T: std::fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut is_first = true; - f.write_str("(")?; - if let Some(ref exact) = &self.exact { - f.write_fmt(format_args!("x == {exact:?}"))?; - is_first = false; - } - if let Some(ref ideal) = &self.ideal { - if !is_first { - f.write_str(" && ")?; - } - f.write_fmt(format_args!("x ~= {ideal:?}"))?; - is_first = false; - } - if is_first { - f.write_str("")?; - } - f.write_str(")")?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn to_string() { - let scenarios = [ - (ResolvedValueSequenceConstraint::default(), "()"), - ( - ResolvedValueSequenceConstraint::default().exact(vec![1, 2]), - "(x == [1, 2])", - ), - ( - ResolvedValueSequenceConstraint::default().ideal(vec![2, 3]), - "(x ~= [2, 3])", - ), - ( - ResolvedValueSequenceConstraint::default() - .exact(vec![1, 2]) - .ideal(vec![2, 3]), - "(x == [1, 2] && x ~= [2, 3])", - ), - ]; - - for (constraint, expected) in scenarios { - let actual = constraint.to_string(); - - assert_eq!(actual, expected); - } - } - - #[test] - fn is_required() { - let scenarios = [ - (ResolvedValueSequenceConstraint::default(), false), - ( - ResolvedValueSequenceConstraint::default().exact(vec![true]), - true, - ), - ( - ResolvedValueSequenceConstraint::default().ideal(vec![true]), - false, - ), - ( - ResolvedValueSequenceConstraint::default() - .exact(vec![true]) - .ideal(vec![true]), - true, - ), - ]; - - for (constraint, expected) in scenarios { - let actual = constraint.is_required(); - - assert_eq!(actual, expected); - } - } - - mod is_empty { - use super::*; - - #[test] - fn bare() { - let constraint = ValueSequenceConstraint::Bare(vec![true]); - - assert!(!constraint.is_empty()); - } - - #[test] - fn constraint() { - let scenarios = [ - (ResolvedValueSequenceConstraint::default(), true), - ( - ResolvedValueSequenceConstraint::default().exact(vec![true]), - false, - ), - ( - ResolvedValueSequenceConstraint::default().ideal(vec![true]), - false, - ), - ( - ResolvedValueSequenceConstraint::default() - .exact(vec![true]) - .ideal(vec![true]), - false, - ), - ]; - - for (constraint, expected) in scenarios { - let constraint = ValueSequenceConstraint::::Constraint(constraint); - - let actual = constraint.is_empty(); - - assert_eq!(actual, expected); - } - } - } - - #[test] - fn resolve_to_advanced() { - let constraints = [ - ValueSequenceConstraint::Bare(vec![true]), - ValueSequenceConstraint::Constraint( - ResolvedValueSequenceConstraint::default().exact(vec![true]), - ), - ]; - let strategy = MediaTrackConstraintResolutionStrategy::BareToExact; - - for constraint in constraints { - let actuals = [ - constraint.to_resolved(strategy), - constraint.into_resolved(strategy), - ]; - - let expected = ResolvedValueSequenceConstraint::default().exact(vec![true]); - - for actual in actuals { - assert_eq!(actual, expected); - } - } - } - - #[test] - fn resolve_to_basic() { - let constraints = [ - ValueSequenceConstraint::Bare(vec![true]), - ValueSequenceConstraint::Constraint( - ResolvedValueSequenceConstraint::default().ideal(vec![true]), - ), - ]; - let strategy = MediaTrackConstraintResolutionStrategy::BareToIdeal; - - for constraint in constraints { - let actuals = [ - constraint.to_resolved(strategy), - constraint.into_resolved(strategy), - ]; - - let expected = ResolvedValueSequenceConstraint::default().ideal(vec![true]); - - for actual in actuals { - assert_eq!(actual, expected); - } - } - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - macro_rules! test_serde { - ($t:ty => { - values: [$($values:expr),*] - }) => { - type Subject = ValueSequenceConstraint<$t>; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn bare() { - let subject = Subject::Bare(vec![$($values.to_owned()),*].into()); - let json = serde_json::json!([$($values),*]); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn exact_constraint() { - let subject = Subject::Constraint(ResolvedValueSequenceConstraint::default().exact(vec![$($values.to_owned()),*])); - let json = serde_json::json!({ - "exact": [$($values),*], - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn ideal_constraint() { - let subject = Subject::Constraint(ResolvedValueSequenceConstraint::default().ideal(vec![$($values.to_owned()),*])); - let json = serde_json::json!({ - "ideal": [$($values),*], - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn full_constraint() { - let subject = Subject::Constraint(ResolvedValueSequenceConstraint::default().exact(vec![$($values.to_owned()),*]).ideal(vec![$($values.to_owned()),*])); - let json = serde_json::json!({ - "exact": [$($values),*], - "ideal": [$($values),*], - }); - - test_serde_symmetry!(subject: subject, json: json); - } - }; - } - - mod string { - use super::*; - - test_serde!(String => { - values: ["VALUE_0", "VALUE_1"] - }); - } -} diff --git a/constraints/src/constraints.rs b/constraints/src/constraints.rs deleted file mode 100644 index 8b89b909e..000000000 --- a/constraints/src/constraints.rs +++ /dev/null @@ -1,22 +0,0 @@ -mod advanced; -mod constraint_set; -mod mandatory; -mod stream; -mod track; - -pub use self::advanced::{ - AdvancedMediaTrackConstraints, ResolvedAdvancedMediaTrackConstraints, - SanitizedAdvancedMediaTrackConstraints, -}; -pub use self::constraint_set::{ - MediaTrackConstraintSet, ResolvedMediaTrackConstraintSet, SanitizedMediaTrackConstraintSet, -}; -pub use self::mandatory::{ - MandatoryMediaTrackConstraints, ResolvedMandatoryMediaTrackConstraints, - SanitizedMandatoryMediaTrackConstraints, -}; -pub use self::stream::MediaStreamConstraints; -pub use self::track::{ - BoolOrMediaTrackConstraints, MediaTrackConstraints, ResolvedMediaTrackConstraints, - SanitizedMediaTrackConstraints, -}; diff --git a/constraints/src/constraints/advanced.rs b/constraints/src/constraints/advanced.rs deleted file mode 100644 index d6e893199..000000000 --- a/constraints/src/constraints/advanced.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::iter::FromIterator; -use std::ops::{Deref, DerefMut}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use super::constraint_set::GenericMediaTrackConstraintSet; -use crate::{ - MediaTrackConstraint, MediaTrackConstraintResolutionStrategy, MediaTrackSupportedConstraints, - ResolvedMediaTrackConstraint, SanitizedMediaTrackConstraint, -}; - -/// Advanced media track constraints that contain sets of either bare values or constraints. -pub type AdvancedMediaTrackConstraints = GenericAdvancedMediaTrackConstraints; - -/// Advanced media track constraints that contain sets of constraints (both, empty and non-empty). -pub type ResolvedAdvancedMediaTrackConstraints = - GenericAdvancedMediaTrackConstraints; - -/// Advanced media track constraints that contain sets of only non-empty constraints. -pub type SanitizedAdvancedMediaTrackConstraints = - GenericAdvancedMediaTrackConstraints; - -/// The list of advanced constraint sets for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`ResolvedMediaTrackConstraints.advanced`][media_track_constraints_advanced] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints_advanced]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints-advanced -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct GenericAdvancedMediaTrackConstraints(Vec>); - -impl GenericAdvancedMediaTrackConstraints { - pub fn new(constraints: Vec>) -> Self { - Self(constraints) - } - - pub fn into_inner(self) -> Vec> { - self.0 - } -} - -impl Deref for GenericAdvancedMediaTrackConstraints { - type Target = Vec>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for GenericAdvancedMediaTrackConstraints { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Default for GenericAdvancedMediaTrackConstraints { - fn default() -> Self { - Self(Default::default()) - } -} - -impl FromIterator> - for GenericAdvancedMediaTrackConstraints -{ - fn from_iter(iter: I) -> Self - where - I: IntoIterator>, - { - Self::new(iter.into_iter().collect()) - } -} - -impl IntoIterator for GenericAdvancedMediaTrackConstraints { - type Item = GenericMediaTrackConstraintSet; - type IntoIter = std::vec::IntoIter>; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl AdvancedMediaTrackConstraints { - pub fn to_resolved(&self) -> ResolvedAdvancedMediaTrackConstraints { - self.clone().into_resolved() - } - - pub fn into_resolved(self) -> ResolvedAdvancedMediaTrackConstraints { - let strategy = MediaTrackConstraintResolutionStrategy::BareToExact; - ResolvedAdvancedMediaTrackConstraints::from_iter( - self.into_iter() - .map(|constraint_set| constraint_set.into_resolved(strategy)), - ) - } -} - -impl ResolvedAdvancedMediaTrackConstraints { - pub fn to_sanitized( - &self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedAdvancedMediaTrackConstraints { - self.clone().into_sanitized(supported_constraints) - } - - pub fn into_sanitized( - self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedAdvancedMediaTrackConstraints { - SanitizedAdvancedMediaTrackConstraints::from_iter( - self.into_iter() - .map(|constraint_set| constraint_set.into_sanitized(supported_constraints)) - .filter(|constraint_set| !constraint_set.is_empty()), - ) - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::property::all::name::*; - use crate::MediaTrackConstraintSet; - - #[test] - fn serialize_default() { - let advanced = AdvancedMediaTrackConstraints::default(); - let actual = serde_json::to_value(advanced).unwrap(); - let expected = serde_json::json!([]); - - assert_eq!(actual, expected); - } - - #[test] - fn deserialize_default() { - let json = serde_json::json!([]); - let actual: AdvancedMediaTrackConstraints = serde_json::from_value(json).unwrap(); - let expected = AdvancedMediaTrackConstraints::default(); - - assert_eq!(actual, expected); - } - - #[test] - fn serialize() { - let advanced = - AdvancedMediaTrackConstraints::new(vec![MediaTrackConstraintSet::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ])]); - let actual = serde_json::to_value(advanced).unwrap(); - let expected = serde_json::json!([ - { - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - } - ]); - - assert_eq!(actual, expected); - } - - #[test] - fn deserialize() { - let json = serde_json::json!([ - { - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - } - ]); - let actual: AdvancedMediaTrackConstraints = serde_json::from_value(json).unwrap(); - let expected = - AdvancedMediaTrackConstraints::new(vec![MediaTrackConstraintSet::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ])]); - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/constraints/constraint_set.rs b/constraints/src/constraints/constraint_set.rs deleted file mode 100644 index 89680963b..000000000 --- a/constraints/src/constraints/constraint_set.rs +++ /dev/null @@ -1,200 +0,0 @@ -use std::iter::FromIterator; -use std::ops::{Deref, DerefMut}; - -use indexmap::IndexMap; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::constraint::SanitizedMediaTrackConstraint; -use crate::{ - MediaTrackConstraint, MediaTrackConstraintResolutionStrategy, MediaTrackProperty, - MediaTrackSupportedConstraints, ResolvedMediaTrackConstraint, -}; - -/// Media track constraint set that contains either bare values or constraints. -pub type MediaTrackConstraintSet = GenericMediaTrackConstraintSet; - -/// Media track constraint set that contains only constraints (both, empty and non-empty). -pub type ResolvedMediaTrackConstraintSet = - GenericMediaTrackConstraintSet; - -/// Media track constraint set that contains only non-empty constraints. -pub type SanitizedMediaTrackConstraintSet = - GenericMediaTrackConstraintSet; - -/// The set of constraints for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`ResolvedMediaTrackConstraintSet`][media_track_constraint_set] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraint_set]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraintset -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct GenericMediaTrackConstraintSet(IndexMap); - -impl GenericMediaTrackConstraintSet { - pub fn new(constraint_set: IndexMap) -> Self { - Self(constraint_set) - } - - pub fn into_inner(self) -> IndexMap { - self.0 - } -} - -impl Deref for GenericMediaTrackConstraintSet { - type Target = IndexMap; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for GenericMediaTrackConstraintSet { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Default for GenericMediaTrackConstraintSet { - fn default() -> Self { - Self(IndexMap::new()) - } -} - -impl FromIterator<(U, T)> for GenericMediaTrackConstraintSet -where - U: Into, -{ - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - Self::new(iter.into_iter().map(|(k, v)| (k.into(), v)).collect()) - } -} - -impl IntoIterator for GenericMediaTrackConstraintSet { - type Item = (MediaTrackProperty, T); - type IntoIter = indexmap::map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl MediaTrackConstraintSet { - pub fn to_resolved( - &self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedMediaTrackConstraintSet { - self.clone().into_resolved(strategy) - } - - pub fn into_resolved( - self, - strategy: MediaTrackConstraintResolutionStrategy, - ) -> ResolvedMediaTrackConstraintSet { - ResolvedMediaTrackConstraintSet::new( - self.into_iter() - .map(|(property, constraint)| (property, constraint.into_resolved(strategy))) - .collect(), - ) - } -} - -impl ResolvedMediaTrackConstraintSet { - pub fn to_sanitized( - &self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedMediaTrackConstraintSet { - self.clone().into_sanitized(supported_constraints) - } - - pub fn into_sanitized( - self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedMediaTrackConstraintSet { - let index_map: IndexMap = self - .into_iter() - .filter_map(|(property, constraint)| { - if supported_constraints.contains(&property) { - constraint - .into_sanitized() - .map(|constraint| (property, constraint)) - } else { - None - } - }) - .collect(); - SanitizedMediaTrackConstraintSet::new(index_map) - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::property::all::name::*; - - #[test] - fn serialize_default() { - let constraint_set = MediaTrackConstraintSet::default(); - let actual = serde_json::to_value(constraint_set).unwrap(); - let expected = serde_json::json!({}); - - assert_eq!(actual, expected); - } - - #[test] - fn deserialize_default() { - let json = serde_json::json!({}); - let actual: MediaTrackConstraintSet = serde_json::from_value(json).unwrap(); - let expected = MediaTrackConstraintSet::default(); - - assert_eq!(actual, expected); - } - - #[test] - fn serialize() { - let constraint_set = MediaTrackConstraintSet::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ]); - let actual = serde_json::to_value(constraint_set).unwrap(); - let expected = serde_json::json!({ - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - }); - - assert_eq!(actual, expected); - } - - #[test] - fn deserialize() { - let json = serde_json::json!({ - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - }); - let actual: MediaTrackConstraintSet = serde_json::from_value(json).unwrap(); - let expected = MediaTrackConstraintSet::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ]); - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/constraints/mandatory.rs b/constraints/src/constraints/mandatory.rs deleted file mode 100644 index 95dd8fc24..000000000 --- a/constraints/src/constraints/mandatory.rs +++ /dev/null @@ -1,321 +0,0 @@ -use std::iter::FromIterator; -use std::ops::{Deref, DerefMut}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use super::constraint_set::GenericMediaTrackConstraintSet; -use crate::{ - MediaTrackConstraint, MediaTrackConstraintResolutionStrategy, MediaTrackProperty, - MediaTrackSupportedConstraints, ResolvedMediaTrackConstraint, SanitizedMediaTrackConstraint, -}; - -/// The list of mandatory constraint sets for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`ResolvedMediaTrackConstraints.mandatory`][media_track_constraints_mandatory] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// Unlike `ResolvedMandatoryMediaTrackConstraints` this type may contain constraints with bare values. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints_mandatory]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints-mandatory -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -pub type MandatoryMediaTrackConstraints = - GenericMandatoryMediaTrackConstraints; - -/// The list of mandatory constraint sets for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`ResolvedMediaTrackConstraintSet`][media_track_constraints_mandatory] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// Unlike `MandatoryMediaTrackConstraints` this type does not contain constraints -/// with bare values, but has them resolved to full constraints instead. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints_mandatory]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints-mandatory -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -pub type ResolvedMandatoryMediaTrackConstraints = - GenericMandatoryMediaTrackConstraints; - -/// Set of mandatory media track constraints that contains only non-empty constraints. -pub type SanitizedMandatoryMediaTrackConstraints = - GenericMandatoryMediaTrackConstraints; - -/// The set of constraints for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`ResolvedMediaTrackConstraintSet`][media_track_constraint_set] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraint_set]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraintset -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct GenericMandatoryMediaTrackConstraints(GenericMediaTrackConstraintSet); - -impl GenericMandatoryMediaTrackConstraints { - pub fn new(constraints: GenericMediaTrackConstraintSet) -> Self { - Self(constraints) - } - - pub fn into_inner(self) -> GenericMediaTrackConstraintSet { - self.0 - } -} - -impl GenericMandatoryMediaTrackConstraints { - pub fn basic(&self) -> GenericMediaTrackConstraintSet { - self.basic_or_required(false) - } - - pub fn required(&self) -> GenericMediaTrackConstraintSet { - self.basic_or_required(true) - } - - fn basic_or_required( - &self, - required: bool, - ) -> GenericMediaTrackConstraintSet { - GenericMediaTrackConstraintSet::new( - self.0 - .iter() - .filter_map(|(property, constraint)| { - if constraint.is_required() == required { - Some((property.clone(), constraint.clone())) - } else { - None - } - }) - .collect(), - ) - } -} - -impl Deref for GenericMandatoryMediaTrackConstraints { - type Target = GenericMediaTrackConstraintSet; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for GenericMandatoryMediaTrackConstraints { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Default for GenericMandatoryMediaTrackConstraints { - fn default() -> Self { - Self(Default::default()) - } -} - -impl FromIterator<(U, T)> for GenericMandatoryMediaTrackConstraints -where - U: Into, -{ - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - Self::new(iter.into_iter().map(|(k, v)| (k.into(), v)).collect()) - } -} - -impl IntoIterator for GenericMandatoryMediaTrackConstraints { - type Item = (MediaTrackProperty, T); - type IntoIter = indexmap::map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl MandatoryMediaTrackConstraints { - pub fn to_resolved(&self) -> ResolvedMandatoryMediaTrackConstraints { - self.clone().into_resolved() - } - - pub fn into_resolved(self) -> ResolvedMandatoryMediaTrackConstraints { - let strategy = MediaTrackConstraintResolutionStrategy::BareToIdeal; - ResolvedMandatoryMediaTrackConstraints::new(self.0.into_resolved(strategy)) - } -} - -impl ResolvedMandatoryMediaTrackConstraints { - pub fn to_sanitized( - &self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedMandatoryMediaTrackConstraints { - self.clone().into_sanitized(supported_constraints) - } - - pub fn into_sanitized( - self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedMandatoryMediaTrackConstraints { - SanitizedMandatoryMediaTrackConstraints::new(self.0.into_sanitized(supported_constraints)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::property::all::name::*; - use crate::{ - ResolvedMediaTrackConstraintSet, ResolvedValueConstraint, ResolvedValueRangeConstraint, - }; - - #[test] - fn basic() { - let mandatory = ResolvedMandatoryMediaTrackConstraints::new( - ResolvedMediaTrackConstraintSet::from_iter([ - ( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("device-id".to_owned()) - .into(), - ), - ( - &AUTO_GAIN_CONTROL, - ResolvedValueConstraint::default().ideal(true).into(), - ), - ( - &CHANNEL_COUNT, - ResolvedValueRangeConstraint::default() - .exact(2) - .ideal(3) - .into(), - ), - ]), - ); - - let actual = mandatory.basic(); - let expected = ResolvedMediaTrackConstraintSet::from_iter([( - &AUTO_GAIN_CONTROL, - ResolvedValueConstraint::default().ideal(true).into(), - )]); - - assert_eq!(actual, expected); - } - - #[test] - fn required() { - let mandatory = ResolvedMandatoryMediaTrackConstraints::new( - ResolvedMediaTrackConstraintSet::from_iter([ - ( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("device-id".to_owned()) - .into(), - ), - ( - &AUTO_GAIN_CONTROL, - ResolvedValueConstraint::default().ideal(true).into(), - ), - ( - &CHANNEL_COUNT, - ResolvedValueRangeConstraint::default() - .exact(2) - .ideal(3) - .into(), - ), - ]), - ); - - let actual = mandatory.required(); - let expected = ResolvedMediaTrackConstraintSet::from_iter([ - ( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("device-id".to_owned()) - .into(), - ), - ( - &CHANNEL_COUNT, - ResolvedValueRangeConstraint::default() - .exact(2) - .ideal(3) - .into(), - ), - ]); - - assert_eq!(actual, expected); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::property::all::name::*; - use crate::MediaTrackConstraintSet; - - #[test] - fn serialize_default() { - let mandatory = MandatoryMediaTrackConstraints::default(); - let actual = serde_json::to_value(mandatory).unwrap(); - let expected = serde_json::json!({}); - - assert_eq!(actual, expected); - } - - #[test] - fn deserialize_default() { - let json = serde_json::json!({}); - let actual: MandatoryMediaTrackConstraints = serde_json::from_value(json).unwrap(); - let expected = MandatoryMediaTrackConstraints::default(); - - assert_eq!(actual, expected); - } - - #[test] - fn serialize() { - let mandatory = MandatoryMediaTrackConstraints::new(MediaTrackConstraintSet::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ])); - let actual = serde_json::to_value(mandatory).unwrap(); - let expected = serde_json::json!( - { - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - } - ); - - assert_eq!(actual, expected); - } - - #[test] - fn deserialize() { - let json = serde_json::json!( - { - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - } - ); - let actual: MandatoryMediaTrackConstraints = serde_json::from_value(json).unwrap(); - let expected = MandatoryMediaTrackConstraints::new(MediaTrackConstraintSet::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ])); - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/constraints/stream.rs b/constraints/src/constraints/stream.rs deleted file mode 100644 index c3268160f..000000000 --- a/constraints/src/constraints/stream.rs +++ /dev/null @@ -1,90 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use super::track::GenericBoolOrMediaTrackConstraints; -use crate::MediaTrackConstraint; - -/// The constraints for a [`MediaStream`][media_stream] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`MediaStreamConstraints`][media_stream_constraints] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastream -/// [media_stream_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -pub type MediaStreamConstraints = GenericMediaStreamConstraints; - -/// The constraints for a [`MediaStream`][media_stream] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`MediaStreamConstraints`][media_stream_constraints] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastream -/// [media_stream_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Default, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub struct GenericMediaStreamConstraints { - #[cfg_attr(feature = "serde", serde(default))] - pub audio: GenericBoolOrMediaTrackConstraints, - #[cfg_attr(feature = "serde", serde(default))] - pub video: GenericBoolOrMediaTrackConstraints, -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod tests { - use std::iter::FromIterator; - - use super::*; - use crate::constraints::advanced::AdvancedMediaTrackConstraints; - use crate::constraints::mandatory::MandatoryMediaTrackConstraints; - use crate::constraints::track::{BoolOrMediaTrackConstraints, MediaTrackConstraints}; - use crate::macros::test_serde_symmetry; - use crate::property::all::name::*; - use crate::MediaTrackConstraintSet; - - type Subject = MediaStreamConstraints; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({ - "audio": false, - "video": false, - }); - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = Subject { - audio: BoolOrMediaTrackConstraints::Constraints(MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([ - (&DEVICE_ID, "microphone".into()), - (&CHANNEL_COUNT, 2.into()), - ]), - advanced: AdvancedMediaTrackConstraints::new(vec![ - MediaTrackConstraintSet::from_iter([(&LATENCY, 0.123.into())]), - ]), - }), - video: BoolOrMediaTrackConstraints::Bool(true), - }; - let json = serde_json::json!({ - "audio": { - "deviceId": "microphone", - "channelCount": 2_i64, - "advanced": [ - { "latency": 0.123_f64, } - ] - }, - "video": true, - }); - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/constraints/track.rs b/constraints/src/constraints/track.rs deleted file mode 100644 index 425fd6c54..000000000 --- a/constraints/src/constraints/track.rs +++ /dev/null @@ -1,339 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use super::advanced::GenericAdvancedMediaTrackConstraints; -use super::mandatory::GenericMandatoryMediaTrackConstraints; -use crate::constraint::SanitizedMediaTrackConstraint; -use crate::{MediaTrackConstraint, MediaTrackSupportedConstraints, ResolvedMediaTrackConstraint}; - -/// A boolean on/off flag or bare value or constraints for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -pub type BoolOrMediaTrackConstraints = GenericBoolOrMediaTrackConstraints; - -/// A boolean on/off flag or constraints for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no direct corresponding type in the -/// W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec, -/// since the `BoolOrMediaTrackConstraints` type aims to be a -/// generalization over multiple types in the W3C spec: -/// -/// | Rust | W3C | -/// | ----------------------------- | -------------------------------------------------------------------------------------------------- | -/// | `BoolOrMediaTrackConstraints` | [`MediaStreamConstraints`][media_stream_constraints]'s [`video`][video] / [`audio`][audio] members | -/// -/// [media_stream_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamconstraints-video -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [video]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamconstraints-video -/// [audio]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamconstraints-audio -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum GenericBoolOrMediaTrackConstraints { - /// Boolean track selector. - Bool(bool), - /// Constraints-based track selector. - Constraints(GenericMediaTrackConstraints), -} - -impl GenericBoolOrMediaTrackConstraints -where - T: Clone, -{ - pub fn to_constraints(&self) -> Option> { - self.clone().into_constraints() - } - - pub fn into_constraints(self) -> Option> { - match self { - Self::Bool(false) => None, - Self::Bool(true) => Some(GenericMediaTrackConstraints::default()), - Self::Constraints(constraints) => Some(constraints), - } - } -} - -impl Default for GenericBoolOrMediaTrackConstraints { - fn default() -> Self { - Self::Bool(false) - } -} - -impl From for GenericBoolOrMediaTrackConstraints { - fn from(flag: bool) -> Self { - Self::Bool(flag) - } -} - -impl From> for GenericBoolOrMediaTrackConstraints { - fn from(constraints: GenericMediaTrackConstraints) -> Self { - Self::Constraints(constraints) - } -} - -/// Media track constraints that contains either bare values or constraints. -pub type MediaTrackConstraints = GenericMediaTrackConstraints; - -/// Media track constraints that contains only constraints (both, empty and non-empty). -pub type ResolvedMediaTrackConstraints = GenericMediaTrackConstraints; - -/// Media track constraints that contains only non-empty constraints. -pub type SanitizedMediaTrackConstraints = - GenericMediaTrackConstraints; - -/// The constraints for a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`MediaTrackConstraints`][media_track_constraints] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatrackconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams/ -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct GenericMediaTrackConstraints { - /// Mandatory (i.e required or optional basic) constraints, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-constraint - #[cfg_attr(feature = "serde", serde(flatten))] - pub mandatory: GenericMandatoryMediaTrackConstraints, - - /// Advanced constraints, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-constraint - #[cfg_attr( - feature = "serde", - serde(default = "Default::default"), - serde(skip_serializing_if = "should_skip_advanced") - )] - pub advanced: GenericAdvancedMediaTrackConstraints, -} - -#[cfg(feature = "serde")] -fn should_skip_advanced(advanced: &GenericAdvancedMediaTrackConstraints) -> bool { - advanced.is_empty() -} - -impl Default for GenericMediaTrackConstraints { - fn default() -> Self { - Self { - mandatory: Default::default(), - advanced: Default::default(), - } - } -} - -impl MediaTrackConstraints { - pub fn to_resolved(&self) -> ResolvedMediaTrackConstraints { - self.clone().into_resolved() - } - - pub fn into_resolved(self) -> ResolvedMediaTrackConstraints { - let Self { - mandatory, - advanced, - } = self; - ResolvedMediaTrackConstraints { - mandatory: mandatory.into_resolved(), - advanced: advanced.into_resolved(), - } - } -} - -impl ResolvedMediaTrackConstraints { - pub fn to_sanitized( - &self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedMediaTrackConstraints { - self.clone().into_sanitized(supported_constraints) - } - - pub fn into_sanitized( - self, - supported_constraints: &MediaTrackSupportedConstraints, - ) -> SanitizedMediaTrackConstraints { - let mandatory = self.mandatory.into_sanitized(supported_constraints); - let advanced = self.advanced.into_sanitized(supported_constraints); - SanitizedMediaTrackConstraints { - mandatory, - advanced, - } - } -} - -#[cfg(test)] -mod tests { - use std::iter::FromIterator; - - use super::*; - use crate::constraints::mandatory::MandatoryMediaTrackConstraints; - use crate::property::all::name::*; - use crate::{ - AdvancedMediaTrackConstraints, ResolvedAdvancedMediaTrackConstraints, - ResolvedMandatoryMediaTrackConstraints, ResolvedValueConstraint, - }; - - type Subject = BoolOrMediaTrackConstraints; - - #[test] - fn default() { - let actual = Subject::default(); - let expected = Subject::Bool(false); - - assert_eq!(actual, expected); - } - - mod from { - use super::*; - - #[test] - fn bool() { - for value in [false, true] { - let actual = Subject::from(value); - let expected = Subject::Bool(value); - - assert_eq!(actual, expected); - } - } - - #[test] - fn constraints() { - let constraints = GenericMediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([( - &DEVICE_ID, - "microphone".into(), - )]), - advanced: AdvancedMediaTrackConstraints::new(vec![]), - }; - - let actual = Subject::from(constraints.clone()); - let expected = Subject::Constraints(constraints); - - assert_eq!(actual, expected); - } - } - - mod to_constraints { - use super::*; - - #[test] - fn bool_false() { - let subject = Subject::Bool(false); - - let actual = subject.to_constraints(); - let expected = None; - - assert_eq!(actual, expected); - } - - #[test] - fn bool_true() { - let subject = Subject::Bool(true); - - let actual = subject.to_constraints(); - let expected = Some(GenericMediaTrackConstraints::default()); - - assert_eq!(actual, expected); - } - - #[test] - fn constraints() { - let constraints = GenericMediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([( - &DEVICE_ID, - "microphone".into(), - )]), - advanced: AdvancedMediaTrackConstraints::new(vec![]), - }; - - let subject = Subject::Constraints(constraints.clone()); - - let actual = subject.to_constraints(); - let expected = Some(constraints); - - assert_eq!(actual, expected); - } - } - - #[test] - fn to_resolved() { - let subject = MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([( - &DEVICE_ID, - "microphone".into(), - )]), - advanced: AdvancedMediaTrackConstraints::new(vec![]), - }; - - let actual = subject.to_resolved(); - let expected = ResolvedMediaTrackConstraints { - mandatory: ResolvedMandatoryMediaTrackConstraints::from_iter([( - &DEVICE_ID, - ResolvedValueConstraint::default() - .ideal("microphone".to_owned()) - .into(), - )]), - advanced: ResolvedAdvancedMediaTrackConstraints::new(vec![]), - }; - - assert_eq!(actual, expected); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use std::iter::FromIterator; - - use super::*; - use crate::constraints::mandatory::MandatoryMediaTrackConstraints; - use crate::macros::test_serde_symmetry; - use crate::property::all::name::*; - use crate::{AdvancedMediaTrackConstraints, MediaTrackConstraintSet}; - - type Subject = MediaTrackConstraints; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = Subject { - mandatory: MandatoryMediaTrackConstraints::from_iter([( - &DEVICE_ID, - "microphone".into(), - )]), - advanced: AdvancedMediaTrackConstraints::new(vec![ - MediaTrackConstraintSet::from_iter([ - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - ]), - MediaTrackConstraintSet::from_iter([(&LATENCY, 0.123.into())]), - ]), - }; - let json = serde_json::json!({ - "deviceId": "microphone", - "advanced": [ - { - "autoGainControl": true, - "channelCount": 2, - }, - { - "latency": 0.123, - }, - ] - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/enumerations.rs b/constraints/src/enumerations.rs deleted file mode 100644 index 2c890a866..000000000 --- a/constraints/src/enumerations.rs +++ /dev/null @@ -1,131 +0,0 @@ -/// The directions that the camera can face, as seen from the user's perspective. -/// -/// # Note -/// The enumeration is not exhaustive and merely provides a list of known values. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -pub enum FacingMode { - /// The source is facing toward the user (a self-view camera). - User, - - /// The source is facing away from the user (viewing the environment). - Environment, - - /// The source is facing to the left of the user. - Left, - - /// The source is facing to the right of the user. - Right, -} - -impl FacingMode { - /// Returns `"user"`, the string-value of the `User` facing mode. - pub fn user() -> String { - Self::User.to_string() - } - - /// Returns `"environment"`, the string-value of the `Environment` facing mode. - pub fn environment() -> String { - Self::Environment.to_string() - } - - /// Returns `"left"`, the string-value of the `Left` facing mode. - pub fn left() -> String { - Self::Left.to_string() - } - - /// Returns `"right"`, the string-value of the `Right` facing mode. - pub fn right() -> String { - Self::Right.to_string() - } -} - -impl std::fmt::Display for FacingMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::User => f.write_str("user"), - Self::Environment => f.write_str("environment"), - Self::Left => f.write_str("left"), - Self::Right => f.write_str("right"), - } - } -} - -/// The means by which the resolution can be derived by the client. -/// -/// # Note -/// The enumeration is not exhaustive and merely provides a list of known values. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -pub enum ResizeMode { - /// This resolution and frame rate is offered by the camera, its driver, or the OS. - /// - /// # Note - /// The user agent MAY report this value to disguise concurrent use, - /// but only when the camera is in use in another browsing context. - /// - /// # Important - /// This value is a possible finger-printing surface. - None, - - /// This resolution is downscaled and/or cropped from a higher camera resolution by the user agent, - /// or its frame rate is decimated by the User Agent. - /// - /// # Important - /// The media MUST NOT be upscaled, stretched or have fake data created that did not occur in the input source. - CropAndScale, -} - -impl ResizeMode { - /// Returns `"none"`, the string-value of the `None` resize mode. - pub fn none() -> String { - Self::None.to_string() - } - - /// Returns `"crop-and-scale"`, the string-value of the `CropAndScale` resize mode. - pub fn crop_and_scale() -> String { - Self::CropAndScale.to_string() - } -} - -impl std::fmt::Display for ResizeMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::None => f.write_str("none"), - Self::CropAndScale => f.write_str("crop-and-scale"), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - mod facing_mode { - use super::*; - - #[test] - fn to_string() { - assert_eq!(FacingMode::User.to_string(), "user"); - assert_eq!(FacingMode::Environment.to_string(), "environment"); - assert_eq!(FacingMode::Left.to_string(), "left"); - assert_eq!(FacingMode::Right.to_string(), "right"); - - assert_eq!(FacingMode::user(), "user"); - assert_eq!(FacingMode::environment(), "environment"); - assert_eq!(FacingMode::left(), "left"); - assert_eq!(FacingMode::right(), "right"); - } - } - - mod resize_mode { - use super::*; - - #[test] - fn to_string() { - assert_eq!(ResizeMode::None.to_string(), "none"); - assert_eq!(ResizeMode::CropAndScale.to_string(), "crop-and-scale"); - - assert_eq!(ResizeMode::none(), "none"); - assert_eq!(ResizeMode::crop_and_scale(), "crop-and-scale"); - } - } -} diff --git a/constraints/src/errors.rs b/constraints/src/errors.rs deleted file mode 100644 index cdd1a2a1c..000000000 --- a/constraints/src/errors.rs +++ /dev/null @@ -1,111 +0,0 @@ -//! Errors, as defined in the ["Media Capture and Streams"][mediacapture_streams] spec. -//! -//! [mediacapture_streams]: https://www.w3.org/TR/mediacapture-streams/ - -use std::collections::HashMap; - -use thiserror::Error; - -use crate::algorithms::{ConstraintFailureInfo, SettingFitnessDistanceErrorKind}; -use crate::MediaTrackProperty; - -/// An error indicating one or more over-constrained settings. -#[derive(Error, Clone, Eq, PartialEq, Debug)] -pub struct OverconstrainedError { - /// The offending constraint's name. - pub constraint: MediaTrackProperty, - /// An error message, or `None` if exposure-mode was `Protected`. - pub message: Option, -} - -impl Default for OverconstrainedError { - fn default() -> Self { - Self { - constraint: MediaTrackProperty::from(""), - message: Default::default(), - } - } -} - -impl std::fmt::Display for OverconstrainedError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Over-constrained property {:?}", self.constraint)?; - if let Some(message) = self.message.as_ref() { - write!(f, ": {message}")?; - } - Ok(()) - } -} - -impl OverconstrainedError { - pub(super) fn exposing_device_information( - failed_constraints: HashMap, - ) -> Self { - let failed_constraint = failed_constraints - .into_iter() - .max_by_key(|(_, failure_info)| failure_info.failures); - - let (constraint, failure_info) = - failed_constraint.expect("Empty candidates implies non-empty failed constraints"); - - struct Violation { - constraint: String, - settings: Vec, - } - let mut violators_by_kind: HashMap = - HashMap::default(); - - for error in failure_info.errors { - let violation = violators_by_kind.entry(error.kind).or_insert(Violation { - constraint: error.constraint.clone(), - settings: vec![], - }); - assert_eq!(violation.constraint, error.constraint); - if let Some(setting) = error.setting { - violation.settings.push(setting.clone()); - } - } - - let formatted_reasons: Vec<_> = violators_by_kind - .into_iter() - .map(|(kind, violation)| { - let kind_str = match kind { - SettingFitnessDistanceErrorKind::Missing => "missing", - SettingFitnessDistanceErrorKind::Mismatch => "a mismatch", - SettingFitnessDistanceErrorKind::TooSmall => "too small", - SettingFitnessDistanceErrorKind::TooLarge => "too large", - }; - - let mut settings = violation.settings; - - if settings.is_empty() { - return format!("{} (does not satisfy {})", kind_str, violation.constraint); - } - - settings.sort(); - - format!( - "{} ([{}] do not satisfy {})", - kind_str, - settings.join(", "), - violation.constraint - ) - }) - .collect(); - - let formatted_reason = match &formatted_reasons[..] { - [] => unreachable!(), - [reason] => reason.clone(), - [reasons @ .., reason] => { - let reasons = reasons.join(", "); - format!("either {reasons}, or {reason}") - } - }; - let message = Some(format!("Setting was {formatted_reason}.")); - - Self { - constraint, - message, - } - } -} diff --git a/constraints/src/lib.rs b/constraints/src/lib.rs deleted file mode 100644 index dca56e368..000000000 --- a/constraints/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! Pure Rust implementation of the constraint logic defined in the ["Media Capture and Streams"][mediacapture_streams] spec. -//! -//! [mediacapture_streams]: https://www.w3.org/TR/mediacapture-streams/ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod algorithms; -pub mod errors; -pub mod macros; -pub mod property; - -mod capabilities; -mod capability; -mod constraint; -mod constraints; -mod enumerations; -mod setting; -mod settings; -mod supported_constraints; - -#[allow(unused_imports)] -pub(crate) use self::{capabilities::MediaStreamCapabilities, settings::MediaStreamSettings}; -#[allow(unused_imports)] -pub use self::{ - capabilities::MediaTrackCapabilities, - capability::MediaTrackCapability, - constraint::{ - MediaTrackConstraint, MediaTrackConstraintResolutionStrategy, ResolvedMediaTrackConstraint, - ResolvedValueConstraint, ResolvedValueRangeConstraint, ResolvedValueSequenceConstraint, - SanitizedMediaTrackConstraint, ValueConstraint, ValueRangeConstraint, - ValueSequenceConstraint, - }, - constraints::{ - AdvancedMediaTrackConstraints, BoolOrMediaTrackConstraints, MandatoryMediaTrackConstraints, - MediaStreamConstraints, MediaTrackConstraintSet, MediaTrackConstraints, - ResolvedAdvancedMediaTrackConstraints, ResolvedMandatoryMediaTrackConstraints, - ResolvedMediaTrackConstraintSet, ResolvedMediaTrackConstraints, - SanitizedMandatoryMediaTrackConstraints, SanitizedMediaTrackConstraintSet, - SanitizedMediaTrackConstraints, - }, - enumerations::{FacingMode, ResizeMode}, - property::MediaTrackProperty, - setting::MediaTrackSetting, - settings::MediaTrackSettings, - supported_constraints::MediaTrackSupportedConstraints, -}; diff --git a/constraints/src/macros.rs b/constraints/src/macros.rs deleted file mode 100644 index 73f518204..000000000 --- a/constraints/src/macros.rs +++ /dev/null @@ -1,450 +0,0 @@ -//! Convenience macros. - -/// A convenience macro for defining settings. -#[macro_export] -macro_rules! settings { - [ - $($p:expr => $c:expr),* $(,)? - ] => { - <$crate::MediaTrackSettings as std::iter::FromIterator<_>>::from_iter([ - $(($p, $c.into())),* - ]) - }; -} - -pub use settings; - -/// A convenience macro for defining individual "value" constraints. -#[macro_export] -macro_rules! value_constraint { - ($($p:ident: $c:expr),+ $(,)?) => { - $crate::ValueConstraint::Constraint( - #[allow(clippy::needless_update)] - $crate::ResolvedValueConstraint { - $($p: Some($c)),+, - ..Default::default() - } - ) - }; - ($c:expr) => { - $crate::ValueConstraint::Bare($c) - }; -} - -pub use value_constraint; - -/// A convenience macro for defining individual "value range" constraints. -#[macro_export] -macro_rules! value_range_constraint { - {$($p:ident: $c:expr),+ $(,)?} => { - $crate::ValueRangeConstraint::Constraint( - $crate::ResolvedValueRangeConstraint { - $($p: Some($c)),+, - ..Default::default() - } - ) - }; - ($c:expr) => { - $crate::ValueRangeConstraint::Bare($c) - }; -} - -pub use value_range_constraint; - -/// A convenience macro for defining individual "value sequence" constraints. -#[macro_export] -macro_rules! value_sequence_constraint { - {$($p:ident: $c:expr),+ $(,)?} => { - $crate::ValueSequenceConstraint::Constraint( - $crate::ResolvedValueSequenceConstraint { - $($p: Some($c)),*, - ..Default::default() - } - ) - }; - ($c:expr) => { - $crate::ValueSequenceConstraint::Bare($c) - }; -} - -pub use value_sequence_constraint; - -/// A convenience macro for defining constraint sets. -#[macro_export] -macro_rules! constraint_set { - { - $($p:expr => $c:expr),* $(,)? - } => { - <$crate::MediaTrackConstraintSet as std::iter::FromIterator<_>>::from_iter([ - $(($p, $c.into())),* - ]) - }; -} - -pub use constraint_set; - -/// A convenience macro for defining "mandatory" constraints. -#[macro_export] -macro_rules! mandatory_constraints { - { - $($p:expr => $c:expr),* $(,)? - } => { - $crate::MandatoryMediaTrackConstraints::new( - constraint_set!{ - $($p => $c),* - } - ) - }; -} - -pub use mandatory_constraints; - -/// A convenience macro for defining "advanced" constraints. -#[macro_export] -macro_rules! advanced_constraints { - [ - $({ - $($p:expr => $c:expr),* $(,)? - }),* $(,)? - ] => { - <$crate::AdvancedMediaTrackConstraints as std::iter::FromIterator<_>>::from_iter([ - $(constraint_set!{ - $($p => $c),* - }),* - ]) - }; -} - -pub use advanced_constraints; - -/// A convenience macro for defining constraints. -#[macro_export] -macro_rules! constraints { - [ - mandatory: {$($mp:expr => $mc:expr),* $(,)?}, - advanced: [$( - {$($ap:expr => $ac:expr),* $(,)?} - ),* $(,)?] - ] => { - $crate::MediaTrackConstraints { - mandatory: mandatory_constraints!($($mp => $mc),*), - advanced: advanced_constraints!($({ $($ap => $ac),* }),*) - } - }; -} - -pub use constraints; - -#[allow(unused_macros)] -#[cfg(test)] -macro_rules! test_serde_symmetry { - (subject: $s:expr, json: $j:expr) => { - // Serialize: - { - let actual = serde_json::to_value($s.clone()).unwrap(); - let expected = $j.clone(); - - assert_eq!(actual, expected); - } - - // Deserialize: - { - let actual: Subject = serde_json::from_value($j).unwrap(); - let expected = $s; - - assert_eq!(actual, expected); - } - }; -} - -#[allow(unused_imports)] -#[cfg(test)] -pub(crate) use test_serde_symmetry; - -#[cfg(test)] -mod tests { - use crate::property::all::name::*; - use crate::{ - AdvancedMediaTrackConstraints, FacingMode, MandatoryMediaTrackConstraints, - MediaTrackConstraintSet, MediaTrackConstraints, MediaTrackSettings, - ResolvedValueConstraint, ResolvedValueRangeConstraint, ResolvedValueSequenceConstraint, - ValueConstraint, ValueRangeConstraint, ValueSequenceConstraint, - }; - - #[test] - fn settings() { - let actual: MediaTrackSettings = settings![ - &DEVICE_ID => "foobar".to_owned(), - &FRAME_RATE => 30.0, - &HEIGHT => 1080, - &FACING_MODE => FacingMode::user(), - ]; - - let expected = >::from_iter([ - (&DEVICE_ID, "foobar".to_owned().into()), - (&FRAME_RATE, 30.0.into()), - (&HEIGHT, 1080.into()), - (&FACING_MODE, FacingMode::user().into()), - ]); - - assert_eq!(actual, expected); - } - - mod constraint { - use super::*; - - #[test] - fn value() { - // Bare: - - let actual = value_constraint!("foobar".to_owned()); - - let expected = ValueConstraint::Bare("foobar".to_owned()); - - assert_eq!(actual, expected); - - // Constraint: - - let actual = value_constraint! { - exact: "foobar".to_owned(), - ideal: "bazblee".to_owned(), - }; - - let expected = ValueConstraint::Constraint( - ResolvedValueConstraint::default() - .exact("foobar".to_owned()) - .ideal("bazblee".to_owned()), - ); - - assert_eq!(actual, expected); - } - - #[test] - fn range() { - // Bare: - - let actual = value_range_constraint!(42); - - let expected = ValueRangeConstraint::Bare(42); - - assert_eq!(actual, expected); - - // Constraint: - - let actual = value_range_constraint! { - min: 30.0, - max: 60.0, - }; - - let expected = ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().min(30.0).max(60.0), - ); - - assert_eq!(actual, expected); - } - - #[test] - fn sequence() { - // Bare: - - let actual = value_sequence_constraint![vec![FacingMode::user()]]; - - let expected = ValueSequenceConstraint::Bare(vec![FacingMode::user()]); - - assert_eq!(actual, expected); - - // Constraint: - - let actual = value_sequence_constraint! { - ideal: vec![FacingMode::user()], - }; - - let expected = ValueSequenceConstraint::Constraint( - ResolvedValueSequenceConstraint::default().ideal(vec![FacingMode::user()]), - ); - - assert_eq!(actual, expected); - } - } - - #[test] - fn mandatory_constraints() { - let actual = mandatory_constraints! { - &DEVICE_ID => value_constraint! { - exact: "foobar".to_owned(), - ideal: "bazblee".to_owned(), - }, - &FRAME_RATE => value_range_constraint! { - min: 30.0, - max: 60.0, - }, - &FACING_MODE => value_sequence_constraint! { - exact: vec![FacingMode::user(), FacingMode::environment()] - }, - }; - - let expected = >::from_iter([ - ( - &DEVICE_ID, - ValueConstraint::Constraint( - ResolvedValueConstraint::default() - .exact("foobar".to_owned()) - .ideal("bazblee".to_owned()), - ) - .into(), - ), - ( - &FRAME_RATE, - ValueRangeConstraint::Constraint( - ResolvedValueRangeConstraint::default().min(30.0).max(60.0), - ) - .into(), - ), - ( - &FACING_MODE, - ValueSequenceConstraint::Constraint( - ResolvedValueSequenceConstraint::default() - .exact(vec![FacingMode::user(), FacingMode::environment()]), - ) - .into(), - ), - ]); - - assert_eq!(actual, expected); - } - - #[test] - fn advanced_constraints() { - let actual = advanced_constraints! [ - { - &DEVICE_ID => value_constraint! { - exact: "foobar".to_owned(), - ideal: "bazblee".to_owned(), - }, - }, - { - &FRAME_RATE => value_range_constraint! { - min: 30.0, - max: 60.0, - }, - }, - { - &FACING_MODE => value_sequence_constraint! { - exact: vec![FacingMode::user(), FacingMode::environment()] - }, - }, - ]; - - let expected = >::from_iter([ - >::from_iter([( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("foobar".to_owned()) - .ideal("bazblee".to_owned()) - .into(), - )]), - >::from_iter([( - &FRAME_RATE, - ResolvedValueRangeConstraint::default() - .min(30.0) - .max(60.0) - .into(), - )]), - >::from_iter([( - &FACING_MODE, - ResolvedValueSequenceConstraint::default() - .exact(vec![FacingMode::user(), FacingMode::environment()]) - .into(), - )]), - ]); - - assert_eq!(actual, expected); - } - - #[test] - fn constraints() { - let actual: MediaTrackConstraints = constraints!( - mandatory: { - &DEVICE_ID => value_constraint! { - exact: "foobar".to_owned(), - ideal: "bazblee".to_owned(), - }, - &FRAME_RATE => value_range_constraint! { - min: 30.0, - max: 60.0, - }, - &FACING_MODE => value_sequence_constraint! { - exact: vec![FacingMode::user(), FacingMode::environment()] - }, - }, - advanced: [ - { - &DEVICE_ID => value_constraint! { - exact: "foobar".to_owned(), - ideal: "bazblee".to_owned(), - }, - }, - { - &FRAME_RATE => value_range_constraint! { - min: 30.0, - max: 60.0, - }, - }, - { - &FACING_MODE => value_sequence_constraint! { - exact: vec![FacingMode::user(), FacingMode::environment()] - }, - }, - ] - ); - - let expected = MediaTrackConstraints { - mandatory: >::from_iter([ - ( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("foobar".to_owned()) - .ideal("bazblee".to_owned()) - .into(), - ), - ( - &FRAME_RATE, - ResolvedValueRangeConstraint::default() - .min(30.0) - .max(60.0) - .into(), - ), - ( - &FACING_MODE, - ResolvedValueSequenceConstraint::default() - .exact(vec![FacingMode::user(), FacingMode::environment()]) - .into(), - ), - ]), - advanced: >::from_iter([ - >::from_iter([( - &DEVICE_ID, - ResolvedValueConstraint::default() - .exact("foobar".to_owned()) - .ideal("bazblee".to_owned()) - .into(), - )]), - >::from_iter([( - &FRAME_RATE, - ResolvedValueRangeConstraint::default() - .min(30.0) - .max(60.0) - .into(), - )]), - >::from_iter([( - &FACING_MODE, - ResolvedValueSequenceConstraint::default() - .exact(vec![FacingMode::user(), FacingMode::environment()]) - .into(), - )]), - ]), - }; - - assert_eq!(actual, expected); - } -} diff --git a/constraints/src/property.rs b/constraints/src/property.rs deleted file mode 100644 index 473095773..000000000 --- a/constraints/src/property.rs +++ /dev/null @@ -1,301 +0,0 @@ -//! Constants identifying the properties of a [`MediaStreamTrack`][media_stream_track] object, -//! as defined in the ["Media Capture and Streams"][media_track_supported_constraints] spec. -//! -//! [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#mediastreamtrack -//! [media_track_supported_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatracksupportedconstraints - -use std::borrow::Cow; -use std::fmt::Display; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// An identifier for a media track property. -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub struct MediaTrackProperty(Cow<'static, str>); - -impl From<&MediaTrackProperty> for MediaTrackProperty { - fn from(borrowed: &MediaTrackProperty) -> Self { - borrowed.clone() - } -} - -impl From for MediaTrackProperty { - /// Creates a property from an owned representation of its name. - fn from(owned: String) -> Self { - Self(Cow::Owned(owned)) - } -} - -impl From<&str> for MediaTrackProperty { - /// Creates a property from an owned representation of its name. - /// - /// Use `MediaTrackProperty::named(str)` if your property name - /// is statically borrowed (i.e. `&'static str`). - fn from(borrowed: &str) -> Self { - Self(Cow::Owned(borrowed.to_owned())) - } -} - -impl Display for MediaTrackProperty { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.0) - } -} - -impl MediaTrackProperty { - /// Creates a property from a statically borrowed representation of its name. - pub const fn named(name: &'static str) -> Self { - Self(Cow::Borrowed(name)) - } - - /// The property's name. - pub fn name(&self) -> &str { - &self.0 - } -} - -/// Standard properties that apply to both, audio and video device types. -pub mod common { - use super::*; - - /// Names of common properties. - pub mod name { - use super::*; - - /// The identifier of the device generating the content of the track, - /// as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-deviceid - pub static DEVICE_ID: MediaTrackProperty = MediaTrackProperty::named("deviceId"); - - /// The document-unique group identifier for the device generating the content - /// of the track, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-groupid - pub static GROUP_ID: MediaTrackProperty = MediaTrackProperty::named("groupId"); - } - - /// Names of common properties. - pub fn names() -> Vec<&'static MediaTrackProperty> { - use self::name::*; - - vec![&DEVICE_ID, &GROUP_ID] - } -} - -/// Standard properties that apply only to audio device types. -pub mod audio_only { - use super::*; - - /// Names of audio-only properties. - pub mod name { - use super::*; - - /// Automatic gain control is often desirable on the input signal recorded - /// by the microphone, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-autogaincontrol - pub static AUTO_GAIN_CONTROL: MediaTrackProperty = - MediaTrackProperty::named("autoGainControl"); - - /// The number of independent channels of sound that the audio data contains, - /// i.e. the number of audio samples per sample frame, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-channelcount - pub static CHANNEL_COUNT: MediaTrackProperty = MediaTrackProperty::named("channelCount"); - - /// When one or more audio streams is being played in the processes of - /// various microphones, it is often desirable to attempt to remove - /// all the sound being played from the input signals recorded by the microphones. - /// This is referred to as echo cancellation, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-echocancellation - pub static ECHO_CANCELLATION: MediaTrackProperty = - MediaTrackProperty::named("echoCancellation"); - - /// The latency or latency range, in seconds, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-latency - pub static LATENCY: MediaTrackProperty = MediaTrackProperty::named("latency"); - - /// Noise suppression is often desirable on the input signal recorded by the microphone, - /// as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-noisesuppression - pub static NOISE_SUPPRESSION: MediaTrackProperty = - MediaTrackProperty::named("noiseSuppression"); - - /// The sample rate in samples per second for the audio data, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-samplerate - pub static SAMPLE_RATE: MediaTrackProperty = MediaTrackProperty::named("sampleRate"); - - /// The linear sample size in bits. This constraint can only - /// be satisfied for audio devices that produce linear samples, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-samplesize - pub static SAMPLE_SIZE: MediaTrackProperty = MediaTrackProperty::named("sampleSize"); - } - - /// Names of all audio-only properties. - pub fn names() -> Vec<&'static MediaTrackProperty> { - use self::name::*; - - vec![ - &AUTO_GAIN_CONTROL, - &CHANNEL_COUNT, - &ECHO_CANCELLATION, - &LATENCY, - &NOISE_SUPPRESSION, - &SAMPLE_RATE, - &SAMPLE_SIZE, - ] - } -} - -/// Standard properties that apply only to video device types. -pub mod video_only { - use super::*; - - /// Names of audio-only properties. - pub mod name { - use super::*; - - /// The exact aspect ratio (width in pixels divided by height in pixels, - /// represented as a double rounded to the tenth decimal place), - /// as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-aspectratio - pub static ASPECT_RATIO: MediaTrackProperty = MediaTrackProperty::named("aspectRatio"); - - /// The directions that the camera can face, as seen from the user's perspective, - /// as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-facingmode - pub static FACING_MODE: MediaTrackProperty = MediaTrackProperty::named("facingMode"); - - /// The exact frame rate (frames per second) or frame rate range, - /// as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-framerate - pub static FRAME_RATE: MediaTrackProperty = MediaTrackProperty::named("frameRate"); - - /// The height or height range, in pixels, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-height - pub static HEIGHT: MediaTrackProperty = MediaTrackProperty::named("height"); - - /// The width or width range, in pixels, as defined in the [spec][spec]. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-width - pub static WIDTH: MediaTrackProperty = MediaTrackProperty::named("width"); - - /// The means by which the resolution can be derived by the client, as defined in the [spec][spec]. - /// - /// In other words, whether the client is allowed to use cropping and downscaling on the camera output. - /// - /// [spec]: https://www.w3.org/TR/mediacapture-streams/#dfn-resizemode - pub static RESIZE_MODE: MediaTrackProperty = MediaTrackProperty::named("resizeMode"); - } - - /// Names of all video-only properties. - pub fn names() -> Vec<&'static MediaTrackProperty> { - use self::name::*; - vec![ - &ASPECT_RATIO, - &FACING_MODE, - &FRAME_RATE, - &HEIGHT, - &WIDTH, - &RESIZE_MODE, - ] - } -} - -/// The union of all standard properties (i.e. common + audio + video). -pub mod all { - use super::*; - - /// Names of all properties. - pub mod name { - pub use super::audio_only::name::*; - pub use super::common::name::*; - pub use super::video_only::name::*; - } - - /// Names of all properties. - pub fn names() -> Vec<&'static MediaTrackProperty> { - let mut all = vec![]; - all.append(&mut self::common::names()); - all.append(&mut self::audio_only::names()); - all.append(&mut self::video_only::names()); - all - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Subject = MediaTrackProperty; - - mod from { - use super::*; - - #[test] - fn owned() { - let actuals = [Subject::from("string"), Subject::from("string".to_owned())]; - let expected = MediaTrackProperty(Cow::Owned("string".to_owned())); - - for actual in actuals { - assert_eq!(actual, expected); - - // TODO: remove feature-gate, once stabilized: - #[cfg(feature = "cow_is_borrowed")] - assert!(actual.0.is_owned()); - } - } - - #[test] - fn borrowed() { - let actual = Subject::named("string"); - let expected = MediaTrackProperty(Cow::Borrowed("string")); - - assert_eq!(actual, expected); - - // TODO: remove feature-gate, once stabilized: - #[cfg(feature = "cow_is_borrowed")] - assert!(actual.0.is_borrowed()); - } - } - - #[test] - fn name() { - assert_eq!(Subject::named("string").name(), "string"); - } - - #[test] - fn to_string() { - assert_eq!(Subject::named("string").to_string(), "string"); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackProperty; - - #[test] - fn is_symmetric() { - let subject = Subject::named("string"); - let json = serde_json::json!("string"); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/setting.rs b/constraints/src/setting.rs deleted file mode 100644 index 15dc24c64..000000000 --- a/constraints/src/setting.rs +++ /dev/null @@ -1,159 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// A single [setting][media_track_settings] value of a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_settings]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatracksettings -/// [media_track_supported_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatracksupportedconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(untagged))] -pub enum MediaTrackSetting { - /// A boolean-valued track setting. - Bool(bool), - /// An integer-valued track setting. - Integer(i64), - /// A floating-point-valued track setting. - Float(f64), - /// A string-valued track setting. - String(String), -} - -impl std::fmt::Display for MediaTrackSetting { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Bool(setting) => f.write_fmt(format_args!("{setting:?}")), - Self::Integer(setting) => f.write_fmt(format_args!("{setting:?}")), - Self::Float(setting) => f.write_fmt(format_args!("{setting:?}")), - Self::String(setting) => f.write_fmt(format_args!("{setting:?}")), - } - } -} - -impl From for MediaTrackSetting { - fn from(setting: bool) -> Self { - Self::Bool(setting) - } -} - -impl From for MediaTrackSetting { - fn from(setting: i64) -> Self { - Self::Integer(setting) - } -} - -impl From for MediaTrackSetting { - fn from(setting: f64) -> Self { - Self::Float(setting) - } -} - -impl From for MediaTrackSetting { - fn from(setting: String) -> Self { - Self::String(setting) - } -} - -impl<'a> From<&'a str> for MediaTrackSetting { - fn from(setting: &'a str) -> Self { - Self::String(setting.to_owned()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type Subject = MediaTrackSetting; - - mod from { - use super::*; - - #[test] - fn bool() { - let actual = Subject::from(true); - let expected = Subject::Bool(true); - - assert_eq!(actual, expected); - } - - #[test] - fn integer() { - let actual = Subject::from(42); - let expected = Subject::Integer(42); - - assert_eq!(actual, expected); - } - - #[test] - fn float() { - let actual = Subject::from(4.2); - let expected = Subject::Float(4.2); - - assert_eq!(actual, expected); - } - - #[test] - fn string() { - let actual = Subject::from("string".to_owned()); - let expected = Subject::String("string".to_owned()); - - assert_eq!(actual, expected); - } - } - - #[test] - fn to_string() { - assert_eq!(Subject::from(true).to_string(), "true"); - assert_eq!(Subject::from(42).to_string(), "42"); - assert_eq!(Subject::from(4.2).to_string(), "4.2"); - assert_eq!(Subject::from("string".to_owned()).to_string(), "\"string\""); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaTrackSetting; - - #[test] - fn bool() { - let subject = Subject::Bool(true); - let json = serde_json::json!(true); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn integer() { - let subject = Subject::Integer(42); - let json = serde_json::json!(42); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn float() { - let subject = Subject::Float(4.2); - let json = serde_json::json!(4.2); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn string() { - let subject = Subject::String("string".to_owned()); - let json = serde_json::json!("string"); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/settings.rs b/constraints/src/settings.rs deleted file mode 100644 index b69a2c76a..000000000 --- a/constraints/src/settings.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod stream; -mod track; - -pub(crate) use self::stream::MediaStreamSettings; -pub use self::track::MediaTrackSettings; diff --git a/constraints/src/settings/stream.rs b/constraints/src/settings/stream.rs deleted file mode 100644 index 79f99cc99..000000000 --- a/constraints/src/settings/stream.rs +++ /dev/null @@ -1,58 +0,0 @@ -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::MediaTrackSettings; - -/// The settings of a [`MediaStream`][media_stream] object. -/// -/// # W3C Spec Compliance -/// -/// There exists no corresponding type in the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// [media_stream]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastream -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -#[derive(Default, Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))] -pub(crate) struct MediaStreamSettings { - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub audio: Option, - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "core::option::Option::is_none") - )] - pub video: Option, -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - - type Subject = MediaStreamSettings; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = MediaStreamSettings { - audio: Some(MediaTrackSettings::default()), - video: None, - }; - let json = serde_json::json!({ - "audio": {} - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/settings/track.rs b/constraints/src/settings/track.rs deleted file mode 100644 index 625e7bc4d..000000000 --- a/constraints/src/settings/track.rs +++ /dev/null @@ -1,169 +0,0 @@ -use std::collections::HashMap; -use std::iter::FromIterator; -use std::ops::{Deref, DerefMut}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::{MediaTrackProperty, MediaTrackSetting}; - -/// The settings of a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`MediaTrackSettings`][media_track_settings] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// The W3C spec defines `MediaTrackSettings` in terma of a dictionary, -/// which per the [WebIDL spec][webidl_spec] is an ordered map (e.g. [`IndexMap`][index_map]). -/// Since the spec however does not make use of the order of items -/// in the map we use a simple [`HashMap`][hash_map]. -/// -/// [hash_map]: https://doc.rust-lang.org/std/collections/struct.HashMap.html -/// [index_map]: https://docs.rs/indexmap/latest/indexmap/set/struct.IndexMap.html -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_settings]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatracksettings -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -/// [webidl_spec]: https://webidl.spec.whatwg.org/#idl-dictionaries -#[derive(Debug, Clone, Default, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct MediaTrackSettings(HashMap); - -impl MediaTrackSettings { - /// Creates a settings value from its inner hashmap. - pub fn new(settings: HashMap) -> Self { - Self(settings) - } - - /// Consumes the value, returning its inner hashmap. - pub fn into_inner(self) -> HashMap { - self.0 - } -} - -impl Deref for MediaTrackSettings { - type Target = HashMap; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for MediaTrackSettings { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl FromIterator<(T, MediaTrackSetting)> for MediaTrackSettings -where - T: Into, -{ - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - Self::new(iter.into_iter().map(|(k, v)| (k.into(), v)).collect()) - } -} - -impl IntoIterator for MediaTrackSettings { - type Item = (MediaTrackProperty, MediaTrackSetting); - type IntoIter = std::collections::hash_map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::property::all::name::*; - - type Subject = MediaTrackSettings; - - #[test] - fn into_inner() { - let hash_map = HashMap::from_iter([ - (DEVICE_ID.clone(), "device-id".into()), - (AUTO_GAIN_CONTROL.clone(), true.into()), - (CHANNEL_COUNT.clone(), 20.into()), - (LATENCY.clone(), 2.0.into()), - ]); - - let subject = Subject::new(hash_map.clone()); - - let actual = subject.into_inner(); - - let expected = hash_map; - - assert_eq!(actual, expected); - } - - #[test] - fn into_iter() { - let hash_map = HashMap::from_iter([ - (DEVICE_ID.clone(), "device-id".into()), - (AUTO_GAIN_CONTROL.clone(), true.into()), - (CHANNEL_COUNT.clone(), 20.into()), - (LATENCY.clone(), 2.0.into()), - ]); - - let subject = Subject::new(hash_map.clone()); - - let actual: HashMap<_, _> = subject.into_iter().collect(); - - let expected = hash_map; - - assert_eq!(actual, expected); - } - - #[test] - fn deref_and_deref_mut() { - let mut subject = Subject::default(); - - // Deref mut: - subject.insert(DEVICE_ID.clone(), "device-id".into()); - - // Deref: - assert!(subject.contains_key(&DEVICE_ID)); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - use crate::property::all::name::*; - - type Subject = MediaTrackSettings; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({}); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = Subject::from_iter([ - (&DEVICE_ID, "device-id".into()), - (&AUTO_GAIN_CONTROL, true.into()), - (&CHANNEL_COUNT, 2.into()), - (&LATENCY, 0.123.into()), - ]); - let json = serde_json::json!({ - "deviceId": "device-id".to_owned(), - "autoGainControl": true, - "channelCount": 2, - "latency": 0.123, - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/src/supported_constraints.rs b/constraints/src/supported_constraints.rs deleted file mode 100644 index 21ca96b08..000000000 --- a/constraints/src/supported_constraints.rs +++ /dev/null @@ -1,252 +0,0 @@ -use std::collections::HashSet; -use std::iter::FromIterator; -use std::ops::{Deref, DerefMut}; - -#[cfg(feature = "serde")] -use serde::{ - de::{MapAccess, Visitor}, - ser::SerializeMap, - Deserialize, Deserializer, Serialize, Serializer, -}; - -use crate::MediaTrackProperty; - -/// The list of constraints recognized by a User Agent for controlling the -/// capabilities of a [`MediaStreamTrack`][media_stream_track] object. -/// -/// # W3C Spec Compliance -/// -/// Corresponds to [`MediaTrackSupportedConstraints`][media_track_supported_constraints] -/// from the W3C ["Media Capture and Streams"][media_capture_and_streams_spec] spec. -/// -/// The W3C spec defines `MediaTrackSupportedConstraints` in terma of a dictionary, -/// which per the [WebIDL spec][webidl_spec] is an ordered map (e.g. [`IndexSet`][index_set]). -/// Since the spec however does not make use of the order of items -/// in the map we use a simple `HashSet`. -/// -/// [hash_set]: https://doc.rust-lang.org/std/collections/struct.HashSet.html -/// [index_set]: https://docs.rs/indexmap/latest/indexmap/set/struct.IndexSet.html -/// [media_stream_track]: https://www.w3.org/TR/mediacapture-streams/#dom-mediastreamtrack -/// [media_track_supported_constraints]: https://www.w3.org/TR/mediacapture-streams/#dom-mediatracksupportedconstraints -/// [media_capture_and_streams_spec]: https://www.w3.org/TR/mediacapture-streams -/// [webidl_spec]: https://webidl.spec.whatwg.org/#idl-dictionaries -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct MediaTrackSupportedConstraints(HashSet); - -impl MediaTrackSupportedConstraints { - /// Creates a supported constraints value from its inner hashmap. - pub fn new(properties: HashSet) -> Self { - Self(properties) - } - - /// Consumes the value, returning its inner hashmap. - pub fn into_inner(self) -> HashSet { - self.0 - } -} - -impl Deref for MediaTrackSupportedConstraints { - type Target = HashSet; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for MediaTrackSupportedConstraints { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Default for MediaTrackSupportedConstraints { - /// [Default values][default_values] as defined by the W3C specification. - /// - /// [default_values]: https://www.w3.org/TR/mediacapture-streams/#dictionary-mediatracksupportedconstraints-members - fn default() -> Self { - use crate::property::all::names as property_names; - - Self::from_iter(property_names().into_iter().cloned()) - } -} - -impl FromIterator for MediaTrackSupportedConstraints -where - T: Into, -{ - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - Self(iter.into_iter().map(|property| property.into()).collect()) - } -} - -impl IntoIterator for MediaTrackSupportedConstraints { - type Item = MediaTrackProperty; - type IntoIter = std::collections::hash_set::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -#[cfg(feature = "serde")] -impl<'de> Deserialize<'de> for MediaTrackSupportedConstraints { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_map(SerdeVisitor) - } -} - -#[cfg(feature = "serde")] -impl Serialize for MediaTrackSupportedConstraints { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.0.len()))?; - for property in &self.0 { - map.serialize_entry(property, &true)?; - } - map.end() - } -} - -#[cfg(feature = "serde")] -struct SerdeVisitor; - -#[cfg(feature = "serde")] -impl<'de> Visitor<'de> for SerdeVisitor { - type Value = MediaTrackSupportedConstraints; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("an object with strings as keys and `true` as values") - } - - fn visit_map(self, mut access: M) -> Result - where - M: MapAccess<'de>, - { - let mut set = HashSet::with_capacity(access.size_hint().unwrap_or(0)); - while let Some((key, value)) = access.next_entry()? { - if value { - set.insert(key); - } - } - Ok(MediaTrackSupportedConstraints(set)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::property::all::name::*; - - type Subject = MediaTrackSupportedConstraints; - - #[test] - fn into_inner() { - let hash_set = HashSet::from_iter([ - DEVICE_ID.clone(), - AUTO_GAIN_CONTROL.clone(), - CHANNEL_COUNT.clone(), - LATENCY.clone(), - ]); - - let subject = Subject::new(hash_set.clone()); - - let actual = subject.into_inner(); - - let expected = hash_set; - - assert_eq!(actual, expected); - } - - #[test] - fn into_iter() { - let hash_set = HashSet::from_iter([ - DEVICE_ID.clone(), - AUTO_GAIN_CONTROL.clone(), - CHANNEL_COUNT.clone(), - LATENCY.clone(), - ]); - - let subject = Subject::new(hash_set.clone()); - - let actual: HashSet<_, _> = subject.into_iter().collect(); - - let expected = hash_set; - - assert_eq!(actual, expected); - } - - #[test] - fn deref_and_deref_mut() { - let mut subject = Subject::default(); - - // Deref mut: - subject.insert(DEVICE_ID.clone()); - - // Deref: - assert!(subject.contains(&DEVICE_ID)); - } -} - -#[cfg(feature = "serde")] -#[cfg(test)] -mod serde_tests { - use super::*; - use crate::macros::test_serde_symmetry; - use crate::property::all::name::*; - - type Subject = MediaTrackSupportedConstraints; - - #[test] - fn default() { - let subject = Subject::default(); - let json = serde_json::json!({ - "deviceId": true, - "groupId": true, - "autoGainControl": true, - "channelCount": true, - "echoCancellation": true, - "latency": true, - "noiseSuppression": true, - "sampleRate": true, - "sampleSize": true, - "aspectRatio": true, - "facingMode": true, - "frameRate": true, - "height": true, - "width": true, - "resizeMode": true, - }); - - test_serde_symmetry!(subject: subject, json: json); - } - - #[test] - fn customized() { - let subject = Subject::from_iter([ - &DEVICE_ID, - &GROUP_ID, - &AUTO_GAIN_CONTROL, - &CHANNEL_COUNT, - &ASPECT_RATIO, - &FACING_MODE, - ]); - let json = serde_json::json!({ - "deviceId": true, - "groupId": true, - "autoGainControl": true, - "channelCount": true, - "aspectRatio": true, - "facingMode": true - }); - - test_serde_symmetry!(subject: subject, json: json); - } -} diff --git a/constraints/tests/w3c_spec_examples.rs b/constraints/tests/w3c_spec_examples.rs deleted file mode 100644 index f46144f00..000000000 --- a/constraints/tests/w3c_spec_examples.rs +++ /dev/null @@ -1,216 +0,0 @@ -#[cfg(feature = "serde")] -use webrtc_constraints::{ - property::all::name::*, AdvancedMediaTrackConstraints, BoolOrMediaTrackConstraints, - MediaTrackConstraintSet, MediaTrackConstraints, ResolvedValueRangeConstraint, - ValueRangeConstraint, -}; - -// -#[cfg(feature = "serde")] -#[test] -fn w3c_spec_example_1() { - use std::iter::FromIterator; - - use webrtc_constraints::{MandatoryMediaTrackConstraints, MediaStreamConstraints}; - - let actual: MediaStreamConstraints = { - let json = serde_json::json!({ - "video": { - "width": 1280, - "height": 720, - "aspectRatio": 1.5, - } - }); - serde_json::from_value(json).unwrap() - }; - let expected = MediaStreamConstraints { - audio: BoolOrMediaTrackConstraints::Bool(false), - video: BoolOrMediaTrackConstraints::Constraints(MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([ - (&WIDTH, 1280.into()), - (&HEIGHT, 720.into()), - (&ASPECT_RATIO, 1.5.into()), - ]), - advanced: AdvancedMediaTrackConstraints::default(), - }), - }; - - assert_eq!(actual, expected); -} - -// -#[cfg(feature = "serde")] -#[test] -fn w3c_spec_example_2() { - use std::iter::FromIterator; - - use webrtc_constraints::{MandatoryMediaTrackConstraints, MediaStreamConstraints}; - - let actual: MediaStreamConstraints = { - let json = serde_json::json!({ - "video": { - "width": { "min": 640, "ideal": 1280 }, - "height": { "min": 480, "ideal": 720 }, - "aspectRatio": 1.5, - "frameRate": { "min": 20.0 }, - } - }); - serde_json::from_value(json).unwrap() - }; - - let expected = MediaStreamConstraints { - audio: BoolOrMediaTrackConstraints::Bool(false), - video: BoolOrMediaTrackConstraints::Constraints(MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([ - ( - &WIDTH, - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint { - min: Some(640), - max: None, - exact: None, - ideal: Some(1280), - }) - .into(), - ), - ( - &HEIGHT, - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint { - min: Some(480), - max: None, - exact: None, - ideal: Some(720), - }) - .into(), - ), - (&ASPECT_RATIO, ValueRangeConstraint::Bare(1.5).into()), - ( - &FRAME_RATE, - ValueRangeConstraint::Constraint(ResolvedValueRangeConstraint { - min: Some(20.0), - max: None, - exact: None, - ideal: None, - }) - .into(), - ), - ]), - advanced: AdvancedMediaTrackConstraints::default(), - }), - }; - - assert_eq!(actual, expected); -} - -// -#[cfg(feature = "serde")] -#[test] -fn w3c_spec_example_3() { - use std::iter::FromIterator; - - use webrtc_constraints::{MandatoryMediaTrackConstraints, MediaStreamConstraints}; - - let actual: MediaStreamConstraints = { - let json = serde_json::json!({ - "video": { - "height": { "min": 480, "ideal": 720 }, - "width": { "min": 640, "ideal": 1280 }, - "frameRate": { "min": 30.0 }, - "advanced": [ - {"width": 1920, "height": 1280 }, - {"aspectRatio": 1.333}, - {"frameRate": {"min": 50.0 } }, - {"frameRate": {"min": 40.0 } } - ] - } - }); - serde_json::from_value(json).unwrap() - }; - - let expected = MediaStreamConstraints { - audio: BoolOrMediaTrackConstraints::Bool(false), - video: BoolOrMediaTrackConstraints::Constraints(MediaTrackConstraints { - mandatory: MandatoryMediaTrackConstraints::from_iter([ - ( - &HEIGHT, - ResolvedValueRangeConstraint { - min: Some(480), - max: None, - exact: None, - ideal: Some(720), - } - .into(), - ), - ( - &WIDTH, - ResolvedValueRangeConstraint { - min: Some(640), - max: None, - exact: None, - ideal: Some(1280), - } - .into(), - ), - ( - &FRAME_RATE, - ResolvedValueRangeConstraint { - min: Some(30.0), - max: None, - exact: None, - ideal: None, - } - .into(), - ), - ]), - advanced: AdvancedMediaTrackConstraints::new(vec![ - MediaTrackConstraintSet::from_iter([(&WIDTH, 1920.into()), (&HEIGHT, 1280.into())]), - MediaTrackConstraintSet::from_iter([(&ASPECT_RATIO, 1.333.into())]), - MediaTrackConstraintSet::from_iter([( - &FRAME_RATE, - ResolvedValueRangeConstraint { - min: Some(50.0), - max: None, - exact: None, - ideal: None, - } - .into(), - )]), - MediaTrackConstraintSet::from_iter([( - &FRAME_RATE, - ResolvedValueRangeConstraint { - min: Some(40.0), - max: None, - exact: None, - ideal: None, - } - .into(), - )]), - ]), - }), - }; - - assert_eq!(actual, expected); -} - -// -#[cfg(feature = "serde")] -#[test] -fn w3c_spec_example_4() { - use std::iter::FromIterator; - - let actual: MediaTrackConstraintSet = { - let json = serde_json::json!({ - "width": 1920, - "height": 1080, - "frameRate": 30, - }); - serde_json::from_value(json).unwrap() - }; - - let expected = MediaTrackConstraintSet::from_iter([ - (&WIDTH, ValueRangeConstraint::Bare(1920).into()), - (&HEIGHT, ValueRangeConstraint::Bare(1080).into()), - (&FRAME_RATE, ValueRangeConstraint::Bare(30).into()), - ]); - - assert_eq!(actual, expected); -} diff --git a/data/.gitignore b/data/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/data/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/data/CHANGELOG.md b/data/CHANGELOG.md deleted file mode 100644 index f632b1312..000000000 --- a/data/CHANGELOG.md +++ /dev/null @@ -1,30 +0,0 @@ -# webrtc-data changelog - -## Unreleased - -* Remove builder pattern from `data_channel::Config` [#411](https://github.com/webrtc-rs/webrtc/pull/411). - -## v0.7.0 - -* Increased required `webrtc-sctp` version to `0.8.0`. - -## v0.6.0 - -* Increased minimum support rust version to `1.60.0`. -* Do not loose data in `PollDataChannel::poll_write` [#341](https://github.com/webrtc-rs/webrtc/pull/341). -* `PollDataChannel::poll_shutdown`: make sure to flush any writes before shutting down [#340](https://github.com/webrtc-rs/webrtc/pull/340) -* Increased required `webrtc-util` version to `0.7.0`. -* Increased required `webrtc-sctp` version to `0.7.0`. - -### Breaking changes - -* Make `DataChannel::on_buffered_amount_low` function non-async [#338](https://github.com/webrtc-rs/webrtc/pull/338). - -## v0.5.0 - -* [#16 [PollDataChannel] reset shutdown_fut future after done](https://github.com/webrtc-rs/data/pull/16) by [@melekes](https://github.com/melekes). -* Increase min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). - -## Prior to 0.4.0 - -Before 0.4.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/data/releases). diff --git a/data/Cargo.toml b/data/Cargo.toml deleted file mode 100644 index f03fb72c7..000000000 --- a/data/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -[package] -name = "webrtc-data" -version = "0.9.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of WebRTC DataChannel API" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-data" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/data" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["conn", "marshal"] } -sctp = { version = "0.10.0", path = "../sctp", package = "webrtc-sctp" } - -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -bytes = "1" -log = "0.4" -thiserror = "1" -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" # must match the min version of the `tokio` crate above -env_logger = "0.10" -chrono = "0.4.28" diff --git a/data/LICENSE-APACHE b/data/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/data/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/data/LICENSE-MIT b/data/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/data/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/data/README.md b/data/README.md deleted file mode 100644 index 8230caac5..000000000 --- a/data/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of WebRTC DataChannel. Rewrite Pion DataChannel in Rust -

diff --git a/data/codecov.yml b/data/codecov.yml deleted file mode 100644 index 616770e51..000000000 --- a/data/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 00d131c6-1478-4018-b481-be1b44f5f094 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/data/doc/webrtc.rs.png b/data/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/data/doc/webrtc.rs.png and /dev/null differ diff --git a/data/src/data_channel/data_channel_test.rs b/data/src/data_channel/data_channel_test.rs deleted file mode 100644 index 382ae85b9..000000000 --- a/data/src/data_channel/data_channel_test.rs +++ /dev/null @@ -1,670 +0,0 @@ -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::{broadcast, mpsc}; -use tokio::time::Duration; -use util::conn::conn_bridge::*; -use util::conn::*; - -use super::*; -use crate::error::Result; - -async fn bridge_process_at_least_one(br: &Arc) { - let mut n_sum = 0; - loop { - tokio::time::sleep(Duration::from_millis(10)).await; - n_sum += br.tick().await; - if br.len(0).await == 0 && br.len(1).await == 0 && n_sum > 0 { - break; - } - } -} - -async fn create_new_association_pair( - br: &Arc, - ca: Arc, - cb: Arc, -) -> Result<(Arc, Arc)> { - let (handshake0ch_tx, mut handshake0ch_rx) = mpsc::channel(1); - let (handshake1ch_tx, mut handshake1ch_rx) = mpsc::channel(1); - let (closed_tx, mut closed_rx0) = broadcast::channel::<()>(1); - let mut closed_rx1 = closed_tx.subscribe(); - - // Setup client - tokio::spawn(async move { - let client = Association::client(sctp::association::Config { - net_conn: ca, - max_receive_buffer_size: 0, - max_message_size: 0, - name: "client".to_owned(), - }) - .await; - - let _ = handshake0ch_tx.send(client).await; - let _ = closed_rx0.recv().await; - - Result::<()>::Ok(()) - }); - - // Setup server - tokio::spawn(async move { - let server = Association::server(sctp::association::Config { - net_conn: cb, - max_receive_buffer_size: 0, - max_message_size: 0, - name: "server".to_owned(), - }) - .await; - - let _ = handshake1ch_tx.send(server).await; - let _ = closed_rx1.recv().await; - - Result::<()>::Ok(()) - }); - - let mut client = None; - let mut server = None; - let mut a0handshake_done = false; - let mut a1handshake_done = false; - let mut i = 0; - while (!a0handshake_done || !a1handshake_done) && i < 100 { - br.tick().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - r0 = handshake0ch_rx.recv() => { - if let Ok(c) = r0.unwrap() { - client = Some(c); - } - a0handshake_done = true; - }, - r1 = handshake1ch_rx.recv() => { - if let Ok(s) = r1.unwrap() { - server = Some(s); - } - a1handshake_done = true; - }, - }; - i += 1; - } - - if !a0handshake_done || !a1handshake_done { - return Err(Error::new("handshake failed".to_owned())); - } - - drop(closed_tx); - - Ok((Arc::new(client.unwrap()), Arc::new(server.unwrap()))) -} - -async fn close_association_pair( - br: &Arc, - client: Arc, - server: Arc, -) { - let (handshake0ch_tx, mut handshake0ch_rx) = mpsc::channel(1); - let (handshake1ch_tx, mut handshake1ch_rx) = mpsc::channel(1); - let (closed_tx, mut closed_rx0) = broadcast::channel::<()>(1); - let mut closed_rx1 = closed_tx.subscribe(); - - // Close client - tokio::spawn(async move { - client.close().await?; - let _ = handshake0ch_tx.send(()).await; - let _ = closed_rx0.recv().await; - - Result::<()>::Ok(()) - }); - - // Close server - tokio::spawn(async move { - server.close().await?; - let _ = handshake1ch_tx.send(()).await; - let _ = closed_rx1.recv().await; - - Result::<()>::Ok(()) - }); - - let mut a0handshake_done = false; - let mut a1handshake_done = false; - let mut i = 0; - while (!a0handshake_done || !a1handshake_done) && i < 100 { - br.tick().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - _ = handshake0ch_rx.recv() => { - a0handshake_done = true; - }, - _ = handshake1ch_rx.recv() => { - a1handshake_done = true; - }, - }; - i += 1; - } - - drop(closed_tx); -} - -//use std::io::Write; - -async fn pr_ordered_unordered_test(channel_type: ChannelType, is_ordered: bool) -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let mut sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 2000]; - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; - - let cfg = Config { - channel_type, - reliability_parameter: 0, - label: "data".to_string(), - ..Default::default() - }; - - let dc0 = DataChannel::dial(&a0, 100, cfg.clone()).await?; - bridge_process_at_least_one(&br).await; - - let existing_data_channels: Vec = Vec::new(); - let dc1 = DataChannel::accept(&a1, Config::default(), &existing_data_channels).await?; - bridge_process_at_least_one(&br).await; - - assert_eq!(dc0.config, cfg, "local config should match"); - assert_eq!(dc1.config, cfg, "remote config should match"); - - dc0.commit_reliability_params(); - dc1.commit_reliability_params(); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = dc0 - .write_data_channel(&Bytes::from(sbuf.clone()), true) - .await?; - assert_eq!(sbuf.len(), n, "data length should match"); - - sbuf[0..4].copy_from_slice(&2u32.to_be_bytes()); - let n = dc0 - .write_data_channel(&Bytes::from(sbuf.clone()), true) - .await?; - assert_eq!(sbuf.len(), n, "data length should match"); - - if !is_ordered { - sbuf[0..4].copy_from_slice(&3u32.to_be_bytes()); - let n = dc0 - .write_data_channel(&Bytes::from(sbuf.clone()), true) - .await?; - assert_eq!(sbuf.len(), n, "data length should match"); - } - - tokio::time::sleep(Duration::from_millis(100)).await; - br.drop_offset(0, 0, 1).await; // drop the first packet on the wire - if !is_ordered { - br.reorder(0).await; - } else { - tokio::time::sleep(Duration::from_millis(100)).await; - } - bridge_process_at_least_one(&br).await; - - if !is_ordered { - let (n, is_string) = dc1.read_data_channel(&mut rbuf[..]).await?; - assert!(is_string, "should return isString being true"); - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 3, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - } - - let (n, is_string) = dc1.read_data_channel(&mut rbuf[..]).await?; - assert!(is_string, "should return isString being true"); - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 2, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - - dc0.close().await?; - dc1.close().await?; - bridge_process_at_least_one(&br).await; - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_channel_type_reliable_ordered() -> Result<()> { - let mut sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 1500]; - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; - - let cfg = Config { - channel_type: ChannelType::Reliable, - reliability_parameter: 123, - label: "data".to_string(), - ..Default::default() - }; - - let dc0 = DataChannel::dial(&a0, 100, cfg.clone()).await?; - bridge_process_at_least_one(&br).await; - - let existing_data_channels: Vec = Vec::new(); - let dc1 = DataChannel::accept(&a1, Config::default(), &existing_data_channels).await?; - bridge_process_at_least_one(&br).await; - - assert_eq!(dc0.config, cfg, "local config should match"); - assert_eq!(dc1.config, cfg, "remote config should match"); - - br.reorder_next_nwrites(0, 2); // reordering on the wire - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = dc0.write(&Bytes::from(sbuf.clone())).await?; - assert_eq!(sbuf.len(), n, "data length should match"); - - sbuf[0..4].copy_from_slice(&2u32.to_be_bytes()); - let n = dc0.write(&Bytes::from(sbuf.clone())).await?; - assert_eq!(sbuf.len(), n, "data length should match"); - - bridge_process_at_least_one(&br).await; - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 1, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 2, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - - dc0.close().await?; - dc1.close().await?; - bridge_process_at_least_one(&br).await; - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_channel_type_reliable_unordered() -> Result<()> { - let mut sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 1500]; - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; - - let cfg = Config { - channel_type: ChannelType::ReliableUnordered, - reliability_parameter: 123, - label: "data".to_string(), - ..Default::default() - }; - - let dc0 = DataChannel::dial(&a0, 100, cfg.clone()).await?; - bridge_process_at_least_one(&br).await; - - let existing_data_channels: Vec = Vec::new(); - let dc1 = DataChannel::accept(&a1, Config::default(), &existing_data_channels).await?; - bridge_process_at_least_one(&br).await; - - assert_eq!(dc0.config, cfg, "local config should match"); - assert_eq!(dc1.config, cfg, "remote config should match"); - - dc0.commit_reliability_params(); - dc1.commit_reliability_params(); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = dc0 - .write_data_channel(&Bytes::from(sbuf.clone()), true) - .await?; - assert_eq!(sbuf.len(), n, "data length should match"); - - sbuf[0..4].copy_from_slice(&2u32.to_be_bytes()); - let n = dc0 - .write_data_channel(&Bytes::from(sbuf.clone()), true) - .await?; - assert_eq!(sbuf.len(), n, "data length should match"); - - tokio::time::sleep(Duration::from_millis(100)).await; - br.reorder(0).await; // reordering on the wire - bridge_process_at_least_one(&br).await; - - let (n, is_string) = dc1.read_data_channel(&mut rbuf[..]).await?; - assert!(is_string, "should return isString being true"); - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 2, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - - let (n, is_string) = dc1.read_data_channel(&mut rbuf[..]).await?; - assert!(is_string, "should return isString being true"); - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 1, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - - dc0.close().await?; - dc1.close().await?; - bridge_process_at_least_one(&br).await; - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -#[cfg(not(target_os = "windows"))] // this times out in CI on windows. -#[tokio::test] -async fn test_data_channel_channel_type_partial_reliable_rexmit() -> Result<()> { - pr_ordered_unordered_test(ChannelType::PartialReliableRexmit, true).await -} - -#[cfg(not(target_os = "windows"))] // this times out in CI on windows. -#[tokio::test] -async fn test_data_channel_channel_type_partial_reliable_rexmit_unordered() -> Result<()> { - pr_ordered_unordered_test(ChannelType::PartialReliableRexmitUnordered, false).await -} - -#[cfg(not(target_os = "windows"))] // this times out in CI on windows. -#[tokio::test] -async fn test_data_channel_channel_type_partial_reliable_timed() -> Result<()> { - pr_ordered_unordered_test(ChannelType::PartialReliableTimed, true).await -} - -#[cfg(not(target_os = "windows"))] // this times out in CI on windows. -#[tokio::test] -async fn test_data_channel_channel_type_partial_reliable_timed_unordered() -> Result<()> { - pr_ordered_unordered_test(ChannelType::PartialReliableTimedUnordered, false).await -} - -//TODO: remove this conditional test -#[cfg(not(any(target_os = "macos", target_os = "windows")))] -#[tokio::test] -async fn test_data_channel_buffered_amount() -> Result<()> { - let sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 1000]; - - let n_cbs = Arc::new(AtomicUsize::new(0)); - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; - - let dc0 = Arc::new( - DataChannel::dial( - &a0, - 100, - Config { - label: "data".to_owned(), - ..Default::default() - }, - ) - .await?, - ); - bridge_process_at_least_one(&br).await; - - let existing_data_channels: Vec = Vec::new(); - let dc1 = Arc::new(DataChannel::accept(&a1, Config::default(), &existing_data_channels).await?); - bridge_process_at_least_one(&br).await; - - while dc0.buffered_amount() > 0 { - bridge_process_at_least_one(&br).await; - } - - let n = dc0.write(&Bytes::new()).await?; - assert_eq!(n, 0, "data length should match"); - assert_eq!(dc0.buffered_amount(), 1, "incorrect bufferedAmount"); - - let n = dc0.write(&Bytes::from_static(&[0])).await?; - assert_eq!(n, 1, "data length should match"); - assert_eq!(dc0.buffered_amount(), 2, "incorrect bufferedAmount"); - - bridge_process_at_least_one(&br).await; - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(n, 0, "received length should match"); - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(n, 1, "received length should match"); - - dc0.set_buffered_amount_low_threshold(1500); - assert_eq!( - dc0.buffered_amount_low_threshold(), - 1500, - "incorrect bufferedAmountLowThreshold" - ); - let n_cbs2 = Arc::clone(&n_cbs); - dc0.on_buffered_amount_low(Box::new(move || { - n_cbs2.fetch_add(1, Ordering::SeqCst); - Box::pin(async {}) - })); - - // Write 10 1000-byte packets (total 10,000 bytes) - for i in 0..10 { - let n = dc0.write(&Bytes::from(sbuf.clone())).await?; - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - sbuf.len() * (i + 1) + 2, - dc0.buffered_amount(), - "incorrect bufferedAmount" - ); - } - - let dc1_cloned = Arc::clone(&dc1); - tokio::spawn(async move { - while let Ok(n) = dc1_cloned.read(&mut rbuf[..]).await { - if n == 0 { - break; - } - assert_eq!(n, rbuf.len(), "received length should match"); - } - }); - - let since = tokio::time::Instant::now(); - loop { - br.tick().await; - tokio::time::sleep(Duration::from_millis(10)).await; - if tokio::time::Instant::now().duration_since(since) > Duration::from_millis(500) { - break; - } - } - - dc0.close().await?; - dc1.close().await?; - bridge_process_at_least_one(&br).await; - - assert!( - n_cbs.load(Ordering::SeqCst) > 0, - "should make at least one callback" - ); - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//TODO: remove this conditional test -#[cfg(not(any(target_os = "macos", target_os = "windows")))] // this times out in CI on windows. -#[tokio::test] -async fn test_stats() -> Result<()> { - let sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 1500]; - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; - - let cfg = Config { - channel_type: ChannelType::Reliable, - reliability_parameter: 123, - label: "data".to_owned(), - ..Default::default() - }; - - let dc0 = DataChannel::dial(&a0, 100, cfg.clone()).await?; - bridge_process_at_least_one(&br).await; - - let existing_data_channels: Vec = Vec::new(); - let dc1 = DataChannel::accept(&a1, Config::default(), &existing_data_channels).await?; - bridge_process_at_least_one(&br).await; - - let mut bytes_sent = 0; - - let n = dc0.write(&Bytes::from(sbuf.clone())).await?; - assert_eq!(n, sbuf.len(), "data length should match"); - bytes_sent += n; - - assert_eq!(dc0.bytes_sent(), bytes_sent); - assert_eq!(dc0.messages_sent(), 1); - - let n = dc0.write(&Bytes::from(sbuf.clone())).await?; - assert_eq!(n, sbuf.len(), "data length should match"); - bytes_sent += n; - - assert_eq!(dc0.bytes_sent(), bytes_sent); - assert_eq!(dc0.messages_sent(), 2); - - let n = dc0.write(&Bytes::from_static(&[0])).await?; - assert_eq!(n, 1, "data length should match"); - bytes_sent += n; - - assert_eq!(dc0.bytes_sent(), bytes_sent); - assert_eq!(dc0.messages_sent(), 3); - - let n = dc0.write(&Bytes::from_static(&[])).await?; - assert_eq!(n, 0, "data length should match"); - bytes_sent += n; - - assert_eq!(dc0.bytes_sent(), bytes_sent); - assert_eq!(dc0.messages_sent(), 4); - - bridge_process_at_least_one(&br).await; - - let mut bytes_read = 0; - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(n, sbuf.len(), "data length should match"); - bytes_read += n; - - assert_eq!(dc1.bytes_received(), bytes_read); - assert_eq!(dc1.messages_received(), 1); - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(n, sbuf.len(), "data length should match"); - bytes_read += n; - - assert_eq!(dc1.bytes_received(), bytes_read); - assert_eq!(dc1.messages_received(), 2); - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(n, 1, "data length should match"); - bytes_read += n; - - assert_eq!(dc1.bytes_received(), bytes_read); - assert_eq!(dc1.messages_received(), 3); - - let n = dc1.read(&mut rbuf[..]).await?; - assert_eq!(n, 0, "data length should match"); - bytes_read += n; - - assert_eq!(dc1.bytes_received(), bytes_read); - assert_eq!(dc1.messages_received(), 4); - - dc0.close().await?; - dc1.close().await?; - bridge_process_at_least_one(&br).await; - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -#[tokio::test] -async fn test_poll_data_channel() -> Result<()> { - let mut sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 1500]; - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, a1) = create_new_association_pair(&br, Arc::new(ca), Arc::new(cb)).await?; - - let cfg = Config { - channel_type: ChannelType::Reliable, - reliability_parameter: 123, - label: "data".to_string(), - ..Default::default() - }; - - let dc0 = Arc::new(DataChannel::dial(&a0, 100, cfg.clone()).await?); - bridge_process_at_least_one(&br).await; - - let existing_data_channels: Vec = Vec::new(); - let dc1 = Arc::new(DataChannel::accept(&a1, Config::default(), &existing_data_channels).await?); - bridge_process_at_least_one(&br).await; - - let mut poll_dc0 = PollDataChannel::new(dc0); - let mut poll_dc1 = PollDataChannel::new(dc1); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = poll_dc0 - .write(&Bytes::from(sbuf.clone())) - .await - .map_err(|e| Error::new(e.to_string()))?; - assert_eq!(sbuf.len(), n, "data length should match"); - - bridge_process_at_least_one(&br).await; - - let n = poll_dc1 - .read(&mut rbuf[..]) - .await - .map_err(|e| Error::new(e.to_string()))?; - assert_eq!(sbuf.len(), n, "data length should match"); - assert_eq!( - 1, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "data should match" - ); - - poll_dc0.into_inner().close().await?; - poll_dc1.into_inner().close().await?; - bridge_process_at_least_one(&br).await; - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} diff --git a/data/src/data_channel/mod.rs b/data/src/data_channel/mod.rs deleted file mode 100644 index 17f12b9e8..000000000 --- a/data/src/data_channel/mod.rs +++ /dev/null @@ -1,682 +0,0 @@ -#[cfg(test)] -mod data_channel_test; - -use std::borrow::Borrow; -use std::future::Future; -use std::net::Shutdown; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{fmt, io}; - -use bytes::{Buf, Bytes}; -use portable_atomic::AtomicUsize; -use sctp::association::Association; -use sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier; -use sctp::stream::*; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use util::marshal::*; - -use crate::error::{Error, Result}; -use crate::message::message_channel_ack::*; -use crate::message::message_channel_open::*; -use crate::message::*; - -const RECEIVE_MTU: usize = 8192; - -/// Config is used to configure the data channel. -#[derive(Eq, PartialEq, Default, Clone, Debug)] -pub struct Config { - pub channel_type: ChannelType, - pub negotiated: bool, - pub priority: u16, - pub reliability_parameter: u32, - pub label: String, - pub protocol: String, -} - -/// DataChannel represents a data channel -#[derive(Debug, Default, Clone)] -pub struct DataChannel { - pub config: Config, - stream: Arc, - - // stats - messages_sent: Arc, - messages_received: Arc, - bytes_sent: Arc, - bytes_received: Arc, -} - -impl DataChannel { - pub fn new(stream: Arc, config: Config) -> Self { - Self { - config, - stream, - ..Default::default() - } - } - - /// Dial opens a data channels over SCTP - pub async fn dial( - association: &Arc, - identifier: u16, - config: Config, - ) -> Result { - let stream = association - .open_stream(identifier, PayloadProtocolIdentifier::Binary) - .await?; - - Self::client(stream, config).await - } - - /// Accept is used to accept incoming data channels over SCTP - pub async fn accept( - association: &Arc, - config: Config, - existing_channels: &[T], - ) -> Result - where - T: Borrow, - { - let stream = association - .accept_stream() - .await - .ok_or(Error::ErrStreamClosed)?; - - for channel in existing_channels.iter().map(|ch| ch.borrow()) { - if channel.stream_identifier() == stream.stream_identifier() { - let ch = channel.to_owned(); - ch.stream - .set_default_payload_type(PayloadProtocolIdentifier::Binary); - return Ok(ch); - } - } - - stream.set_default_payload_type(PayloadProtocolIdentifier::Binary); - - Self::server(stream, config).await - } - - /// Client opens a data channel over an SCTP stream - pub async fn client(stream: Arc, config: Config) -> Result { - if !config.negotiated { - let msg = Message::DataChannelOpen(DataChannelOpen { - channel_type: config.channel_type, - priority: config.priority, - reliability_parameter: config.reliability_parameter, - label: config.label.bytes().collect(), - protocol: config.protocol.bytes().collect(), - }) - .marshal()?; - - stream - .write_sctp(&msg, PayloadProtocolIdentifier::Dcep) - .await?; - } - Ok(DataChannel::new(stream, config)) - } - - /// Server accepts a data channel over an SCTP stream - pub async fn server(stream: Arc, mut config: Config) -> Result { - let mut buf = vec![0u8; RECEIVE_MTU]; - - let (n, ppi) = stream.read_sctp(&mut buf).await?; - - if ppi != PayloadProtocolIdentifier::Dcep { - return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8)); - } - - let mut read_buf = &buf[..n]; - let msg = Message::unmarshal(&mut read_buf)?; - - if let Message::DataChannelOpen(dco) = msg { - config.channel_type = dco.channel_type; - config.priority = dco.priority; - config.reliability_parameter = dco.reliability_parameter; - config.label = String::from_utf8(dco.label)?; - config.protocol = String::from_utf8(dco.protocol)?; - } else { - return Err(Error::InvalidMessageType(msg.message_type() as u8)); - }; - - let data_channel = DataChannel::new(stream, config); - - data_channel.write_data_channel_ack().await?; - data_channel.commit_reliability_params(); - - Ok(data_channel) - } - - /// Read reads a packet of len(p) bytes as binary data. - /// - /// See [`sctp::stream::Stream::read_sctp`]. - pub async fn read(&self, buf: &mut [u8]) -> Result { - self.read_data_channel(buf).await.map(|(n, _)| n) - } - - /// ReadDataChannel reads a packet of len(p) bytes. It returns the number of bytes read and - /// `true` if the data read is a string. - /// - /// See [`sctp::stream::Stream::read_sctp`]. - pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> { - loop { - //TODO: add handling of cancel read_data_channel - let (mut n, ppi) = match self.stream.read_sctp(buf).await { - Ok((0, PayloadProtocolIdentifier::Unknown)) => { - // The incoming stream was reset or the reading half was shutdown - return Ok((0, false)); - } - Ok((n, ppi)) => (n, ppi), - Err(err) => { - // Shutdown the stream and send the reset request to the remote. - self.close().await?; - return Err(err.into()); - } - }; - - let mut is_string = false; - match ppi { - PayloadProtocolIdentifier::Dcep => { - let mut data = &buf[..n]; - match self.handle_dcep(&mut data).await { - Ok(()) => {} - Err(err) => { - log::error!("Failed to handle DCEP: {:?}", err); - } - } - continue; - } - PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => { - is_string = true; - } - _ => {} - }; - - match ppi { - PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => { - n = 0; - } - _ => {} - }; - - self.messages_received.fetch_add(1, Ordering::SeqCst); - self.bytes_received.fetch_add(n, Ordering::SeqCst); - - return Ok((n, is_string)); - } - } - - /// MessagesSent returns the number of messages sent - pub fn messages_sent(&self) -> usize { - self.messages_sent.load(Ordering::SeqCst) - } - - /// MessagesReceived returns the number of messages received - pub fn messages_received(&self) -> usize { - self.messages_received.load(Ordering::SeqCst) - } - - /// BytesSent returns the number of bytes sent - pub fn bytes_sent(&self) -> usize { - self.bytes_sent.load(Ordering::SeqCst) - } - - /// BytesReceived returns the number of bytes received - pub fn bytes_received(&self) -> usize { - self.bytes_received.load(Ordering::SeqCst) - } - - /// StreamIdentifier returns the Stream identifier associated to the stream. - pub fn stream_identifier(&self) -> u16 { - self.stream.stream_identifier() - } - - async fn handle_dcep(&self, data: &mut B) -> Result<()> - where - B: Buf, - { - let msg = Message::unmarshal(data)?; - - match msg { - Message::DataChannelOpen(_) => { - // Note: DATA_CHANNEL_OPEN message is handled inside Server() method. - // Therefore, the message will not reach here. - log::debug!("Received DATA_CHANNEL_OPEN"); - let _ = self.write_data_channel_ack().await?; - } - Message::DataChannelAck(_) => { - log::debug!("Received DATA_CHANNEL_ACK"); - self.commit_reliability_params(); - } - }; - - Ok(()) - } - - /// Write writes len(p) bytes from p as binary data - pub async fn write(&self, data: &Bytes) -> Result { - self.write_data_channel(data, false).await - } - - /// WriteDataChannel writes len(p) bytes from p - pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result { - let data_len = data.len(); - - // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6 - // SCTP does not support the sending of empty user messages. Therefore, - // if an empty message has to be sent, the appropriate PPID (WebRTC - // String Empty or WebRTC Binary Empty) is used and the SCTP user - // message of one zero byte is sent. When receiving an SCTP user - // message with one of these PPIDs, the receiver MUST ignore the SCTP - // user message and process it as an empty message. - let ppi = match (is_string, data_len) { - (false, 0) => PayloadProtocolIdentifier::BinaryEmpty, - (false, _) => PayloadProtocolIdentifier::Binary, - (true, 0) => PayloadProtocolIdentifier::StringEmpty, - (true, _) => PayloadProtocolIdentifier::String, - }; - - let n = if data_len == 0 { - let _ = self - .stream - .write_sctp(&Bytes::from_static(&[0]), ppi) - .await?; - 0 - } else { - let n = self.stream.write_sctp(data, ppi).await?; - self.bytes_sent.fetch_add(n, Ordering::SeqCst); - n - }; - - self.messages_sent.fetch_add(1, Ordering::SeqCst); - Ok(n) - } - - async fn write_data_channel_ack(&self) -> Result { - let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?; - Ok(self - .stream - .write_sctp(&ack, PayloadProtocolIdentifier::Dcep) - .await?) - } - - /// Close closes the DataChannel and the underlying SCTP stream. - pub async fn close(&self) -> Result<()> { - // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 - // Closing of a data channel MUST be signaled by resetting the - // corresponding outgoing streams [RFC6525]. This means that if one - // side decides to close the data channel, it resets the corresponding - // outgoing stream. When the peer sees that an incoming stream was - // reset, it also resets its corresponding outgoing stream. Once this - // is completed, the data channel is closed. Resetting a stream sets - // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with - // a corresponding notification to the application layer that the reset - // has been performed. Streams are available for reuse after a reset - // has been performed. - Ok(self.stream.shutdown(Shutdown::Both).await?) - } - - /// BufferedAmount returns the number of bytes of data currently queued to be - /// sent over this stream. - pub fn buffered_amount(&self) -> usize { - self.stream.buffered_amount() - } - - /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing - /// data that is considered "low." Defaults to 0. - pub fn buffered_amount_low_threshold(&self) -> usize { - self.stream.buffered_amount_low_threshold() - } - - /// SetBufferedAmountLowThreshold is used to update the threshold. - /// See BufferedAmountLowThreshold(). - pub fn set_buffered_amount_low_threshold(&self, threshold: usize) { - self.stream.set_buffered_amount_low_threshold(threshold) - } - - /// OnBufferedAmountLow sets the callback handler which would be called when the - /// number of bytes of outgoing data buffered is lower than the threshold. - pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { - self.stream.on_buffered_amount_low(f) - } - - fn commit_reliability_params(&self) { - let (unordered, reliability_type) = match self.config.channel_type { - ChannelType::Reliable => (false, ReliabilityType::Reliable), - ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable), - ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit), - ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit), - ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed), - ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed), - }; - - self.stream.set_reliability_params( - unordered, - reliability_type, - self.config.reliability_parameter, - ); - } -} - -/// Default capacity of the temporary read buffer used by [`PollStream`]. -const DEFAULT_READ_BUF_SIZE: usize = 8192; - -/// State of the read `Future` in [`PollStream`]. -enum ReadFut { - /// Nothing in progress. - Idle, - /// Reading data from the underlying stream. - Reading(Pin>> + Send>>), - /// Finished reading, but there's unread data in the temporary buffer. - RemainingData(Vec), -} - -impl ReadFut { - /// Gets a mutable reference to the future stored inside `Reading(future)`. - /// - /// # Panics - /// - /// Panics if `ReadFut` variant is not `Reading`. - fn get_reading_mut(&mut self) -> &mut Pin>> + Send>> { - match self { - ReadFut::Reading(ref mut fut) => fut, - _ => panic!("expected ReadFut to be Reading"), - } - } -} - -/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and -/// [`AsyncWrite`]. -/// -/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an -/// additional overhead. -pub struct PollDataChannel { - data_channel: Arc, - - read_fut: ReadFut, - write_fut: Option> + Send>>>, - shutdown_fut: Option> + Send>>>, - - read_buf_cap: usize, -} - -impl PollDataChannel { - /// Constructs a new `PollDataChannel`. - /// - /// # Examples - /// - /// ``` - /// use webrtc_data::data_channel::{DataChannel, PollDataChannel, Config}; - /// use sctp::stream::Stream; - /// use std::sync::Arc; - /// - /// let dc = Arc::new(DataChannel::new(Arc::new(Stream::default()), Config::default())); - /// let poll_dc = PollDataChannel::new(dc); - /// ``` - pub fn new(data_channel: Arc) -> Self { - Self { - data_channel, - read_fut: ReadFut::Idle, - write_fut: None, - shutdown_fut: None, - read_buf_cap: DEFAULT_READ_BUF_SIZE, - } - } - - /// Get back the inner data_channel. - pub fn into_inner(self) -> Arc { - self.data_channel - } - - /// Obtain a clone of the inner data_channel. - pub fn clone_inner(&self) -> Arc { - self.data_channel.clone() - } - - /// MessagesSent returns the number of messages sent - pub fn messages_sent(&self) -> usize { - self.data_channel.messages_sent() - } - - /// MessagesReceived returns the number of messages received - pub fn messages_received(&self) -> usize { - self.data_channel.messages_received() - } - - /// BytesSent returns the number of bytes sent - pub fn bytes_sent(&self) -> usize { - self.data_channel.bytes_sent() - } - - /// BytesReceived returns the number of bytes received - pub fn bytes_received(&self) -> usize { - self.data_channel.bytes_received() - } - - /// StreamIdentifier returns the Stream identifier associated to the stream. - pub fn stream_identifier(&self) -> u16 { - self.data_channel.stream_identifier() - } - - /// BufferedAmount returns the number of bytes of data currently queued to be - /// sent over this stream. - pub fn buffered_amount(&self) -> usize { - self.data_channel.buffered_amount() - } - - /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing - /// data that is considered "low." Defaults to 0. - pub fn buffered_amount_low_threshold(&self) -> usize { - self.data_channel.buffered_amount_low_threshold() - } - - /// Set the capacity of the temporary read buffer (default: 8192). - pub fn set_read_buf_capacity(&mut self, capacity: usize) { - self.read_buf_cap = capacity - } -} - -impl AsyncRead for PollDataChannel { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if buf.remaining() == 0 { - return Poll::Ready(Ok(())); - } - - let fut = match self.read_fut { - ReadFut::Idle => { - // read into a temporary buffer because `buf` has an unonymous lifetime, which can - // be shorter than the lifetime of `read_fut`. - let data_channel = self.data_channel.clone(); - let mut temp_buf = vec![0; self.read_buf_cap]; - self.read_fut = ReadFut::Reading(Box::pin(async move { - data_channel.read(temp_buf.as_mut_slice()).await.map(|n| { - temp_buf.truncate(n); - temp_buf - }) - })); - self.read_fut.get_reading_mut() - } - ReadFut::Reading(ref mut fut) => fut, - ReadFut::RemainingData(ref mut data) => { - let remaining = buf.remaining(); - let len = std::cmp::min(data.len(), remaining); - buf.put_slice(&data[..len]); - if data.len() > remaining { - // ReadFut remains to be RemainingData - data.drain(..len); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(())); - } - }; - - loop { - match fut.as_mut().poll(cx) { - Poll::Pending => return Poll::Pending, - // retry immediately upon empty data or incomplete chunks - // since there's no way to setup a waker. - Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {} - // EOF has been reached => don't touch buf and just return Ok - Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Ok(())); - } - Poll::Ready(Err(e)) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Err(e.into())); - } - Poll::Ready(Ok(mut temp_buf)) => { - let remaining = buf.remaining(); - let len = std::cmp::min(temp_buf.len(), remaining); - buf.put_slice(&temp_buf[..len]); - if temp_buf.len() > remaining { - temp_buf.drain(..len); - self.read_fut = ReadFut::RemainingData(temp_buf); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(())); - } - } - } - } -} - -impl AsyncWrite for PollDataChannel { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - if let Some(fut) = self.write_fut.as_mut() { - match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let data_channel = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - self.write_fut = - Some(Box::pin(async move { data_channel.write(&bytes).await })); - Poll::Ready(Err(e.into())) - } - // Given the data is buffered, it's okay to ignore the number of written bytes. - // - // TODO: In the long term, `data_channel.write` should be made sync. Then we could - // remove the whole `if` condition and just call `data_channel.write`. - Poll::Ready(Ok(_)) => { - let data_channel = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - self.write_fut = - Some(Box::pin(async move { data_channel.write(&bytes).await })); - Poll::Ready(Ok(buf.len())) - } - } - } else { - let data_channel = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - let fut = self - .write_fut - .insert(Box::pin(async move { data_channel.write(&bytes).await })); - - match fut.as_mut().poll(cx) { - // If it's the first time we're polling the future, `Poll::Pending` can't be - // returned because that would mean the `PollDataChannel` is not ready for writing. - // And this is not true since we've just created a future, which is going to write - // the buf to the underlying stream. - // - // It's okay to return `Poll::Ready` if the data is buffered (this is what the - // buffered writer and `File` do). - Poll::Pending => Poll::Ready(Ok(buf.len())), - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(n)) => { - self.write_fut = None; - Poll::Ready(Ok(n)) - } - } - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.write_fut.as_mut() { - Some(fut) => match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(_)) => { - self.write_fut = None; - Poll::Ready(Ok(())) - } - }, - None => Poll::Ready(Ok(())), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.as_mut().poll_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(_) => {} - } - - let fut = match self.shutdown_fut.as_mut() { - Some(fut) => fut, - None => { - let data_channel = self.data_channel.clone(); - self.shutdown_fut.get_or_insert(Box::pin(async move { - data_channel - .stream - .shutdown(Shutdown::Write) - .await - .map_err(Error::Sctp) - })) - } - }; - - match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - self.shutdown_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(_)) => { - self.shutdown_fut = None; - Poll::Ready(Ok(())) - } - } - } -} - -impl Clone for PollDataChannel { - fn clone(&self) -> PollDataChannel { - PollDataChannel::new(self.clone_inner()) - } -} - -impl fmt::Debug for PollDataChannel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PollDataChannel") - .field("data_channel", &self.data_channel) - .field("read_buf_cap", &self.read_buf_cap) - .finish() - } -} - -impl AsRef for PollDataChannel { - fn as_ref(&self) -> &DataChannel { - &self.data_channel - } -} diff --git a/data/src/error.rs b/data/src/error.rs deleted file mode 100644 index 4d6b1b84c..000000000 --- a/data/src/error.rs +++ /dev/null @@ -1,72 +0,0 @@ -use std::io; -use std::string::FromUtf8Error; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error( - "DataChannel message is not long enough to determine type: (expected: {expected}, actual: {actual})" - )] - UnexpectedEndOfBuffer { expected: usize, actual: usize }, - #[error("Unknown MessageType {0}")] - InvalidMessageType(u8), - #[error("Unknown ChannelType {0}")] - InvalidChannelType(u8), - #[error("Unknown PayloadProtocolIdentifier {0}")] - InvalidPayloadProtocolIdentifier(u8), - #[error("Stream closed")] - ErrStreamClosed, - - #[error("{0}")] - Util(#[from] util::Error), - #[error("{0}")] - Sctp(#[from] sctp::Error), - #[error("utf-8 error: {0}")] - Utf8(#[from] FromUtf8Error), - - #[allow(non_camel_case_types)] - #[error("{0}")] - new(String), -} - -impl From for util::Error { - fn from(e: Error) -> Self { - util::Error::from_std(e) - } -} - -impl From for io::Error { - fn from(error: Error) -> Self { - match error { - e @ Error::Sctp(sctp::Error::ErrEof) => { - io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()) - } - e @ Error::ErrStreamClosed => { - io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()) - } - e => io::Error::new(io::ErrorKind::Other, e.to_string()), - } - } -} - -impl PartialEq for Error { - fn eq(&self, other: &util::Error) -> bool { - if let Some(down) = other.downcast_ref::() { - return self == down; - } - false - } -} - -impl PartialEq for util::Error { - fn eq(&self, other: &Error) -> bool { - if let Some(down) = self.downcast_ref::() { - return other == down; - } - false - } -} diff --git a/data/src/lib.rs b/data/src/lib.rs deleted file mode 100644 index 7dff238a4..000000000 --- a/data/src/lib.rs +++ /dev/null @@ -1,8 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod data_channel; -mod error; -pub mod message; - -pub use error::Error; diff --git a/data/src/message/message_channel_ack.rs b/data/src/message/message_channel_ack.rs deleted file mode 100644 index dbe4796ca..000000000 --- a/data/src/message/message_channel_ack.rs +++ /dev/null @@ -1,77 +0,0 @@ -use super::*; - -type Result = std::result::Result; - -/// The data-part of an data-channel ACK message without the message type. -/// -/// # Memory layout -/// -/// ```plain -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Message Type | -///+-+-+-+-+-+-+-+-+ -/// ``` -#[derive(Eq, PartialEq, Clone, Debug)] -pub struct DataChannelAck; - -impl MarshalSize for DataChannelAck { - fn marshal_size(&self) -> usize { - 0 - } -} - -impl Marshal for DataChannelAck { - fn marshal_to(&self, _buf: &mut [u8]) -> Result { - Ok(0) - } -} - -impl Unmarshal for DataChannelAck { - fn unmarshal(_buf: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - Ok(Self) - } -} - -#[cfg(test)] -mod tests { - use bytes::{Bytes, BytesMut}; - - use super::*; - - #[test] - fn test_channel_ack_unmarshal() -> Result<()> { - let mut bytes = Bytes::from_static(&[]); - - let channel_ack = DataChannelAck::unmarshal(&mut bytes)?; - - assert_eq!(channel_ack, DataChannelAck); - Ok(()) - } - - #[test] - fn test_channel_ack_marshal_size() -> Result<()> { - let channel_ack = DataChannelAck; - let marshal_size = channel_ack.marshal_size(); - - assert_eq!(marshal_size, 0); - Ok(()) - } - - #[test] - fn test_channel_ack_marshal() -> Result<()> { - let channel_ack = DataChannelAck; - let mut buf = BytesMut::with_capacity(0); - let bytes_written = channel_ack.marshal_to(&mut buf)?; - let bytes = buf.freeze(); - - assert_eq!(bytes_written, channel_ack.marshal_size()); - assert_eq!(&bytes[..], &[]); - Ok(()) - } -} diff --git a/data/src/message/message_channel_open.rs b/data/src/message/message_channel_open.rs deleted file mode 100644 index a93a017d3..000000000 --- a/data/src/message/message_channel_open.rs +++ /dev/null @@ -1,441 +0,0 @@ -use super::*; -use crate::error::Error; - -type Result = std::result::Result; - -const CHANNEL_TYPE_RELIABLE: u8 = 0x00; -const CHANNEL_TYPE_RELIABLE_UNORDERED: u8 = 0x80; -const CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT: u8 = 0x01; -const CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED: u8 = 0x81; -const CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED: u8 = 0x02; -const CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED: u8 = 0x82; -const CHANNEL_TYPE_LEN: usize = 1; - -/// ChannelPriority -pub const CHANNEL_PRIORITY_BELOW_NORMAL: u16 = 128; -pub const CHANNEL_PRIORITY_NORMAL: u16 = 256; -pub const CHANNEL_PRIORITY_HIGH: u16 = 512; -pub const CHANNEL_PRIORITY_EXTRA_HIGH: u16 = 1024; - -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -pub enum ChannelType { - // `Reliable` determines the Data Channel provides a - // reliable in-order bi-directional communication. - Reliable, - // `ReliableUnordered` determines the Data Channel - // provides a reliable unordered bi-directional communication. - ReliableUnordered, - // `PartialReliableRexmit` determines the Data Channel - // provides a partially-reliable in-order bi-directional communication. - // User messages will not be retransmitted more times than specified in the Reliability Parameter. - PartialReliableRexmit, - // `PartialReliableRexmitUnordered` determines - // the Data Channel provides a partial reliable unordered bi-directional communication. - // User messages will not be retransmitted more times than specified in the Reliability Parameter. - PartialReliableRexmitUnordered, - // `PartialReliableTimed` determines the Data Channel - // provides a partial reliable in-order bi-directional communication. - // User messages might not be transmitted or retransmitted after - // a specified life-time given in milli- seconds in the Reliability Parameter. - // This life-time starts when providing the user message to the protocol stack. - PartialReliableTimed, - // The Data Channel provides a partial reliable unordered bi-directional - // communication. User messages might not be transmitted or retransmitted - // after a specified life-time given in milli- seconds in the Reliability Parameter. - // This life-time starts when providing the user message to the protocol stack. - PartialReliableTimedUnordered, -} - -impl Default for ChannelType { - fn default() -> Self { - Self::Reliable - } -} - -impl MarshalSize for ChannelType { - fn marshal_size(&self) -> usize { - CHANNEL_TYPE_LEN - } -} - -impl Marshal for ChannelType { - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - let required_len = self.marshal_size(); - if buf.remaining_mut() < required_len { - return Err(Error::UnexpectedEndOfBuffer { - expected: required_len, - actual: buf.remaining_mut(), - } - .into()); - } - - let byte = match self { - Self::Reliable => CHANNEL_TYPE_RELIABLE, - Self::ReliableUnordered => CHANNEL_TYPE_RELIABLE_UNORDERED, - Self::PartialReliableRexmit => CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT, - Self::PartialReliableRexmitUnordered => CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED, - Self::PartialReliableTimed => CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED, - Self::PartialReliableTimedUnordered => CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED, - }; - - buf.put_u8(byte); - - Ok(1) - } -} - -impl Unmarshal for ChannelType { - fn unmarshal(buf: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let required_len = CHANNEL_TYPE_LEN; - if buf.remaining() < required_len { - return Err(Error::UnexpectedEndOfBuffer { - expected: required_len, - actual: buf.remaining(), - } - .into()); - } - - let b0 = buf.get_u8(); - - match b0 { - CHANNEL_TYPE_RELIABLE => Ok(Self::Reliable), - CHANNEL_TYPE_RELIABLE_UNORDERED => Ok(Self::ReliableUnordered), - CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT => Ok(Self::PartialReliableRexmit), - CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED => { - Ok(Self::PartialReliableRexmitUnordered) - } - CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED => Ok(Self::PartialReliableTimed), - CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED => { - Ok(Self::PartialReliableTimedUnordered) - } - _ => Err(Error::InvalidChannelType(b0).into()), - } - } -} - -const CHANNEL_OPEN_HEADER_LEN: usize = 11; - -/// The data-part of an data-channel OPEN message without the message type. -/// -/// # Memory layout -/// -/// ```plain -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | (Message Type)| Channel Type | Priority | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Reliability Parameter | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Label Length | Protocol Length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | | -/// | Label | -/// | | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | | -/// | Protocol | -/// | | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// ``` -#[derive(Eq, PartialEq, Clone, Debug)] -pub struct DataChannelOpen { - pub channel_type: ChannelType, - pub priority: u16, - pub reliability_parameter: u32, - pub label: Vec, - pub protocol: Vec, -} - -impl MarshalSize for DataChannelOpen { - fn marshal_size(&self) -> usize { - let label_len = self.label.len(); - let protocol_len = self.protocol.len(); - - CHANNEL_OPEN_HEADER_LEN + label_len + protocol_len - } -} - -impl Marshal for DataChannelOpen { - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - let required_len = self.marshal_size(); - if buf.remaining_mut() < required_len { - return Err(Error::UnexpectedEndOfBuffer { - expected: required_len, - actual: buf.remaining_mut(), - } - .into()); - } - - let n = self.channel_type.marshal_to(buf)?; - buf = &mut buf[n..]; - buf.put_u16(self.priority); - buf.put_u32(self.reliability_parameter); - buf.put_u16(self.label.len() as u16); - buf.put_u16(self.protocol.len() as u16); - buf.put_slice(self.label.as_slice()); - buf.put_slice(self.protocol.as_slice()); - Ok(self.marshal_size()) - } -} - -impl Unmarshal for DataChannelOpen { - fn unmarshal(buf: &mut B) -> Result - where - B: Buf, - { - let required_len = CHANNEL_OPEN_HEADER_LEN; - if buf.remaining() < required_len { - return Err(Error::UnexpectedEndOfBuffer { - expected: required_len, - actual: buf.remaining(), - } - .into()); - } - - let channel_type = ChannelType::unmarshal(buf)?; - let priority = buf.get_u16(); - let reliability_parameter = buf.get_u32(); - let label_len = buf.get_u16() as usize; - let protocol_len = buf.get_u16() as usize; - - let required_len = label_len + protocol_len; - if buf.remaining() < required_len { - return Err(Error::UnexpectedEndOfBuffer { - expected: required_len, - actual: buf.remaining(), - } - .into()); - } - - let mut label = vec![0; label_len]; - let mut protocol = vec![0; protocol_len]; - - buf.copy_to_slice(&mut label[..]); - buf.copy_to_slice(&mut protocol[..]); - - Ok(Self { - channel_type, - priority, - reliability_parameter, - label, - protocol, - }) - } -} - -#[cfg(test)] -mod tests { - use bytes::{Bytes, BytesMut}; - - use super::*; - - #[test] - fn test_channel_type_unmarshal_success() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x00]); - let channel_type = ChannelType::unmarshal(&mut bytes)?; - - assert_eq!(channel_type, ChannelType::Reliable); - Ok(()) - } - - #[test] - fn test_channel_type_unmarshal_invalid() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x11]); - match ChannelType::unmarshal(&mut bytes) { - Ok(_) => panic!("expected Error, but got Ok"), - Err(err) => { - if let Some(&Error::InvalidChannelType(0x11)) = err.downcast_ref::() { - return Ok(()); - } - panic!( - "unexpected err {:?}, want {:?}", - err, - Error::InvalidMessageType(0x01) - ); - } - } - } - - #[test] - fn test_channel_type_unmarshal_unexpected_end_of_buffer() -> Result<()> { - let mut bytes = Bytes::from_static(&[]); - match ChannelType::unmarshal(&mut bytes) { - Ok(_) => panic!("expected Error, but got Ok"), - Err(err) => { - if let Some(&Error::UnexpectedEndOfBuffer { - expected: 1, - actual: 0, - }) = err.downcast_ref::() - { - return Ok(()); - } - panic!( - "unexpected err {:?}, want {:?}", - err, - Error::InvalidMessageType(0x01) - ); - } - } - } - - #[test] - fn test_channel_type_marshal_size() -> Result<()> { - let channel_type = ChannelType::Reliable; - let marshal_size = channel_type.marshal_size(); - - assert_eq!(marshal_size, 1); - Ok(()) - } - - #[test] - fn test_channel_type_marshal() -> Result<()> { - let mut buf = BytesMut::with_capacity(1); - buf.resize(1, 0u8); - let channel_type = ChannelType::Reliable; - let bytes_written = channel_type.marshal_to(&mut buf)?; - assert_eq!(bytes_written, channel_type.marshal_size()); - - let bytes = buf.freeze(); - assert_eq!(&bytes[..], &[0x00]); - Ok(()) - } - - static MARSHALED_BYTES: [u8; 24] = [ - 0x00, // channel type - 0x0f, 0x35, // priority - 0x00, 0xff, 0x0f, 0x35, // reliability parameter - 0x00, 0x05, // label length - 0x00, 0x08, // protocol length - 0x6c, 0x61, 0x62, 0x65, 0x6c, // label - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, // protocol - ]; - - #[test] - fn test_channel_open_unmarshal_success() -> Result<()> { - let mut bytes = Bytes::from_static(&MARSHALED_BYTES); - - let channel_open = DataChannelOpen::unmarshal(&mut bytes)?; - - assert_eq!(channel_open.channel_type, ChannelType::Reliable); - assert_eq!(channel_open.priority, 3893); - assert_eq!(channel_open.reliability_parameter, 16715573); - assert_eq!(channel_open.label, b"label"); - assert_eq!(channel_open.protocol, b"protocol"); - Ok(()) - } - - #[test] - fn test_channel_open_unmarshal_invalid_channel_type() -> Result<()> { - let mut bytes = Bytes::from_static(&[ - 0x11, // channel type - 0x0f, 0x35, // priority - 0x00, 0xff, 0x0f, 0x35, // reliability parameter - 0x00, 0x05, // label length - 0x00, 0x08, // protocol length - ]); - match DataChannelOpen::unmarshal(&mut bytes) { - Ok(_) => panic!("expected Error, but got Ok"), - Err(err) => { - if let Some(&Error::InvalidChannelType(0x11)) = err.downcast_ref::() { - return Ok(()); - } - panic!( - "unexpected err {:?}, want {:?}", - err, - Error::InvalidMessageType(0x01) - ); - } - } - } - - #[test] - fn test_channel_open_unmarshal_unexpected_end_of_buffer() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x00; 5]); - match DataChannelOpen::unmarshal(&mut bytes) { - Ok(_) => panic!("expected Error, but got Ok"), - Err(err) => { - if let Some(&Error::UnexpectedEndOfBuffer { - expected: 11, - actual: 5, - }) = err.downcast_ref::() - { - return Ok(()); - } - panic!( - "unexpected err {:?}, want {:?}", - err, - Error::InvalidMessageType(0x01) - ); - } - } - } - - #[test] - fn test_channel_open_unmarshal_unexpected_length_mismatch() -> Result<()> { - let mut bytes = Bytes::from_static(&[ - 0x01, // channel type - 0x00, 0x00, // priority - 0x00, 0x00, 0x00, 0x00, // Reliability parameter - 0x00, 0x05, // Label length - 0x00, 0x08, // Protocol length - ]); - match DataChannelOpen::unmarshal(&mut bytes) { - Ok(_) => panic!("expected Error, but got Ok"), - Err(err) => { - if let Some(&Error::UnexpectedEndOfBuffer { - expected: 13, - actual: 0, - }) = err.downcast_ref::() - { - return Ok(()); - } - panic!( - "unexpected err {:?}, want {:?}", - err, - Error::InvalidMessageType(0x01) - ); - } - } - } - - #[test] - fn test_channel_open_marshal_size() -> Result<()> { - let channel_open = DataChannelOpen { - channel_type: ChannelType::Reliable, - priority: 3893, - reliability_parameter: 16715573, - label: b"label".to_vec(), - protocol: b"protocol".to_vec(), - }; - - let marshal_size = channel_open.marshal_size(); - - assert_eq!(marshal_size, 11 + 5 + 8); - Ok(()) - } - - #[test] - fn test_channel_open_marshal() -> Result<()> { - let channel_open = DataChannelOpen { - channel_type: ChannelType::Reliable, - priority: 3893, - reliability_parameter: 16715573, - label: b"label".to_vec(), - protocol: b"protocol".to_vec(), - }; - - let mut buf = BytesMut::with_capacity(11 + 5 + 8); - buf.resize(11 + 5 + 8, 0u8); - let bytes_written = channel_open.marshal_to(&mut buf).unwrap(); - let bytes = buf.freeze(); - - assert_eq!(bytes_written, channel_open.marshal_size()); - assert_eq!(&bytes[..], &MARSHALED_BYTES); - Ok(()) - } -} diff --git a/data/src/message/message_test.rs b/data/src/message/message_test.rs deleted file mode 100644 index 932c501bf..000000000 --- a/data/src/message/message_test.rs +++ /dev/null @@ -1,96 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::*; -use crate::error::Result; - -#[test] -fn test_message_unmarshal_open_success() { - let mut bytes = Bytes::from_static(&[ - 0x03, // message type - 0x00, // channel type - 0x0f, 0x35, // priority - 0x00, 0xff, 0x0f, 0x35, // reliability parameter - 0x00, 0x05, // label length - 0x00, 0x08, // protocol length - 0x6c, 0x61, 0x62, 0x65, 0x6c, // label - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, // protocol - ]); - - let actual = Message::unmarshal(&mut bytes).unwrap(); - - let expected = Message::DataChannelOpen(DataChannelOpen { - channel_type: ChannelType::Reliable, - priority: 3893, - reliability_parameter: 16715573, - label: b"label".to_vec(), - protocol: b"protocol".to_vec(), - }); - - assert_eq!(actual, expected); -} - -#[test] -fn test_message_unmarshal_ack_success() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x02]); - - let actual = Message::unmarshal(&mut bytes)?; - let expected = Message::DataChannelAck(DataChannelAck {}); - - assert_eq!(actual, expected); - - Ok(()) -} - -#[test] -fn test_message_unmarshal_invalid_message_type() { - let mut bytes = Bytes::from_static(&[0x01]); - let expected = Error::InvalidMessageType(0x01); - let result = Message::unmarshal(&mut bytes); - let actual = result.expect_err("expected err, but got ok"); - assert_eq!(actual, expected); -} - -#[test] -fn test_message_marshal_size() { - let msg = Message::DataChannelAck(DataChannelAck {}); - - let actual = msg.marshal_size(); - let expected = 1; - - assert_eq!(actual, expected); -} - -#[test] -fn test_message_marshal() { - let marshal_size = 12 + 5 + 8; - let mut buf = BytesMut::with_capacity(marshal_size); - buf.resize(marshal_size, 0u8); - - let msg = Message::DataChannelOpen(DataChannelOpen { - channel_type: ChannelType::Reliable, - priority: 3893, - reliability_parameter: 16715573, - label: b"label".to_vec(), - protocol: b"protocol".to_vec(), - }); - - let actual = msg.marshal_to(&mut buf).unwrap(); - let expected = marshal_size; - assert_eq!(actual, expected); - - let bytes = buf.freeze(); - - let actual = &bytes[..]; - let expected = &[ - 0x03, // message type - 0x00, // channel type - 0x0f, 0x35, // priority - 0x00, 0xff, 0x0f, 0x35, // reliability parameter - 0x00, 0x05, // label length - 0x00, 0x08, // protocol length - 0x6c, 0x61, 0x62, 0x65, 0x6c, // label - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, // protocol - ]; - - assert_eq!(actual, expected); -} diff --git a/data/src/message/message_type.rs b/data/src/message/message_type.rs deleted file mode 100644 index 0a62af6f6..000000000 --- a/data/src/message/message_type.rs +++ /dev/null @@ -1,125 +0,0 @@ -use super::*; -use crate::error::Error; - -// The first byte in a `Message` that specifies its type: -pub(crate) const MESSAGE_TYPE_ACK: u8 = 0x02; -pub(crate) const MESSAGE_TYPE_OPEN: u8 = 0x03; -pub(crate) const MESSAGE_TYPE_LEN: usize = 1; - -type Result = std::result::Result; - -// A parsed DataChannel message -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -pub enum MessageType { - DataChannelAck, - DataChannelOpen, -} - -impl MarshalSize for MessageType { - fn marshal_size(&self) -> usize { - MESSAGE_TYPE_LEN - } -} - -impl Marshal for MessageType { - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - let b = match self { - MessageType::DataChannelAck => MESSAGE_TYPE_ACK, - MessageType::DataChannelOpen => MESSAGE_TYPE_OPEN, - }; - - buf.put_u8(b); - - Ok(1) - } -} - -impl Unmarshal for MessageType { - fn unmarshal(buf: &mut B) -> Result - where - B: Buf, - { - let required_len = MESSAGE_TYPE_LEN; - if buf.remaining() < required_len { - return Err(Error::UnexpectedEndOfBuffer { - expected: required_len, - actual: buf.remaining(), - } - .into()); - } - - let b = buf.get_u8(); - - match b { - MESSAGE_TYPE_ACK => Ok(Self::DataChannelAck), - MESSAGE_TYPE_OPEN => Ok(Self::DataChannelOpen), - _ => Err(Error::InvalidMessageType(b).into()), - } - } -} - -#[cfg(test)] -mod tests { - use bytes::{Bytes, BytesMut}; - - use super::*; - - #[test] - fn test_message_type_unmarshal_open_success() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x03]); - let msg_type = MessageType::unmarshal(&mut bytes)?; - - assert_eq!(msg_type, MessageType::DataChannelOpen); - - Ok(()) - } - - #[test] - fn test_message_type_unmarshal_ack_success() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x02]); - let msg_type = MessageType::unmarshal(&mut bytes)?; - - assert_eq!(msg_type, MessageType::DataChannelAck); - Ok(()) - } - - #[test] - fn test_message_type_unmarshal_invalid() -> Result<()> { - let mut bytes = Bytes::from_static(&[0x01]); - match MessageType::unmarshal(&mut bytes) { - Ok(_) => panic!("expected Error, but got Ok"), - Err(err) => { - if let Some(&Error::InvalidMessageType(0x01)) = err.downcast_ref::() { - return Ok(()); - } - panic!( - "unexpected err {:?}, want {:?}", - err, - Error::InvalidMessageType(0x01) - ); - } - } - } - - #[test] - fn test_message_type_marshal_size() -> Result<()> { - let ack = MessageType::DataChannelAck; - let marshal_size = ack.marshal_size(); - - assert_eq!(marshal_size, MESSAGE_TYPE_LEN); - Ok(()) - } - - #[test] - fn test_message_type_marshal() -> Result<()> { - let mut buf = BytesMut::with_capacity(MESSAGE_TYPE_LEN); - buf.resize(MESSAGE_TYPE_LEN, 0u8); - let msg_type = MessageType::DataChannelAck; - let n = msg_type.marshal_to(&mut buf)?; - let bytes = buf.freeze(); - - assert_eq!(n, MESSAGE_TYPE_LEN); - assert_eq!(&bytes[..], &[0x02]); - Ok(()) - } -} diff --git a/data/src/message/mod.rs b/data/src/message/mod.rs deleted file mode 100644 index d9a450fe0..000000000 --- a/data/src/message/mod.rs +++ /dev/null @@ -1,76 +0,0 @@ -#[cfg(test)] -mod message_test; - -pub mod message_channel_ack; -pub mod message_channel_open; -pub mod message_type; - -use bytes::{Buf, BufMut}; -use message_channel_ack::*; -use message_channel_open::*; -use message_type::*; -use util::marshal::*; - -use crate::error::Error; - -/// A parsed DataChannel message -#[derive(Eq, PartialEq, Clone, Debug)] -pub enum Message { - DataChannelAck(DataChannelAck), - DataChannelOpen(DataChannelOpen), -} - -impl MarshalSize for Message { - fn marshal_size(&self) -> usize { - match self { - Message::DataChannelAck(m) => m.marshal_size() + MESSAGE_TYPE_LEN, - Message::DataChannelOpen(m) => m.marshal_size() + MESSAGE_TYPE_LEN, - } - } -} - -impl Marshal for Message { - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - let mut bytes_written = 0; - let n = self.message_type().marshal_to(buf)?; - buf = &mut buf[n..]; - bytes_written += n; - bytes_written += match self { - Message::DataChannelAck(_) => 0, - Message::DataChannelOpen(open) => open.marshal_to(buf)?, - }; - Ok(bytes_written) - } -} - -impl Unmarshal for Message { - fn unmarshal(buf: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if buf.remaining() < MESSAGE_TYPE_LEN { - return Err(Error::UnexpectedEndOfBuffer { - expected: MESSAGE_TYPE_LEN, - actual: buf.remaining(), - } - .into()); - } - - match MessageType::unmarshal(buf)? { - MessageType::DataChannelAck => Ok(Self::DataChannelAck(DataChannelAck {})), - MessageType::DataChannelOpen => { - Ok(Self::DataChannelOpen(DataChannelOpen::unmarshal(buf)?)) - } - } - } -} - -impl Message { - pub fn message_type(&self) -> MessageType { - match self { - Self::DataChannelAck(_) => MessageType::DataChannelAck, - Self::DataChannelOpen(_) => MessageType::DataChannelOpen, - } - } -} diff --git a/doc/AVStack.jpg b/doc/AVStack.jpg deleted file mode 100644 index 2c1d5d5c4..000000000 Binary files a/doc/AVStack.jpg and /dev/null differ diff --git a/doc/ChannelTalk_logo.png b/doc/ChannelTalk_logo.png deleted file mode 100644 index d9cd28428..000000000 Binary files a/doc/ChannelTalk_logo.png and /dev/null differ diff --git a/doc/KittyCAD.png b/doc/KittyCAD.png deleted file mode 100644 index dc46348a7..000000000 Binary files a/doc/KittyCAD.png and /dev/null differ diff --git a/doc/check.png b/doc/check.png deleted file mode 100644 index 17bc28eb4..000000000 Binary files a/doc/check.png and /dev/null differ diff --git a/doc/embark.jpg b/doc/embark.jpg deleted file mode 100644 index 3ae23a6be..000000000 Binary files a/doc/embark.jpg and /dev/null differ diff --git a/doc/parity.png b/doc/parity.png deleted file mode 100644 index 660a00437..000000000 Binary files a/doc/parity.png and /dev/null differ diff --git a/doc/uncheck.png b/doc/uncheck.png deleted file mode 100644 index d71f0f7e7..000000000 Binary files a/doc/uncheck.png and /dev/null differ diff --git a/doc/webrtc.rs.png b/doc/webrtc.rs.png deleted file mode 100644 index 3041e749c..000000000 Binary files a/doc/webrtc.rs.png and /dev/null differ diff --git a/doc/webrtc.rs.xcf b/doc/webrtc.rs.xcf deleted file mode 100644 index 8cc72dfe5..000000000 Binary files a/doc/webrtc.rs.xcf and /dev/null differ diff --git a/doc/webrtc_crab.png b/doc/webrtc_crab.png deleted file mode 100644 index 42f785d73..000000000 Binary files a/doc/webrtc_crab.png and /dev/null differ diff --git a/doc/webrtc_crates_dep_graph.odg b/doc/webrtc_crates_dep_graph.odg deleted file mode 100644 index c489207a0..000000000 Binary files a/doc/webrtc_crates_dep_graph.odg and /dev/null differ diff --git a/doc/webrtc_crates_dep_graph.png b/doc/webrtc_crates_dep_graph.png deleted file mode 100644 index c041bb705..000000000 Binary files a/doc/webrtc_crates_dep_graph.png and /dev/null differ diff --git a/doc/webrtc_stack.png b/doc/webrtc_stack.png deleted file mode 100644 index 9a9bc84df..000000000 Binary files a/doc/webrtc_stack.png and /dev/null differ diff --git a/dtls/.gitignore b/dtls/.gitignore deleted file mode 100644 index 9db22ce71..000000000 --- a/dtls/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/examples/hub/target -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/dtls/CHANGELOG.md b/dtls/CHANGELOG.md deleted file mode 100644 index 071897700..000000000 --- a/dtls/CHANGELOG.md +++ /dev/null @@ -1,26 +0,0 @@ -# webrtc-dtls changelog - -## Unreleased - -## v0.7.1 - -* Added support for insecure/deprecated signature verification algorithms [#342](https://github.com/webrtc-rs/webrtc/pull/342) by [@chuigda](https://github.com/chuigda). - -## v0.7.0 - -* Increased minimum support rust version to `1.60.0`. -* Add `RTCCertificate::from_pem` and `RTCCertificate::serialize_pem` (only work with `pem` feature enabled) [#333](https://github.com/webrtc-rs/webrtc/pull/333) - -### Breaking - -* Increased required `webrtc-util` version to `0.7.0`, with this change some methods in `DTLSConn` that implement `webrtc_util::Conn` have changed from async to sync. - -## v0.6.0 - -* [#254 [DTLS] Add NamedCurve::P384](https://github.com/webrtc-rs/webrtc/pull/254) contributed by [neonphog](https://github.com/neonphog) -* Increased min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). -* Increased serde's minimum version to 1.0.110 [#243 Fixes for cargo minimal-versions](https://github.com/webrtc-rs/webrtc/pull/243) contributed by [algesten](https://github.com/algesten) - -## Prior to 0.6.0 - -Before 0.6.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/dtls/releases). diff --git a/dtls/Cargo.toml b/dtls/Cargo.toml deleted file mode 100644 index 7f31126d7..000000000 --- a/dtls/Cargo.toml +++ /dev/null @@ -1,94 +0,0 @@ -[package] -name = "webrtc-dtls" -version = "0.10.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of DTLS" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-dtls" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/dtls" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["conn"] } - -byteorder = "1" -rand_core = "0.6" -hkdf = "0.12" -p256 = { version = "0.13", features = ["default", "ecdh", "ecdsa"] } -p384 = "0.13" -rand = "0.8" -hmac = "0.12" -sec1 = { version = "0.7", features = [ "std" ] } -sha1 = "0.10" -sha2 = "0.10" -aes = "0.8" -cbc = { version = "0.1", features = [ "block-padding", "alloc"] } -aes-gcm = "0.10" -ccm = "0.5" -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -async-trait = "0.1" -x25519-dalek = { version = "2", features = ["static_secrets"] } -x509-parser = "0.16" -der-parser = "9.0" -rcgen = "0.13" -ring = "0.17" -rustls = { version = "0.23", default-features = false, features = ["std", "ring"] } -bincode = "1" -serde = { version = "1", features = ["derive"] } -subtle = "2" -log = "0.4" -thiserror = "1" -pem = { version = "3", optional = true } -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" -env_logger = "0.10" -chrono = "0.4.28" -clap = "3" -hub = {path = "examples/hub"} - -[features] -pem = ["dep:pem"] - -[[example]] -name = "dial_psk" -path = "examples/dial/psk/dial_psk.rs" -bench = false - -[[example]] -name = "dial_selfsign" -path = "examples/dial/selfsign/dial_selfsign.rs" -bench = false - -[[example]] -name = "dial_verify" -path = "examples/dial/verify/dial_verify.rs" -bench = false - -[[example]] -name = "listen_psk" -path = "examples/listen/psk/listen_psk.rs" -bench = false - -[[example]] -name = "listen_selfsign" -path = "examples/listen/selfsign/listen_selfsign.rs" -bench = false - -[[example]] -name = "listen_verify" -path = "examples/listen/verify/listen_verify.rs" -bench = false diff --git a/dtls/LICENSE-APACHE b/dtls/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/dtls/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/dtls/LICENSE-MIT b/dtls/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/dtls/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/dtls/README.md b/dtls/README.md deleted file mode 100644 index 1ef5cb95c..000000000 --- a/dtls/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of DTLS. Rewrite Pion DTLS in Rust -

diff --git a/dtls/codecov.yml b/dtls/codecov.yml deleted file mode 100644 index a405022ef..000000000 --- a/dtls/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 14bc9fb7-bdcf-4355-8e0e-ebec14066ae5 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/dtls/doc/webrtc.rs.png b/dtls/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/dtls/doc/webrtc.rs.png and /dev/null differ diff --git a/dtls/examples/certificates/README.md b/dtls/examples/certificates/README.md deleted file mode 100644 index cd498e0ef..000000000 --- a/dtls/examples/certificates/README.md +++ /dev/null @@ -1,60 +0,0 @@ -# Certificates - -The certificates in for the examples are generated using the commands shown below. - -Note that this was run on OpenSSL 1.1.1d, of which the arguments can be found in the [OpenSSL Manpages](https://www.openssl.org/docs/man1.1.1/man1), and is not guaranteed to work on different OpenSSL versions. - -```shell -# Extensions required for certificate validation. -$ EXTFILE='extfile.conf' -$ echo 'subjectAltName = DNS:webrtc.rs' > "${EXTFILE}" - -# Server. -$ SERVER_NAME='server' -$ openssl ecparam -name prime256v1 -genkey -noout -out "${SERVER_NAME}.pem" -$ openssl req -key "${SERVER_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${SERVER_NAME}.csr" -$ openssl x509 -req -in "${SERVER_NAME}.csr" -extfile "${EXTFILE}" -days 365 -signkey "${SERVER_NAME}.pem" -sha256 -out "${SERVER_NAME}.pub.pem" - -# Client. -$ CLIENT_NAME='client' -$ openssl ecparam -name prime256v1 -genkey -noout -out "${CLIENT_NAME}.pem" -$ openssl req -key "${CLIENT_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${CLIENT_NAME}.csr" -$ openssl x509 -req -in "${CLIENT_NAME}.csr" -extfile "${EXTFILE}" -days 365 -CA "${SERVER_NAME}.pub.pem" -CAkey "${SERVER_NAME}.pem" -set_serial '0xabcd' -sha256 -out "${CLIENT_NAME}.pub.pem" - -# Cleanup. -$ rm "${EXTFILE}" "${SERVER_NAME}.csr" "${CLIENT_NAME}.csr" -``` - -## Converting EC private key to PKCS#8 in Rust - -`Cargo.toml`: - -```toml -[dependencies] -topk8 = "0.0.1" -``` - -`main.rs`: - -```rust -fn main() { - let ec_pem = " ------BEGIN EC PRIVATE KEY----- -MHcCAQEEIAL4r6d9lPq3XEDSZTL9l0D6thrPM7RiAhl3Fjuw9Ji2oAoGCCqGSM49 -AwEHoUQDQgAE4U64dviQRMujGK0g80dwzgjV7fnwLkj6RfvINMHvD6eiCsphWIlq -cddTAoOjXVQDu3qMAS1Ghfyk1F377EW1Sw== ------END EC PRIVATE KEY----- -"; - - let pkcs8_pem = topk8::from_sec1_pem(ec_pem).unwrap(); - - println!("{}", pkcs8_pem); - - // -----BEGIN PRIVATE KEY----- - // MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgAvivp32U+rdcQNJl - // Mv2XQPq2Gs8ztGICGXcWO7D0mLagCgYIKoZIzj0DAQehRANCAAThTrh2+JBEy6MY - // rSDzR3DOCNXt+fAuSPpF+8g0we8Pp6IKymFYiWpx11MCg6NdVAO7eowBLUaF/KTU - // XfvsRbVL - // -----END PRIVATE KEY----- -} -``` diff --git a/dtls/examples/certificates/client.csr b/dtls/examples/certificates/client.csr deleted file mode 100644 index 3a41d57b8..000000000 --- a/dtls/examples/certificates/client.csr +++ /dev/null @@ -1,7 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIHHMG8CAQAwDTELMAkGA1UEBhMCTkwwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC -AAQhwOK0F+5DQy5FG0dRdN0GF20p3MsTaBk73IpNzrlK+WtgxdVxmRm55LWCTgkA -RhnOcmzXW+raCEWQgTadaLd5oAAwCgYIKoZIzj0EAwIDSAAwRQIhANDZpyL2lr50 -Xr5DrD19SOa7LXpXz3DcM8RDLcBQvx05AiB7mbtcY6I18diHU0jSxHAGcUn5nAeD -EP4tqFOz7QRzgQ== ------END CERTIFICATE REQUEST----- diff --git a/dtls/examples/certificates/client.pem b/dtls/examples/certificates/client.pem deleted file mode 100644 index aa73a43a2..000000000 --- a/dtls/examples/certificates/client.pem +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MHcCAQEEIO7fb5dmM2P0F71o/Clo0ElO29ud+JbtA3fhDIL15AgioAoGCCqGSM49 -AwEHoUQDQgAEIcDitBfuQ0MuRRtHUXTdBhdtKdzLE2gZO9yKTc65SvlrYMXVcZkZ -ueS1gk4JAEYZznJs11vq2ghFkIE2nWi3eQ== ------END EC PRIVATE KEY----- diff --git a/dtls/examples/certificates/client.pem.private_key.pem b/dtls/examples/certificates/client.pem.private_key.pem deleted file mode 100644 index fbcbd476f..000000000 --- a/dtls/examples/certificates/client.pem.private_key.pem +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg7t9vl2YzY/QXvWj8 -KWjQSU7b2534lu0Dd+EMgvXkCCKhRANCAAQhwOK0F+5DQy5FG0dRdN0GF20p3MsT -aBk73IpNzrlK+WtgxdVxmRm55LWCTgkARhnOcmzXW+raCEWQgTadaLd5 ------END PRIVATE KEY----- diff --git a/dtls/examples/certificates/client.pub.pem b/dtls/examples/certificates/client.pub.pem deleted file mode 100644 index 688c9f05a..000000000 --- a/dtls/examples/certificates/client.pub.pem +++ /dev/null @@ -1,9 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBITCByaADAgECAgMAq80wCgYIKoZIzj0EAwIwDTELMAkGA1UEBhMCTkwwHhcN -MjEwOTE4MjAzNzE1WhcNMjIwOTE4MjAzNzE1WjANMQswCQYDVQQGEwJOTDBZMBMG -ByqGSM49AgEGCCqGSM49AwEHA0IABCHA4rQX7kNDLkUbR1F03QYXbSncyxNoGTvc -ik3OuUr5a2DF1XGZGbnktYJOCQBGGc5ybNdb6toIRZCBNp1ot3mjGDAWMBQGA1Ud -EQQNMAuCCXdlYnJ0Yy5yczAKBggqhkjOPQQDAgNHADBEAiA8mpJVfaCw+RwALmxN -XD28Ze3DUPomlfXhx+NGuePt5QIgAcRxvuDctyL07f8pQ5n22NOioNHdjwOjxww+ -ZekD+Lg= ------END CERTIFICATE----- diff --git a/dtls/examples/certificates/extfile.conf b/dtls/examples/certificates/extfile.conf deleted file mode 100644 index c4b7ccea7..000000000 --- a/dtls/examples/certificates/extfile.conf +++ /dev/null @@ -1 +0,0 @@ -subjectAltName = DNS:webrtc.rs diff --git a/dtls/examples/certificates/server.csr b/dtls/examples/certificates/server.csr deleted file mode 100644 index cf9a98384..000000000 --- a/dtls/examples/certificates/server.csr +++ /dev/null @@ -1,7 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIHHMG8CAQAwDTELMAkGA1UEBhMCTkwwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC -AASRyJbgbcieCbC1/HbiqmADkPxfk5Bmwjei2YXhPE+oYS3F5+df4BKNBgs7py7H -sxc768+6X8HmvYlfvk2kHXAVoAAwCgYIKoZIzj0EAwIDSAAwRQIhAKR9rI22Xk/U -L3xp2dzn7q3nyWqgDvp5uTflP4t0MBpJAiAJDKmcOCXNMhhgg4T2lhdfz/pZVfu5 -lxLcZm2ELiYImQ== ------END CERTIFICATE REQUEST----- diff --git a/dtls/examples/certificates/server.pem b/dtls/examples/certificates/server.pem deleted file mode 100644 index c6dd91395..000000000 --- a/dtls/examples/certificates/server.pem +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MHcCAQEEID358pSfZXZqwqURqBLvLYcqhOdZVVNR2toCMER39YHboAoGCCqGSM49 -AwEHoUQDQgAEkciW4G3Ingmwtfx24qpgA5D8X5OQZsI3otmF4TxPqGEtxefnX+AS -jQYLO6cux7MXO+vPul/B5r2JX75NpB1wFQ== ------END EC PRIVATE KEY----- diff --git a/dtls/examples/certificates/server.pem.private_key.pem b/dtls/examples/certificates/server.pem.private_key.pem deleted file mode 100644 index b5088e564..000000000 --- a/dtls/examples/certificates/server.pem.private_key.pem +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgPfnylJ9ldmrCpRGo -Eu8thyqE51lVU1Ha2gIwRHf1gduhRANCAASRyJbgbcieCbC1/HbiqmADkPxfk5Bm -wjei2YXhPE+oYS3F5+df4BKNBgs7py7Hsxc768+6X8HmvYlfvk2kHXAV ------END PRIVATE KEY----- diff --git a/dtls/examples/certificates/server.pub.pem b/dtls/examples/certificates/server.pub.pem deleted file mode 100644 index bfdd542be..000000000 --- a/dtls/examples/certificates/server.pub.pem +++ /dev/null @@ -1,9 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBMjCB2qADAgECAhQzsfVoH1cRfcxCY/drp/cKdjvBDDAKBggqhkjOPQQDAjAN -MQswCQYDVQQGEwJOTDAeFw0yMTA5MTgyMDM2NTVaFw0yMjA5MTgyMDM2NTVaMA0x -CzAJBgNVBAYTAk5MMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEkciW4G3Ingmw -tfx24qpgA5D8X5OQZsI3otmF4TxPqGEtxefnX+ASjQYLO6cux7MXO+vPul/B5r2J -X75NpB1wFaMYMBYwFAYDVR0RBA0wC4IJd2VicnRjLnJzMAoGCCqGSM49BAMCA0cA -MEQCIBZBGmNM3qig7OTMZLL4PYj4JrGMjIj/jZFHEhqeQn6HAiBpRte9WzCjJzZX -vzRkUKfCs1NMa/XR0hfdaa8KJAdKyQ== ------END CERTIFICATE----- diff --git a/dtls/examples/dial/psk/dial_psk.rs b/dtls/examples/dial/psk/dial_psk.rs deleted file mode 100644 index 9f73de915..000000000 --- a/dtls/examples/dial/psk/dial_psk.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use util::Conn; -use webrtc_dtls::cipher_suite::CipherSuiteId; -use webrtc_dtls::config::*; -use webrtc_dtls::conn::DTLSConn; -use webrtc_dtls::Error; - -// cargo run --example dial_psk -- --server 127.0.0.1:4444 - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("DTLS Client") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of DTLS Client") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("127.0.0.1:4444") - .long("server") - .help("DTLS Server name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - conn.connect(server).await?; - println!("connecting {server}.."); - - let config = Config { - psk: Some(Arc::new(|hint: &[u8]| -> Result, Error> { - println!("Server's hint: {}", String::from_utf8(hint.to_vec())?); - Ok(vec![0xAB, 0xC1, 0x23]) - })), - psk_identity_hint: Some("webrtc-rs DTLS Server".as_bytes().to_vec()), - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }; - let dtls_conn: Arc = - Arc::new(DTLSConn::new(conn, config, true, None).await?); - - println!("Connected; type 'exit' to shutdown gracefully"); - let _ = hub::utilities::chat(Arc::clone(&dtls_conn)).await; - - dtls_conn.close().await?; - - Ok(()) -} diff --git a/dtls/examples/dial/selfsign/dial_selfsign.rs b/dtls/examples/dial/selfsign/dial_selfsign.rs deleted file mode 100644 index 340a4635d..000000000 --- a/dtls/examples/dial/selfsign/dial_selfsign.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use util::Conn; -use webrtc_dtls::config::*; -use webrtc_dtls::conn::DTLSConn; -use webrtc_dtls::crypto::Certificate; -use webrtc_dtls::Error; - -// cargo run --example dial_selfsign -- --server 127.0.0.1:4444 - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("DTLS Client") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of DTLS Client") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("127.0.0.1:4444") - .long("server") - .help("DTLS Server name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - conn.connect(server).await?; - println!("connecting {server}.."); - - // Generate a certificate and private key to secure the connection - let certificate = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - - let config = Config { - certificates: vec![certificate], - insecure_skip_verify: true, - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }; - let dtls_conn: Arc = - Arc::new(DTLSConn::new(conn, config, true, None).await?); - - println!("Connected; type 'exit' to shutdown gracefully"); - let _ = hub::utilities::chat(Arc::clone(&dtls_conn)).await; - - dtls_conn.close().await?; - - Ok(()) -} diff --git a/dtls/examples/dial/verify/dial_verify.rs b/dtls/examples/dial/verify/dial_verify.rs deleted file mode 100644 index 2f79ef224..000000000 --- a/dtls/examples/dial/verify/dial_verify.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use hub::utilities::load_certificate; -use tokio::net::UdpSocket; -use util::Conn; -use webrtc_dtls::config::*; -use webrtc_dtls::conn::DTLSConn; -use webrtc_dtls::Error; - -// cargo run --example dial_verify -- --server 127.0.0.1:4444 - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("DTLS Client") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of DTLS Client") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("127.0.0.1:4444") - .long("server") - .help("DTLS Server name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - conn.connect(server).await?; - println!("connecting {server}.."); - - let certificate = hub::utilities::load_key_and_certificate( - "dtls/examples/certificates/client.pem.private_key.pem".into(), - "dtls/examples/certificates/client.pub.pem".into(), - )?; - - let mut cert_pool = rustls::RootCertStore::empty(); - let certs = load_certificate("dtls/examples/certificates/server.pub.pem".into())?; - for cert in &certs { - if cert_pool.add(cert.to_owned()).is_err() { - return Err(Error::Other("cert_pool add_pem_file failed".to_owned())); - } - } - - let config = Config { - certificates: vec![certificate], - extended_master_secret: ExtendedMasterSecretType::Require, - roots_cas: cert_pool, - server_name: "webrtc.rs".to_owned(), - ..Default::default() - }; - let dtls_conn: Arc = - Arc::new(DTLSConn::new(conn, config, true, None).await?); - - println!("Connected; type 'exit' to shutdown gracefully"); - let _ = hub::utilities::chat(Arc::clone(&dtls_conn)).await; - - dtls_conn.close().await?; - - Ok(()) -} diff --git a/dtls/examples/hub/Cargo.toml b/dtls/examples/hub/Cargo.toml deleted file mode 100644 index 72f3992e8..000000000 --- a/dtls/examples/hub/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "hub" -version = "0.1.0" -edition = "2021" - -[dependencies] -util = { path = "../../../util", package = "webrtc-util", default-features = false, features = [ - "conn" -] } -dtls = { package = "webrtc-dtls", path = "../../" } - -tokio = { version = "1.32.0", features = ["full"] } -rcgen = { version = "0.13", features = ["pem", "x509-parser"] } -rustls = { version = "0.23", default-features = false } -rustls-pemfile = "2" -thiserror = "1" diff --git a/dtls/examples/hub/src/lib.rs b/dtls/examples/hub/src/lib.rs deleted file mode 100644 index 2b07063a2..000000000 --- a/dtls/examples/hub/src/lib.rs +++ /dev/null @@ -1,112 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod utilities; - -use std::collections::HashMap; -use std::io::{BufRead, BufReader}; -use std::sync::Arc; - -use dtls::Error; -use tokio::sync::Mutex; -use util::Conn; - -const BUF_SIZE: usize = 8192; - -/// Hub is a helper to handle one to many chat -#[derive(Default)] -pub struct Hub { - conns: Arc>>>, -} - -impl Hub { - /// new builds a new hub - pub fn new() -> Self { - Hub { - conns: Arc::new(Mutex::new(HashMap::new())), - } - } - - /// register adds a new conn to the Hub - pub async fn register(&self, conn: Arc) { - println!("Connected to {}", conn.remote_addr().unwrap()); - - if let Some(remote_addr) = conn.remote_addr() { - let mut conns = self.conns.lock().await; - conns.insert(remote_addr.to_string(), Arc::clone(&conn)); - } - - let conns = Arc::clone(&self.conns); - tokio::spawn(async move { - let _ = Hub::read_loop(conns, conn).await; - }); - } - - async fn read_loop( - conns: Arc>>>, - conn: Arc, - ) -> Result<(), Error> { - let mut b = vec![0u8; BUF_SIZE]; - - while let Ok(n) = conn.recv(&mut b).await { - let msg = String::from_utf8(b[..n].to_vec())?; - print!("Got message: {msg}"); - } - - Hub::unregister(conns, conn).await - } - - async fn unregister( - conns: Arc>>>, - conn: Arc, - ) -> Result<(), Error> { - if let Some(remote_addr) = conn.remote_addr() { - { - let mut cs = conns.lock().await; - cs.remove(&remote_addr.to_string()); - } - - if let Err(err) = conn.close().await { - println!("Failed to disconnect: {remote_addr} with err {err}"); - } else { - println!("Disconnected: {remote_addr} "); - } - } - - Ok(()) - } - - async fn broadcast(&self, msg: &[u8]) { - let conns = self.conns.lock().await; - for conn in conns.values() { - if let Err(err) = conn.send(msg).await { - println!( - "Failed to write message to {:?}: {}", - conn.remote_addr(), - err - ); - } - } - } - - /// Chat starts the stdin readloop to dispatch messages to the hub - pub async fn chat(&self) { - let input = std::io::stdin(); - let mut reader = BufReader::new(input.lock()); - loop { - let mut msg = String::new(); - match reader.read_line(&mut msg) { - Ok(0) => return, - Err(err) => { - println!("stdin read err: {err}"); - return; - } - _ => {} - }; - if msg.trim() == "exit" { - return; - } - self.broadcast(msg.as_bytes()).await; - } - } -} diff --git a/dtls/examples/hub/src/utilities.rs b/dtls/examples/hub/src/utilities.rs deleted file mode 100644 index 912b74bc7..000000000 --- a/dtls/examples/hub/src/utilities.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::fs::File; -use std::io::{self, Read}; -use std::path::PathBuf; - -use dtls::crypto::{Certificate, CryptoPrivateKey}; -use rcgen::KeyPair; -use rustls::pki_types::CertificateDer; -use thiserror::Error; - -use super::*; - -#[derive(Debug, Error, PartialEq, Eq)] -pub enum Error { - #[error("block is not a private key, unable to load key")] - ErrBlockIsNotPrivateKey, - #[error("unknown key time in PKCS#8 wrapping, unable to load key")] - ErrUnknownKeyTime, - #[error("no private key found, unable to load key")] - ErrNoPrivateKeyFound, - #[error("block is not a certificate, unable to load certificates")] - ErrBlockIsNotCertificate, - #[error("no certificate found, unable to load certificates")] - ErrNoCertificateFound, - - #[error("{0}")] - Other(String), -} - -impl From for dtls::Error { - fn from(e: Error) -> Self { - dtls::Error::Other(e.to_string()) - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Other(e.to_string()) - } -} - -/// chat simulates a simple text chat session over the connection -pub async fn chat(conn: Arc) -> Result<(), Error> { - let conn_rx = Arc::clone(&conn); - tokio::spawn(async move { - let mut b = vec![0u8; BUF_SIZE]; - - while let Ok(n) = conn_rx.recv(&mut b).await { - let msg = String::from_utf8(b[..n].to_vec()).expect("utf8"); - print!("Got message: {msg}"); - } - - Result::<(), Error>::Ok(()) - }); - - let input = std::io::stdin(); - let mut reader = BufReader::new(input.lock()); - loop { - let mut msg = String::new(); - match reader.read_line(&mut msg) { - Ok(0) => return Ok(()), - Err(err) => { - println!("stdin read err: {err}"); - return Ok(()); - } - _ => {} - }; - if msg.trim() == "exit" { - return Ok(()); - } - - let _ = conn.send(msg.as_bytes()).await; - } -} - -/// load_key_and_certificate reads certificates or key from file -pub fn load_key_and_certificate( - key_path: PathBuf, - certificate_path: PathBuf, -) -> Result { - let private_key = load_key(key_path)?; - - let certificate = load_certificate(certificate_path)?; - - Ok(Certificate { - certificate, - private_key, - }) -} - -/// load_key Load/read key from file -pub fn load_key(path: PathBuf) -> Result { - let f = File::open(path)?; - let mut reader = BufReader::new(f); - let mut buf = vec![]; - reader.read_to_end(&mut buf)?; - - let s = String::from_utf8(buf).expect("utf8 of file"); - - let key_pair = KeyPair::from_pem(s.as_str()).expect("key pair in file"); - - Ok(CryptoPrivateKey::from_key_pair(&key_pair).expect("crypto key pair")) -} - -/// load_certificate Load/read certificate(s) from file -pub fn load_certificate(path: PathBuf) -> Result>, Error> { - let f = File::open(path)?; - - let mut reader = BufReader::new(f); - match rustls_pemfile::certs(&mut reader).collect::, _>>() { - Ok(certs) => Ok(certs.into_iter().map(CertificateDer::from).collect()), - Err(_) => Err(Error::ErrNoCertificateFound), - } -} diff --git a/dtls/examples/listen/psk/listen_psk.rs b/dtls/examples/listen/psk/listen_psk.rs deleted file mode 100644 index fc536bb6e..000000000 --- a/dtls/examples/listen/psk/listen_psk.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use util::conn::*; -use webrtc_dtls::cipher_suite::CipherSuiteId; -use webrtc_dtls::config::{Config, ExtendedMasterSecretType}; -use webrtc_dtls::listener::listen; -use webrtc_dtls::Error; - -// cargo run --example listen_psk -- --host 127.0.0.1:4444 - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("DTLS Server") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of DTLS Server") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("host") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("127.0.0.1:4444") - .long("host") - .help("DTLS host name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let host = matches.value_of("host").unwrap().to_owned(); - - let cfg = Config { - psk: Some(Arc::new(|hint: &[u8]| -> Result, Error> { - println!("Client's hint: {}", String::from_utf8(hint.to_vec())?); - Ok(vec![0xAB, 0xC1, 0x23]) - })), - psk_identity_hint: Some("webrtc-rs DTLS Client".as_bytes().to_vec()), - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }; - - println!("listening {host}...\ntype 'exit' to shutdown gracefully"); - - let listener = Arc::new(listen(host, cfg).await?); - - // Simulate a chat session - let h = Arc::new(hub::Hub::new()); - - let listener2 = Arc::clone(&listener); - let h2 = Arc::clone(&h); - tokio::spawn(async move { - while let Ok((dtls_conn, _remote_addr)) = listener2.accept().await { - // Register the connection with the chat hub - h2.register(dtls_conn).await; - } - }); - - h.chat().await; - - Ok(listener.close().await?) -} diff --git a/dtls/examples/listen/selfsign/listen_selfsign.rs b/dtls/examples/listen/selfsign/listen_selfsign.rs deleted file mode 100644 index 19a54fdc9..000000000 --- a/dtls/examples/listen/selfsign/listen_selfsign.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use util::conn::*; -use webrtc_dtls::config::{Config, ExtendedMasterSecretType}; -use webrtc_dtls::crypto::Certificate; -use webrtc_dtls::listener::listen; -use webrtc_dtls::Error; - -// cargo run --example listen_selfsign -- --host 127.0.0.1:4444 - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("DTLS Server") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of DTLS Server") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("host") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("127.0.0.1:4444") - .long("host") - .help("DTLS host name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let host = matches.value_of("host").unwrap().to_owned(); - - // Generate a certificate and private key to secure the connection - let certificate = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - - let cfg = Config { - certificates: vec![certificate], - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }; - - println!("listening {host}...\ntype 'exit' to shutdown gracefully"); - - let listener = Arc::new(listen(host, cfg).await?); - - // Simulate a chat session - let h = Arc::new(hub::Hub::new()); - - let listener2 = Arc::clone(&listener); - let h2 = Arc::clone(&h); - tokio::spawn(async move { - while let Ok((dtls_conn, _remote_addr)) = listener2.accept().await { - // Register the connection with the chat hub - h2.register(dtls_conn).await; - } - }); - - h.chat().await; - - Ok(listener.close().await?) -} diff --git a/dtls/examples/listen/verify/listen_verify.rs b/dtls/examples/listen/verify/listen_verify.rs deleted file mode 100644 index e63f7ea64..000000000 --- a/dtls/examples/listen/verify/listen_verify.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use hub::utilities::load_certificate; -use util::conn::*; -use webrtc_dtls::config::{ClientAuthType, Config, ExtendedMasterSecretType}; -use webrtc_dtls::listener::listen; -use webrtc_dtls::Error; - -// cargo run --example listen_verify -- --host 127.0.0.1:4444 - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("DTLS Server") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of DTLS Server") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("host") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("127.0.0.1:4444") - .long("host") - .help("DTLS host name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let host = matches.value_of("host").unwrap().to_owned(); - - let certificate = hub::utilities::load_key_and_certificate( - "dtls/examples/certificates/server.pem.private_key.pem".into(), - "dtls/examples/certificates/server.pub.pem".into(), - )?; - - let mut cert_pool = rustls::RootCertStore::empty(); - let certs = load_certificate("dtls/examples/certificates/server.pub.pem".into())?; - for cert in &certs { - if cert_pool.add(cert.to_owned()).is_err() { - return Err(Error::Other("cert_pool add_pem_file failed".to_owned())); - } - } - - let cfg = Config { - certificates: vec![certificate], - extended_master_secret: ExtendedMasterSecretType::Require, - client_auth: ClientAuthType::RequireAndVerifyClientCert, //RequireAnyClientCert, // - client_cas: cert_pool, - ..Default::default() - }; - - println!("listening {host}...\ntype 'exit' to shutdown gracefully"); - - let listener = Arc::new(listen(host, cfg).await?); - - // Simulate a chat session - let h = Arc::new(hub::Hub::new()); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let mut done_tx = Some(done_tx); - - let listener2 = Arc::clone(&listener); - let h2 = Arc::clone(&h); - tokio::spawn(async move { - loop { - tokio::select! { - _ = done_rx.recv() => { - break; - } - result = listener2.accept() => { - match result{ - Ok((dtls_conn, _)) => { - // Register the connection with the chat hub - h2.register(dtls_conn).await; - } - Err(err) => { - println!("connecting failed with error: {err}"); - } - } - } - } - } - }); - - h.chat().await; - - done_tx.take(); - - Ok(listener.close().await?) -} diff --git a/dtls/src/alert/alert_test.rs b/dtls/src/alert/alert_test.rs deleted file mode 100644 index ad9fff9fb..000000000 --- a/dtls/src/alert/alert_test.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; -use crate::error::Error; - -#[test] -fn test_alert() -> Result<()> { - let tests = vec![ - ( - "Valid Alert", - vec![0x02, 0x0A], - Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::UnexpectedMessage, - }, - None, - ), - ( - "Invalid alert length", - vec![0x00], - Alert { - alert_level: AlertLevel::Invalid, - alert_description: AlertDescription::Invalid, - }, - Some(Error::Other("io".to_owned())), - ), - ]; - - for (name, data, wanted, unmarshal_error) in tests { - let mut reader = BufReader::new(data.as_slice()); - let result = Alert::unmarshal(&mut reader); - - if let Some(err) = unmarshal_error { - assert!(result.is_err(), "{name} expected error: {err}"); - } else if let Ok(alert) = result { - assert_eq!(wanted, alert, "{name} expected {wanted}, but got {alert}"); - - let mut data2: Vec = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(data2.as_mut()); - alert.marshal(&mut writer)?; - } - assert_eq!(data, data2, "{name} expected {data:?}, but got {data2:?}"); - } else { - assert!(result.is_ok(), "{name} expected Ok, but has error"); - } - } - - Ok(()) -} diff --git a/dtls/src/alert/mod.rs b/dtls/src/alert/mod.rs deleted file mode 100644 index 197aec22a..000000000 --- a/dtls/src/alert/mod.rs +++ /dev/null @@ -1,185 +0,0 @@ -#[cfg(test)] -mod alert_test; - -use std::fmt; -use std::io::{Read, Write}; - -use byteorder::{ReadBytesExt, WriteBytesExt}; - -use super::content::*; -use crate::error::Result; - -#[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) enum AlertLevel { - Warning = 1, - Fatal = 2, - Invalid, -} - -impl fmt::Display for AlertLevel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - AlertLevel::Warning => write!(f, "LevelWarning"), - AlertLevel::Fatal => write!(f, "LevelFatal"), - _ => write!(f, "Invalid alert level"), - } - } -} - -impl From for AlertLevel { - fn from(val: u8) -> Self { - match val { - 1 => AlertLevel::Warning, - 2 => AlertLevel::Fatal, - _ => AlertLevel::Invalid, - } - } -} - -#[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) enum AlertDescription { - CloseNotify = 0, - UnexpectedMessage = 10, - BadRecordMac = 20, - DecryptionFailed = 21, - RecordOverflow = 22, - DecompressionFailure = 30, - HandshakeFailure = 40, - NoCertificate = 41, - BadCertificate = 42, - UnsupportedCertificate = 43, - CertificateRevoked = 44, - CertificateExpired = 45, - CertificateUnknown = 46, - IllegalParameter = 47, - UnknownCa = 48, - AccessDenied = 49, - DecodeError = 50, - DecryptError = 51, - ExportRestriction = 60, - ProtocolVersion = 70, - InsufficientSecurity = 71, - InternalError = 80, - UserCanceled = 90, - NoRenegotiation = 100, - UnsupportedExtension = 110, - UnknownPskIdentity = 115, - Invalid, -} - -impl fmt::Display for AlertDescription { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - AlertDescription::CloseNotify => write!(f, "CloseNotify"), - AlertDescription::UnexpectedMessage => write!(f, "UnexpectedMessage"), - AlertDescription::BadRecordMac => write!(f, "BadRecordMac"), - AlertDescription::DecryptionFailed => write!(f, "DecryptionFailed"), - AlertDescription::RecordOverflow => write!(f, "RecordOverflow"), - AlertDescription::DecompressionFailure => write!(f, "DecompressionFailure"), - AlertDescription::HandshakeFailure => write!(f, "HandshakeFailure"), - AlertDescription::NoCertificate => write!(f, "NoCertificate"), - AlertDescription::BadCertificate => write!(f, "BadCertificate"), - AlertDescription::UnsupportedCertificate => write!(f, "UnsupportedCertificate"), - AlertDescription::CertificateRevoked => write!(f, "CertificateRevoked"), - AlertDescription::CertificateExpired => write!(f, "CertificateExpired"), - AlertDescription::CertificateUnknown => write!(f, "CertificateUnknown"), - AlertDescription::IllegalParameter => write!(f, "IllegalParameter"), - AlertDescription::UnknownCa => write!(f, "UnknownCA"), - AlertDescription::AccessDenied => write!(f, "AccessDenied"), - AlertDescription::DecodeError => write!(f, "DecodeError"), - AlertDescription::DecryptError => write!(f, "DecryptError"), - AlertDescription::ExportRestriction => write!(f, "ExportRestriction"), - AlertDescription::ProtocolVersion => write!(f, "ProtocolVersion"), - AlertDescription::InsufficientSecurity => write!(f, "InsufficientSecurity"), - AlertDescription::InternalError => write!(f, "InternalError"), - AlertDescription::UserCanceled => write!(f, "UserCanceled"), - AlertDescription::NoRenegotiation => write!(f, "NoRenegotiation"), - AlertDescription::UnsupportedExtension => write!(f, "UnsupportedExtension"), - AlertDescription::UnknownPskIdentity => write!(f, "UnknownPskIdentity"), - _ => write!(f, "Invalid alert description"), - } - } -} - -impl From for AlertDescription { - fn from(val: u8) -> Self { - match val { - 0 => AlertDescription::CloseNotify, - 10 => AlertDescription::UnexpectedMessage, - 20 => AlertDescription::BadRecordMac, - 21 => AlertDescription::DecryptionFailed, - 22 => AlertDescription::RecordOverflow, - 30 => AlertDescription::DecompressionFailure, - 40 => AlertDescription::HandshakeFailure, - 41 => AlertDescription::NoCertificate, - 42 => AlertDescription::BadCertificate, - 43 => AlertDescription::UnsupportedCertificate, - 44 => AlertDescription::CertificateRevoked, - 45 => AlertDescription::CertificateExpired, - 46 => AlertDescription::CertificateUnknown, - 47 => AlertDescription::IllegalParameter, - 48 => AlertDescription::UnknownCa, - 49 => AlertDescription::AccessDenied, - 50 => AlertDescription::DecodeError, - 51 => AlertDescription::DecryptError, - 60 => AlertDescription::ExportRestriction, - 70 => AlertDescription::ProtocolVersion, - 71 => AlertDescription::InsufficientSecurity, - 80 => AlertDescription::InternalError, - 90 => AlertDescription::UserCanceled, - 100 => AlertDescription::NoRenegotiation, - 110 => AlertDescription::UnsupportedExtension, - 115 => AlertDescription::UnknownPskIdentity, - _ => AlertDescription::Invalid, - } - } -} - -// One of the content types supported by the TLS record layer is the -// alert type. Alert messages convey the severity of the message -// (warning or fatal) and a description of the alert. Alert messages -// with a level of fatal result in the immediate termination of the -// connection. In this case, other connections corresponding to the -// session may continue, but the session identifier MUST be invalidated, -// preventing the failed session from being used to establish new -// connections. Like other messages, alert messages are encrypted and -// compressed, as specified by the current connection state. -// https://tools.ietf.org/html/rfc5246#section-7.2 -#[derive(Copy, Clone, PartialEq, Debug)] -pub struct Alert { - pub(crate) alert_level: AlertLevel, - pub(crate) alert_description: AlertDescription, -} - -impl fmt::Display for Alert { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Alert {}: {}", self.alert_level, self.alert_description) - } -} - -impl Alert { - pub fn content_type(&self) -> ContentType { - ContentType::Alert - } - - pub fn size(&self) -> usize { - 2 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(self.alert_level as u8)?; - writer.write_u8(self.alert_description as u8)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let alert_level = reader.read_u8()?.into(); - let alert_description = reader.read_u8()?.into(); - - Ok(Alert { - alert_level, - alert_description, - }) - } -} diff --git a/dtls/src/application_data.rs b/dtls/src/application_data.rs deleted file mode 100644 index 4897430dc..000000000 --- a/dtls/src/application_data.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::io::{Read, Write}; - -use super::content::*; -use crate::error::Result; - -// Application data messages are carried by the record layer and are -// fragmented, compressed, and encrypted based on the current connection -// state. The messages are treated as transparent data to the record -// layer. -// https://tools.ietf.org/html/rfc5246#section-10 -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct ApplicationData { - pub data: Vec, -} - -impl ApplicationData { - pub fn content_type(&self) -> ContentType { - ContentType::ApplicationData - } - - pub fn size(&self) -> usize { - self.data.len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_all(&self.data)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let mut data: Vec = vec![]; - reader.read_to_end(&mut data)?; - - Ok(ApplicationData { data }) - } -} diff --git a/dtls/src/change_cipher_spec/change_cipher_spec_test.rs b/dtls/src/change_cipher_spec/change_cipher_spec_test.rs deleted file mode 100644 index 868d96df0..000000000 --- a/dtls/src/change_cipher_spec/change_cipher_spec_test.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_change_cipher_spec_round_trip() -> Result<()> { - let c = ChangeCipherSpec {}; - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - - let mut reader = BufReader::new(raw.as_slice()); - let cnew = ChangeCipherSpec::unmarshal(&mut reader)?; - assert_eq!( - c, cnew, - "ChangeCipherSpec round trip: got {cnew:?}, want {c:?}" - ); - - Ok(()) -} - -#[test] -fn test_change_cipher_spec_invalid() -> Result<()> { - let data = vec![0x00]; - - let mut reader = BufReader::new(data.as_slice()); - let result = ChangeCipherSpec::unmarshal(&mut reader); - - match result { - Ok(_) => panic!("must be error"), - Err(err) => assert_eq!(err.to_string(), Error::ErrInvalidCipherSpec.to_string()), - }; - - Ok(()) -} diff --git a/dtls/src/change_cipher_spec/mod.rs b/dtls/src/change_cipher_spec/mod.rs deleted file mode 100644 index 51c08b055..000000000 --- a/dtls/src/change_cipher_spec/mod.rs +++ /dev/null @@ -1,42 +0,0 @@ -#[cfg(test)] -mod change_cipher_spec_test; - -use std::io::{Read, Write}; - -use byteorder::{ReadBytesExt, WriteBytesExt}; - -use super::content::*; -use super::error::*; - -// The change cipher spec protocol exists to signal transitions in -// ciphering strategies. The protocol consists of a single message, -// which is encrypted and compressed under the current (not the pending) -// connection state. The message consists of a single byte of value 1. -// https://tools.ietf.org/html/rfc5246#section-7.1 -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct ChangeCipherSpec; - -impl ChangeCipherSpec { - pub fn content_type(&self) -> ContentType { - ContentType::ChangeCipherSpec - } - - pub fn size(&self) -> usize { - 1 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(0x01)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let data = reader.read_u8()?; - if data != 0x01 { - return Err(Error::ErrInvalidCipherSpec); - } - - Ok(ChangeCipherSpec {}) - } -} diff --git a/dtls/src/cipher_suite/cipher_suite_aes_128_ccm.rs b/dtls/src/cipher_suite/cipher_suite_aes_128_ccm.rs deleted file mode 100644 index a28ee0f29..000000000 --- a/dtls/src/cipher_suite/cipher_suite_aes_128_ccm.rs +++ /dev/null @@ -1,118 +0,0 @@ -use super::*; -use crate::client_certificate_type::ClientCertificateType; -use crate::crypto::crypto_ccm::{CryptoCcm, CryptoCcmTagLen}; -use crate::prf::*; - -#[derive(Clone)] -pub struct CipherSuiteAes128Ccm { - ccm: Option, - client_certificate_type: ClientCertificateType, - id: CipherSuiteId, - psk: bool, - crypto_ccm_tag_len: CryptoCcmTagLen, -} - -impl CipherSuiteAes128Ccm { - const PRF_MAC_LEN: usize = 0; - const PRF_KEY_LEN: usize = 16; - const PRF_IV_LEN: usize = 4; - - pub fn new( - client_certificate_type: ClientCertificateType, - id: CipherSuiteId, - psk: bool, - crypto_ccm_tag_len: CryptoCcmTagLen, - ) -> Self { - CipherSuiteAes128Ccm { - ccm: None, - client_certificate_type, - id, - psk, - crypto_ccm_tag_len, - } - } -} - -impl CipherSuite for CipherSuiteAes128Ccm { - fn to_string(&self) -> String { - format!("{}", self.id) - } - - fn id(&self) -> CipherSuiteId { - self.id - } - - fn certificate_type(&self) -> ClientCertificateType { - self.client_certificate_type - } - - fn hash_func(&self) -> CipherSuiteHash { - CipherSuiteHash::Sha256 - } - - fn is_psk(&self) -> bool { - self.psk - } - - fn is_initialized(&self) -> bool { - self.ccm.is_some() - } - - fn init( - &mut self, - master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - is_client: bool, - ) -> Result<()> { - let keys = prf_encryption_keys( - master_secret, - client_random, - server_random, - CipherSuiteAes128Ccm::PRF_MAC_LEN, - CipherSuiteAes128Ccm::PRF_KEY_LEN, - CipherSuiteAes128Ccm::PRF_IV_LEN, - self.hash_func(), - )?; - - if is_client { - self.ccm = Some(CryptoCcm::new( - &self.crypto_ccm_tag_len, - &keys.client_write_key, - &keys.client_write_iv, - &keys.server_write_key, - &keys.server_write_iv, - )); - } else { - self.ccm = Some(CryptoCcm::new( - &self.crypto_ccm_tag_len, - &keys.server_write_key, - &keys.server_write_iv, - &keys.client_write_key, - &keys.client_write_iv, - )); - } - - Ok(()) - } - - fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - if let Some(ccm) = &self.ccm { - ccm.encrypt(pkt_rlh, raw) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to encrypt".to_owned(), - )) - } - } - - fn decrypt(&self, input: &[u8]) -> Result> { - if let Some(ccm) = &self.ccm { - ccm.decrypt(input) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to decrypt".to_owned(), - )) - } - } -} diff --git a/dtls/src/cipher_suite/cipher_suite_aes_128_gcm_sha256.rs b/dtls/src/cipher_suite/cipher_suite_aes_128_gcm_sha256.rs deleted file mode 100644 index fe7fd9a26..000000000 --- a/dtls/src/cipher_suite/cipher_suite_aes_128_gcm_sha256.rs +++ /dev/null @@ -1,113 +0,0 @@ -use super::*; -use crate::crypto::crypto_gcm::*; -use crate::prf::*; - -#[derive(Clone)] -pub struct CipherSuiteAes128GcmSha256 { - gcm: Option, - rsa: bool, -} - -impl CipherSuiteAes128GcmSha256 { - const PRF_MAC_LEN: usize = 0; - const PRF_KEY_LEN: usize = 16; - const PRF_IV_LEN: usize = 4; - - pub fn new(rsa: bool) -> Self { - CipherSuiteAes128GcmSha256 { gcm: None, rsa } - } -} - -impl CipherSuite for CipherSuiteAes128GcmSha256 { - fn to_string(&self) -> String { - if self.rsa { - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256".to_owned() - } else { - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256".to_owned() - } - } - - fn id(&self) -> CipherSuiteId { - if self.rsa { - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256 - } else { - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256 - } - } - - fn certificate_type(&self) -> ClientCertificateType { - if self.rsa { - ClientCertificateType::RsaSign - } else { - ClientCertificateType::EcdsaSign - } - } - - fn hash_func(&self) -> CipherSuiteHash { - CipherSuiteHash::Sha256 - } - - fn is_psk(&self) -> bool { - false - } - - fn is_initialized(&self) -> bool { - self.gcm.is_some() - } - - fn init( - &mut self, - master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - is_client: bool, - ) -> Result<()> { - let keys = prf_encryption_keys( - master_secret, - client_random, - server_random, - CipherSuiteAes128GcmSha256::PRF_MAC_LEN, - CipherSuiteAes128GcmSha256::PRF_KEY_LEN, - CipherSuiteAes128GcmSha256::PRF_IV_LEN, - self.hash_func(), - )?; - - if is_client { - self.gcm = Some(CryptoGcm::new( - &keys.client_write_key, - &keys.client_write_iv, - &keys.server_write_key, - &keys.server_write_iv, - )); - } else { - self.gcm = Some(CryptoGcm::new( - &keys.server_write_key, - &keys.server_write_iv, - &keys.client_write_key, - &keys.client_write_iv, - )); - } - - Ok(()) - } - - fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - if let Some(cg) = &self.gcm { - cg.encrypt(pkt_rlh, raw) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to encrypt".to_owned(), - )) - } - } - - fn decrypt(&self, input: &[u8]) -> Result> { - if let Some(cg) = &self.gcm { - cg.decrypt(input) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to decrypt".to_owned(), - )) - } - } -} diff --git a/dtls/src/cipher_suite/cipher_suite_aes_256_cbc_sha.rs b/dtls/src/cipher_suite/cipher_suite_aes_256_cbc_sha.rs deleted file mode 100644 index e358279b7..000000000 --- a/dtls/src/cipher_suite/cipher_suite_aes_256_cbc_sha.rs +++ /dev/null @@ -1,113 +0,0 @@ -use super::*; -use crate::crypto::crypto_cbc::*; -use crate::prf::*; - -#[derive(Clone)] -pub struct CipherSuiteAes256CbcSha { - cbc: Option, - rsa: bool, -} - -impl CipherSuiteAes256CbcSha { - const PRF_MAC_LEN: usize = 20; - const PRF_KEY_LEN: usize = 32; - const PRF_IV_LEN: usize = 16; - - pub fn new(rsa: bool) -> Self { - CipherSuiteAes256CbcSha { cbc: None, rsa } - } -} - -impl CipherSuite for CipherSuiteAes256CbcSha { - fn to_string(&self) -> String { - if self.rsa { - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA".to_owned() - } else { - "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA".to_owned() - } - } - - fn id(&self) -> CipherSuiteId { - if self.rsa { - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_256_Cbc_Sha - } else { - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha - } - } - - fn certificate_type(&self) -> ClientCertificateType { - if self.rsa { - ClientCertificateType::RsaSign - } else { - ClientCertificateType::EcdsaSign - } - } - - fn hash_func(&self) -> CipherSuiteHash { - CipherSuiteHash::Sha256 - } - - fn is_psk(&self) -> bool { - false - } - - fn is_initialized(&self) -> bool { - self.cbc.is_some() - } - - fn init( - &mut self, - master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - is_client: bool, - ) -> Result<()> { - let keys = prf_encryption_keys( - master_secret, - client_random, - server_random, - CipherSuiteAes256CbcSha::PRF_MAC_LEN, - CipherSuiteAes256CbcSha::PRF_KEY_LEN, - CipherSuiteAes256CbcSha::PRF_IV_LEN, - self.hash_func(), - )?; - - if is_client { - self.cbc = Some(CryptoCbc::new( - &keys.client_write_key, - &keys.client_mac_key, - &keys.server_write_key, - &keys.server_mac_key, - )?); - } else { - self.cbc = Some(CryptoCbc::new( - &keys.server_write_key, - &keys.server_mac_key, - &keys.client_write_key, - &keys.client_mac_key, - )?); - } - - Ok(()) - } - - fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - if let Some(cg) = &self.cbc { - cg.encrypt(pkt_rlh, raw) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to encrypt".to_owned(), - )) - } - } - - fn decrypt(&self, input: &[u8]) -> Result> { - if let Some(cg) = &self.cbc { - cg.decrypt(input) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to decrypt".to_owned(), - )) - } - } -} diff --git a/dtls/src/cipher_suite/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm.rs b/dtls/src/cipher_suite/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm.rs deleted file mode 100644 index c5bf2dfd1..000000000 --- a/dtls/src/cipher_suite/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::*; -use crate::cipher_suite::cipher_suite_aes_128_ccm::CipherSuiteAes128Ccm; -use crate::crypto::crypto_ccm::CryptoCcmTagLen; - -pub fn new_cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm() -> CipherSuiteAes128Ccm { - CipherSuiteAes128Ccm::new( - ClientCertificateType::EcdsaSign, - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm, - false, - CryptoCcmTagLen::CryptoCcmTagLength, - ) -} diff --git a/dtls/src/cipher_suite/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8.rs b/dtls/src/cipher_suite/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8.rs deleted file mode 100644 index aa9a92ce0..000000000 --- a/dtls/src/cipher_suite/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::*; -use crate::cipher_suite::cipher_suite_aes_128_ccm::CipherSuiteAes128Ccm; -use crate::crypto::crypto_ccm::CryptoCcmTagLen; - -pub fn new_cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8() -> CipherSuiteAes128Ccm { - CipherSuiteAes128Ccm::new( - ClientCertificateType::EcdsaSign, - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8, - false, - CryptoCcmTagLen::CryptoCcm8TagLength, - ) -} diff --git a/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_ccm.rs b/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_ccm.rs deleted file mode 100644 index 6f506e0ef..000000000 --- a/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_ccm.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::*; -use crate::cipher_suite::cipher_suite_aes_128_ccm::CipherSuiteAes128Ccm; -use crate::crypto::crypto_ccm::CryptoCcmTagLen; - -pub fn new_cipher_suite_tls_psk_with_aes_128_ccm() -> CipherSuiteAes128Ccm { - CipherSuiteAes128Ccm::new( - ClientCertificateType::Unsupported, - CipherSuiteId::Tls_Psk_With_Aes_128_Ccm, - true, - CryptoCcmTagLen::CryptoCcmTagLength, - ) -} diff --git a/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_ccm8.rs b/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_ccm8.rs deleted file mode 100644 index 64b9f1a50..000000000 --- a/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_ccm8.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::*; -use crate::cipher_suite::cipher_suite_aes_128_ccm::CipherSuiteAes128Ccm; -use crate::crypto::crypto_ccm::CryptoCcmTagLen; - -pub fn new_cipher_suite_tls_psk_with_aes_128_ccm8() -> CipherSuiteAes128Ccm { - CipherSuiteAes128Ccm::new( - ClientCertificateType::Unsupported, - CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8, - true, - CryptoCcmTagLen::CryptoCcm8TagLength, - ) -} diff --git a/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_gcm_sha256.rs b/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_gcm_sha256.rs deleted file mode 100644 index 2204ace29..000000000 --- a/dtls/src/cipher_suite/cipher_suite_tls_psk_with_aes_128_gcm_sha256.rs +++ /dev/null @@ -1,96 +0,0 @@ -use super::*; -use crate::crypto::crypto_gcm::*; -use crate::prf::*; - -#[derive(Clone, Default)] -pub struct CipherSuiteTlsPskWithAes128GcmSha256 { - gcm: Option, -} - -impl CipherSuiteTlsPskWithAes128GcmSha256 { - const PRF_MAC_LEN: usize = 0; - const PRF_KEY_LEN: usize = 16; - const PRF_IV_LEN: usize = 4; -} - -impl CipherSuite for CipherSuiteTlsPskWithAes128GcmSha256 { - fn to_string(&self) -> String { - "TLS_PSK_WITH_AES_128_GCM_SHA256".to_owned() - } - - fn id(&self) -> CipherSuiteId { - CipherSuiteId::Tls_Psk_With_Aes_128_Gcm_Sha256 - } - - fn certificate_type(&self) -> ClientCertificateType { - ClientCertificateType::Unsupported - } - - fn hash_func(&self) -> CipherSuiteHash { - CipherSuiteHash::Sha256 - } - - fn is_psk(&self) -> bool { - true - } - - fn is_initialized(&self) -> bool { - self.gcm.is_some() - } - - fn init( - &mut self, - master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - is_client: bool, - ) -> Result<()> { - let keys = prf_encryption_keys( - master_secret, - client_random, - server_random, - CipherSuiteTlsPskWithAes128GcmSha256::PRF_MAC_LEN, - CipherSuiteTlsPskWithAes128GcmSha256::PRF_KEY_LEN, - CipherSuiteTlsPskWithAes128GcmSha256::PRF_IV_LEN, - self.hash_func(), - )?; - - if is_client { - self.gcm = Some(CryptoGcm::new( - &keys.client_write_key, - &keys.client_write_iv, - &keys.server_write_key, - &keys.server_write_iv, - )); - } else { - self.gcm = Some(CryptoGcm::new( - &keys.server_write_key, - &keys.server_write_iv, - &keys.client_write_key, - &keys.client_write_iv, - )); - } - - Ok(()) - } - - fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - if let Some(cg) = &self.gcm { - cg.encrypt(pkt_rlh, raw) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to encrypt".to_owned(), - )) - } - } - - fn decrypt(&self, input: &[u8]) -> Result> { - if let Some(cg) = &self.gcm { - cg.decrypt(input) - } else { - Err(Error::Other( - "CipherSuite has not been initialized, unable to decrypt".to_owned(), - )) - } - } -} diff --git a/dtls/src/cipher_suite/mod.rs b/dtls/src/cipher_suite/mod.rs deleted file mode 100644 index 195492c21..000000000 --- a/dtls/src/cipher_suite/mod.rs +++ /dev/null @@ -1,227 +0,0 @@ -pub mod cipher_suite_aes_128_ccm; -pub mod cipher_suite_aes_128_gcm_sha256; -pub mod cipher_suite_aes_256_cbc_sha; -pub mod cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm; -pub mod cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8; -pub mod cipher_suite_tls_psk_with_aes_128_ccm; -pub mod cipher_suite_tls_psk_with_aes_128_ccm8; -pub mod cipher_suite_tls_psk_with_aes_128_gcm_sha256; - -use std::fmt; -use std::marker::{Send, Sync}; - -use cipher_suite_aes_128_gcm_sha256::*; -use cipher_suite_aes_256_cbc_sha::*; -use cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm::*; -use cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8::*; -use cipher_suite_tls_psk_with_aes_128_ccm::*; -use cipher_suite_tls_psk_with_aes_128_ccm8::*; -use cipher_suite_tls_psk_with_aes_128_gcm_sha256::*; - -use super::client_certificate_type::*; -use super::error::*; -use super::record_layer::record_layer_header::*; - -// CipherSuiteID is an ID for our supported CipherSuites -// Supported Cipher Suites -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum CipherSuiteId { - // AES-128-CCM - Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm = 0xc0ac, - Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8 = 0xc0ae, - - // AES-128-GCM-SHA256 - Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256 = 0xc02b, - Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256 = 0xc02f, - - // AES-256-CBC-SHA - Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha = 0xc00a, - Tls_Ecdhe_Rsa_With_Aes_256_Cbc_Sha = 0xc014, - - Tls_Psk_With_Aes_128_Ccm = 0xc0a4, - Tls_Psk_With_Aes_128_Ccm_8 = 0xc0a8, - Tls_Psk_With_Aes_128_Gcm_Sha256 = 0x00a8, - - Unsupported, -} - -impl fmt::Display for CipherSuiteId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm => { - write!(f, "TLS_ECDHE_ECDSA_WITH_AES_128_CCM") - } - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8 => { - write!(f, "TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8") - } - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256 => { - write!(f, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") - } - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256 => { - write!(f, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") - } - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha => { - write!(f, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA") - } - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_256_Cbc_Sha => { - write!(f, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA") - } - CipherSuiteId::Tls_Psk_With_Aes_128_Ccm => write!(f, "TLS_PSK_WITH_AES_128_CCM"), - CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8 => write!(f, "TLS_PSK_WITH_AES_128_CCM_8"), - CipherSuiteId::Tls_Psk_With_Aes_128_Gcm_Sha256 => { - write!(f, "TLS_PSK_WITH_AES_128_GCM_SHA256") - } - _ => write!(f, "Unsupported CipherSuiteID"), - } - } -} - -impl From for CipherSuiteId { - fn from(val: u16) -> Self { - match val { - // AES-128-CCM - 0xc0ac => CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm, - 0xc0ae => CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8, - - // AES-128-GCM-SHA256 - 0xc02b => CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - 0xc02f => CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256, - - // AES-256-CBC-SHA - 0xc00a => CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha, - 0xc014 => CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_256_Cbc_Sha, - - 0xc0a4 => CipherSuiteId::Tls_Psk_With_Aes_128_Ccm, - 0xc0a8 => CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8, - 0x00a8 => CipherSuiteId::Tls_Psk_With_Aes_128_Gcm_Sha256, - - _ => CipherSuiteId::Unsupported, - } - } -} - -#[derive(Copy, Clone, Debug)] -pub enum CipherSuiteHash { - Sha256, -} - -impl CipherSuiteHash { - pub(crate) fn size(&self) -> usize { - match *self { - CipherSuiteHash::Sha256 => 32, - } - } -} - -pub trait CipherSuite { - fn to_string(&self) -> String; - fn id(&self) -> CipherSuiteId; - fn certificate_type(&self) -> ClientCertificateType; - fn hash_func(&self) -> CipherSuiteHash; - fn is_psk(&self) -> bool; - fn is_initialized(&self) -> bool; - - // Generate the internal encryption state - fn init( - &mut self, - master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - is_client: bool, - ) -> Result<()>; - - fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result>; - fn decrypt(&self, input: &[u8]) -> Result>; -} - -// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml -// A cipher_suite is a specific combination of key agreement, cipher and MAC -// function. -pub fn cipher_suite_for_id(id: CipherSuiteId) -> Result> { - match id { - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm => { - Ok(Box::new(new_cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm())) - } - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8 => Ok(Box::new( - new_cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8(), - )), - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256 => { - Ok(Box::new(CipherSuiteAes128GcmSha256::new(false))) - } - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256 => { - Ok(Box::new(CipherSuiteAes128GcmSha256::new(true))) - } - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_256_Cbc_Sha => { - Ok(Box::new(CipherSuiteAes256CbcSha::new(true))) - } - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha => { - Ok(Box::new(CipherSuiteAes256CbcSha::new(false))) - } - CipherSuiteId::Tls_Psk_With_Aes_128_Ccm => { - Ok(Box::new(new_cipher_suite_tls_psk_with_aes_128_ccm())) - } - CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8 => { - Ok(Box::new(new_cipher_suite_tls_psk_with_aes_128_ccm8())) - } - CipherSuiteId::Tls_Psk_With_Aes_128_Gcm_Sha256 => { - Ok(Box::::default()) - } - _ => Err(Error::ErrInvalidCipherSuite), - } -} - -// CipherSuites we support in order of preference -pub(crate) fn default_cipher_suites() -> Vec> { - vec![ - Box::new(CipherSuiteAes128GcmSha256::new(false)), - Box::new(CipherSuiteAes256CbcSha::new(false)), - Box::new(CipherSuiteAes128GcmSha256::new(true)), - Box::new(CipherSuiteAes256CbcSha::new(true)), - ] -} - -fn all_cipher_suites() -> Vec> { - vec![ - Box::new(new_cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm()), - Box::new(new_cipher_suite_tls_ecdhe_ecdsa_with_aes_128_ccm8()), - Box::new(CipherSuiteAes128GcmSha256::new(false)), - Box::new(CipherSuiteAes128GcmSha256::new(true)), - Box::new(CipherSuiteAes256CbcSha::new(false)), - Box::new(CipherSuiteAes256CbcSha::new(true)), - Box::new(new_cipher_suite_tls_psk_with_aes_128_ccm()), - Box::new(new_cipher_suite_tls_psk_with_aes_128_ccm8()), - Box::::default(), - ] -} - -fn cipher_suites_for_ids(ids: &[CipherSuiteId]) -> Result>> { - let mut cipher_suites = vec![]; - for id in ids { - cipher_suites.push(cipher_suite_for_id(*id)?); - } - Ok(cipher_suites) -} - -pub(crate) fn parse_cipher_suites( - user_selected_suites: &[CipherSuiteId], - exclude_psk: bool, - exclude_non_psk: bool, -) -> Result>> { - let cipher_suites = if !user_selected_suites.is_empty() { - cipher_suites_for_ids(user_selected_suites)? - } else { - default_cipher_suites() - }; - - let filtered_cipher_suites: Vec> = cipher_suites - .into_iter() - .filter(|c| !((exclude_psk && c.is_psk()) || (exclude_non_psk && !c.is_psk()))) - .collect(); - - if filtered_cipher_suites.is_empty() { - Err(Error::ErrNoAvailableCipherSuites) - } else { - Ok(filtered_cipher_suites) - } -} diff --git a/dtls/src/client_certificate_type.rs b/dtls/src/client_certificate_type.rs deleted file mode 100644 index d87aacb6c..000000000 --- a/dtls/src/client_certificate_type.rs +++ /dev/null @@ -1,16 +0,0 @@ -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ClientCertificateType { - RsaSign = 1, - EcdsaSign = 64, - Unsupported, -} - -impl From for ClientCertificateType { - fn from(val: u8) -> Self { - match val { - 1 => ClientCertificateType::RsaSign, - 64 => ClientCertificateType::EcdsaSign, - _ => ClientCertificateType::Unsupported, - } - } -} diff --git a/dtls/src/compression_methods.rs b/dtls/src/compression_methods.rs deleted file mode 100644 index a03195ec7..000000000 --- a/dtls/src/compression_methods.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::io::{Read, Write}; - -use byteorder::{ReadBytesExt, WriteBytesExt}; - -use crate::error::Result; - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum CompressionMethodId { - Null = 0, - Unsupported, -} - -impl From for CompressionMethodId { - fn from(val: u8) -> Self { - match val { - 0 => CompressionMethodId::Null, - _ => CompressionMethodId::Unsupported, - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct CompressionMethods { - pub ids: Vec, -} - -impl CompressionMethods { - pub fn size(&self) -> usize { - 1 + self.ids.len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(self.ids.len() as u8)?; - - for id in &self.ids { - writer.write_u8(*id as u8)?; - } - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let compression_methods_count = reader.read_u8()? as usize; - let mut ids = vec![]; - for _ in 0..compression_methods_count { - let id = reader.read_u8()?.into(); - if id != CompressionMethodId::Unsupported { - ids.push(id); - } - } - - Ok(CompressionMethods { ids }) - } -} - -pub fn default_compression_methods() -> CompressionMethods { - CompressionMethods { - ids: vec![CompressionMethodId::Null], - } -} diff --git a/dtls/src/config.rs b/dtls/src/config.rs deleted file mode 100644 index 0e1e23c6f..000000000 --- a/dtls/src/config.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::sync::Arc; - -use tokio::time::Duration; - -use crate::cipher_suite::*; -use crate::crypto::*; -use crate::error::*; -use crate::extension::extension_use_srtp::SrtpProtectionProfile; -use crate::handshaker::VerifyPeerCertificateFn; -use crate::signature_hash_algorithm::SignatureScheme; - -/// Config is used to configure a DTLS client or server. -/// After a Config is passed to a DTLS function it must not be modified. -#[derive(Clone)] -pub struct Config { - /// certificates contains certificate chain to present to the other side of the connection. - /// Server MUST set this if psk is non-nil - /// client SHOULD sets this so CertificateRequests can be handled if psk is non-nil - pub certificates: Vec, - - /// cipher_suites is a list of supported cipher suites. - /// If cipher_suites is nil, a default list is used - pub cipher_suites: Vec, - - /// signature_schemes contains the signature and hash schemes that the peer requests to verify. - pub signature_schemes: Vec, - - /// srtp_protection_profiles are the supported protection profiles - /// Clients will send this via use_srtp and assert that the server properly responds - /// Servers will assert that clients send one of these profiles and will respond as needed - pub srtp_protection_profiles: Vec, - - /// client_auth determines the server's policy for - /// TLS Client Authentication. The default is NoClientCert. - pub client_auth: ClientAuthType, - - /// extended_master_secret determines if the "Extended Master Secret" extension - /// should be disabled, requested, or required (default requested). - pub extended_master_secret: ExtendedMasterSecretType, - - /// flight_interval controls how often we send outbound handshake messages - /// defaults to time.Second - pub flight_interval: Duration, - - /// psk sets the pre-shared key used by this DTLS connection - /// If psk is non-nil only psk cipher_suites will be used - pub psk: Option, - pub psk_identity_hint: Option>, - - /// insecure_skip_verify controls whether a client verifies the - /// server's certificate chain and host name. - /// If insecure_skip_verify is true, TLS accepts any certificate - /// presented by the server and any host name in that certificate. - /// In this mode, TLS is susceptible to man-in-the-middle attacks. - /// This should be used only for testing. - pub insecure_skip_verify: bool, - - /// insecure_hashes allows the use of hashing algorithms that are known - /// to be vulnerable. - pub insecure_hashes: bool, - - /// insecure_verification allows the use of verification algorithms that are - /// known to be vulnerable or deprecated - pub insecure_verification: bool, - /// VerifyPeerCertificate, if not nil, is called after normal - /// certificate verification by either a client or server. It - /// receives the certificate provided by the peer and also a flag - /// that tells if normal verification has succeeded. If it returns a - /// non-nil error, the handshake is aborted and that error results. - /// - /// If normal verification fails then the handshake will abort before - /// considering this callback. If normal verification is disabled by - /// setting insecure_skip_verify, or (for a server) when client_auth is - /// RequestClientCert or RequireAnyClientCert, then this callback will - /// be considered but the verifiedChains will always be nil. - pub verify_peer_certificate: Option, - - /// roots_cas defines the set of root certificate authorities - /// that one peer uses when verifying the other peer's certificates. - /// If RootCAs is nil, TLS uses the host's root CA set. - /// Used by Client to verify server's certificate - pub roots_cas: rustls::RootCertStore, - - /// client_cas defines the set of root certificate authorities - /// that servers use if required to verify a client certificate - /// by the policy in client_auth. - /// Used by Server to verify client's certificate - pub client_cas: rustls::RootCertStore, - - /// server_name is used to verify the hostname on the returned - /// certificates unless insecure_skip_verify is given. - pub server_name: String, - - /// mtu is the length at which handshake messages will be fragmented to - /// fit within the maximum transmission unit (default is 1200 bytes) - pub mtu: usize, - - /// replay_protection_window is the size of the replay attack protection window. - /// Duplication of the sequence number is checked in this window size. - /// Packet with sequence number older than this value compared to the latest - /// accepted packet will be discarded. (default is 64) - pub replay_protection_window: usize, -} - -impl Default for Config { - fn default() -> Self { - Config { - certificates: vec![], - cipher_suites: vec![], - signature_schemes: vec![], - srtp_protection_profiles: vec![], - client_auth: ClientAuthType::default(), - extended_master_secret: ExtendedMasterSecretType::default(), - flight_interval: Duration::default(), - psk: None, - psk_identity_hint: None, - insecure_skip_verify: false, - insecure_hashes: false, - insecure_verification: false, - verify_peer_certificate: None, - roots_cas: rustls::RootCertStore::empty(), - client_cas: rustls::RootCertStore::empty(), - server_name: String::default(), - mtu: 0, - replay_protection_window: 0, - } - } -} - -pub(crate) const DEFAULT_MTU: usize = 1200; // bytes - -// PSKCallback is called once we have the remote's psk_identity_hint. -// If the remote provided none it will be nil -pub(crate) type PskCallback = Arc Result>) + Send + Sync>; - -// ClientAuthType declares the policy the server will follow for -// TLS Client Authentication. -#[derive(Default, Copy, Clone, PartialEq, Eq)] -pub enum ClientAuthType { - #[default] - NoClientCert = 0, - RequestClientCert = 1, - RequireAnyClientCert = 2, - VerifyClientCertIfGiven = 3, - RequireAndVerifyClientCert = 4, -} - -// ExtendedMasterSecretType declares the policy the client and server -// will follow for the Extended Master Secret extension -#[derive(Default, PartialEq, Eq, Copy, Clone)] -pub enum ExtendedMasterSecretType { - #[default] - Request = 0, - Require = 1, - Disable = 2, -} - -pub(crate) fn validate_config(is_client: bool, config: &Config) -> Result<()> { - if is_client && config.psk.is_some() && config.psk_identity_hint.is_none() { - return Err(Error::ErrPskAndIdentityMustBeSetForClient); - } - - if !is_client && config.psk.is_none() && config.certificates.is_empty() { - return Err(Error::ErrServerMustHaveCertificate); - } - - if !config.certificates.is_empty() && config.psk.is_some() { - return Err(Error::ErrPskAndCertificate); - } - - if config.psk_identity_hint.is_some() && config.psk.is_none() { - return Err(Error::ErrIdentityNoPsk); - } - - for cert in &config.certificates { - match cert.private_key.kind { - CryptoPrivateKeyKind::Ed25519(_) => {} - CryptoPrivateKeyKind::Ecdsa256(_) => {} - _ => return Err(Error::ErrInvalidPrivateKey), - } - } - - parse_cipher_suites( - &config.cipher_suites, - config.psk.is_none(), - config.psk.is_some(), - )?; - - Ok(()) -} diff --git a/dtls/src/conn/conn_test.rs b/dtls/src/conn/conn_test.rs deleted file mode 100644 index 926b058f9..000000000 --- a/dtls/src/conn/conn_test.rs +++ /dev/null @@ -1,2458 +0,0 @@ -use std::time::SystemTime; - -use rand::Rng; -use rustls::pki_types::CertificateDer; -use util::conn::conn_pipe::*; -use util::KeyingMaterialExporter; - -use super::*; -use crate::cipher_suite::cipher_suite_aes_128_gcm_sha256::*; -use crate::cipher_suite::*; -use crate::compression_methods::*; -use crate::crypto::*; -use crate::curve::*; -use crate::error::*; -use crate::extension::extension_supported_elliptic_curves::*; -use crate::extension::extension_supported_point_formats::*; -use crate::extension::extension_supported_signature_algorithms::*; -use crate::extension::renegotiation_info::ExtensionRenegotiationInfo; -use crate::extension::*; -use crate::handshake::handshake_message_certificate::*; -use crate::handshake::handshake_message_client_hello::*; -use crate::handshake::handshake_message_hello_verify_request::*; -use crate::handshake::handshake_message_server_hello::*; -use crate::handshake::handshake_message_server_hello_done::*; -use crate::handshake::handshake_message_server_key_exchange::*; -use crate::handshake::handshake_random::*; -use crate::signature_hash_algorithm::*; - -const ERR_TEST_PSK_INVALID_IDENTITY: &str = "TestPSK: Server got invalid identity"; -const ERR_PSK_REJECTED: &str = "PSK Rejected"; -const ERR_NOT_EXPECTED_CHAIN: &str = "not expected chain"; -const ERR_EXPECTED_CHAIN: &str = "expected chain"; -const ERR_WRONG_CERT: &str = "wrong cert"; - -async fn build_pipe() -> Result<(DTLSConn, DTLSConn)> { - let (ua, ub) = pipe(); - - pipe_conn(Arc::new(ua), Arc::new(ub)).await -} - -async fn pipe_conn( - ca: Arc, - cb: Arc, -) -> Result<(DTLSConn, DTLSConn)> { - let (c_tx, mut c_rx) = mpsc::channel(1); - - // Setup client - tokio::spawn(async move { - let client = create_test_client( - ca, - Config { - srtp_protection_profiles: vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - ..Default::default() - }, - true, - ) - .await; - - let _ = c_tx.send(client).await; - }); - - // Setup server - let sever = create_test_server( - cb, - Config { - srtp_protection_profiles: vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - ..Default::default() - }, - true, - ) - .await?; - - // Receive client - let client = match c_rx.recv().await.unwrap() { - Ok(client) => client, - Err(err) => return Err(err), - }; - - Ok((client, sever)) -} - -fn psk_callback_client(hint: &[u8]) -> Result> { - trace!( - "Server's hint: {}", - String::from_utf8(hint.to_vec()).unwrap() - ); - Ok(vec![0xAB, 0xC1, 0x23]) -} - -fn psk_callback_server(hint: &[u8]) -> Result> { - trace!( - "Client's hint: {}", - String::from_utf8(hint.to_vec()).unwrap() - ); - Ok(vec![0xAB, 0xC1, 0x23]) -} - -fn psk_callback_hint_fail(_hint: &[u8]) -> Result> { - Err(Error::Other(ERR_PSK_REJECTED.to_owned())) -} - -async fn create_test_client( - ca: Arc, - mut cfg: Config, - generate_certificate: bool, -) -> Result { - if generate_certificate { - let client_cert = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - cfg.certificates = vec![client_cert]; - } - - cfg.insecure_skip_verify = true; - DTLSConn::new(ca, cfg, true, None).await -} - -async fn create_test_server( - cb: Arc, - mut cfg: Config, - generate_certificate: bool, -) -> Result { - if generate_certificate { - let server_cert = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - cfg.certificates = vec![server_cert]; - } - - DTLSConn::new(cb, cfg, false, None).await -} - -#[tokio::test] -async fn test_routine_leak_on_close() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let (ca, cb) = build_pipe().await?; - - let buf_a = vec![0xFA; 100]; - let n_a = ca.write(&buf_a, Some(Duration::from_secs(5))).await?; - assert_eq!(n_a, 100); - - let mut buf_b = vec![0; 1024]; - let n_b = cb.read(&mut buf_b, Some(Duration::from_secs(5))).await?; - assert_eq!(n_a, 100); - assert_eq!(&buf_a[..], &buf_b[0..n_b]); - - cb.close().await?; - ca.close().await?; - - { - drop(ca); - drop(cb); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - - Ok(()) -} - -#[tokio::test] -async fn test_sequence_number_overflow_on_application_data() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let (ca, cb) = build_pipe().await?; - - { - let mut lsn = ca.state.local_sequence_number.lock().await; - lsn[1] = MAX_SEQUENCE_NUMBER; - } - - let buf_a = vec![0xFA; 100]; - let n_a = ca.write(&buf_a, Some(Duration::from_secs(5))).await?; - assert_eq!(n_a, 100); - - let mut buf_b = vec![0; 1024]; - let n_b = cb.read(&mut buf_b, Some(Duration::from_secs(5))).await?; - assert_eq!(n_a, 100); - assert_eq!(&buf_a[..], &buf_b[0..n_b]); - - let result = ca.write(&buf_a, Some(Duration::from_secs(5))).await; - if let Err(err) = result { - assert_eq!( - err.to_string(), - Error::ErrSequenceNumberOverflow.to_string() - ); - } else { - panic!("Expected error but it is OK"); - } - - cb.close().await?; - - if let Err(err) = ca.close().await { - assert_eq!( - err.to_string(), - Error::ErrSequenceNumberOverflow.to_string() - ); - } else { - panic!("Expected error but it is OK"); - } - - { - drop(ca); - drop(cb); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - - Ok(()) -} - -#[tokio::test] -async fn test_sequence_number_overflow_on_handshake() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let (ca, cb) = build_pipe().await?; - - { - let mut lsn = ca.state.local_sequence_number.lock().await; - lsn[0] = MAX_SEQUENCE_NUMBER + 1; - } - - // Try to send handshake packet. - if let Err(err) = ca - .write_packets(vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: PROTOCOL_VERSION1_2, - random: HandshakeRandom::default(), - cookie: vec![0; 64], - - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Gcm_Sha256], - compression_methods: default_compression_methods(), - extensions: vec![], - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }]) - .await - { - assert_eq!( - err.to_string(), - Error::ErrSequenceNumberOverflow.to_string() - ); - } else { - panic!("Expected error but it is OK"); - } - - cb.close().await?; - ca.close().await?; - - { - drop(ca); - drop(cb); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - - Ok(()) -} - -#[tokio::test] -async fn test_handshake_with_alert() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let cases = vec![ - ( - "CipherSuiteNoIntersection", - Config { - // Server - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - ..Default::default() - }, - Config { - // Client - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256], - ..Default::default() - }, - Error::ErrCipherSuiteNoIntersection, - Error::ErrAlertFatalOrClose, //errClient: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}}, - ), - ( - "SignatureSchemesNoIntersection", - Config { - // Server - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - signature_schemes: vec![SignatureScheme::EcdsaWithP256AndSha256], - ..Default::default() - }, - Config { - // Client - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - signature_schemes: vec![SignatureScheme::EcdsaWithP521AndSha512], - ..Default::default() - }, - Error::ErrAlertFatalOrClose, //errServer: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}}, - Error::ErrNoAvailableSignatureSchemes, //NoAvailableSignatureSchemes, - ), - ]; - - for (name, config_server, config_client, err_server, err_client) in cases { - let (client_err_tx, mut client_err_rx) = mpsc::channel(1); - - let (ca, cb) = pipe(); - tokio::spawn(async move { - let result = create_test_client(Arc::new(ca), config_client, true).await; - let _ = client_err_tx.send(result).await; - }); - - let result_server = create_test_server(Arc::new(cb), config_server, true).await; - if let Err(err) = result_server { - assert_eq!( - err.to_string(), - err_server.to_string(), - "{name} Server error exp({err_server}) failed({err})" - ); - } else { - panic!("{name} expected error but create_test_server return OK"); - } - - let result_client = client_err_rx.recv().await; - if let Some(result_client) = result_client { - if let Err(err) = result_client { - assert_eq!( - err.to_string(), - err_client.to_string(), - "{name} Client error exp({err_client}) failed({err})" - ); - } else { - panic!("{name} expected error but create_test_client return OK"); - } - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_export_keying_material() -> Result<()> { - let export_label = "EXTRACTOR-dtls_srtp"; - let expected_server_key = vec![0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b]; - let expected_client_key = vec![0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77]; - - let (_decrypted_tx, decrypted_rx) = mpsc::channel(1); - let (_handshake_tx, handshake_rx) = mpsc::channel(1); - let (packet_tx, _packet_rx) = mpsc::channel(1); - let (handle_queue_tx, _handle_queue_rx) = mpsc::channel(1); - let (ca, _cb) = pipe(); - - let mut c = DTLSConn { - conn: Arc::new(ca), - state: State { - local_random: HandshakeRandom { - gmt_unix_time: SystemTime::UNIX_EPOCH - .checked_add(Duration::new(500, 0)) - .unwrap(), - ..Default::default() - }, - remote_random: HandshakeRandom { - gmt_unix_time: SystemTime::UNIX_EPOCH - .checked_add(Duration::new(1000, 0)) - .unwrap(), - ..Default::default() - }, - local_sequence_number: Arc::new(Mutex::new(vec![0, 0])), - cipher_suite: Arc::new(Mutex::new(Some(Box::new(CipherSuiteAes128GcmSha256::new( - false, - ))))), - ..Default::default() - }, - cache: HandshakeCache::new(), - decrypted_rx: Mutex::new(decrypted_rx), - handshake_completed_successfully: Arc::new(AtomicBool::new(false)), - connection_closed_by_user: false, - closed: AtomicBool::new(false), - current_flight: Box::new(Flight0 {}) as Box, - flights: None, - cfg: HandshakeConfig::default(), - retransmit: false, - handshake_rx, - - packet_tx: Arc::new(packet_tx), - handle_queue_tx, - handshake_done_tx: None, - - reader_close_tx: Mutex::new(None), - }; - - c.set_local_epoch(0); - let state = c.connection_state().await; - if let Err(err) = state.export_keying_material(export_label, &[], 0).await { - assert!( - err.to_string() - .contains(&Error::ErrHandshakeInProgress.to_string()), - "ExportKeyingMaterial when epoch == 0: expected '{}' actual '{}'", - Error::ErrHandshakeInProgress, - err, - ); - } else { - panic!("expect error but export_keying_material returns OK"); - } - - c.set_local_epoch(1); - let state = c.connection_state().await; - if let Err(err) = state.export_keying_material(export_label, &[0x00], 0).await { - assert!( - err.to_string() - .contains(&Error::ErrContextUnsupported.to_string()), - "ExportKeyingMaterial with context: expected '{}' actual '{}'", - Error::ErrContextUnsupported, - err - ); - } else { - panic!("expect error but export_keying_material returns OK"); - } - - for k in INVALID_KEYING_LABELS.iter() { - let state = c.connection_state().await; - if let Err(err) = state.export_keying_material(k, &[], 0).await { - assert!( - err.to_string() - .contains(&Error::ErrReservedExportKeyingMaterial.to_string()), - "ExportKeyingMaterial reserved label: expected '{}' actual '{}'", - Error::ErrReservedExportKeyingMaterial, - err, - ); - } else { - panic!("expect error but export_keying_material returns OK"); - } - } - - let state = c.connection_state().await; - let keying_material = state.export_keying_material(export_label, &[], 10).await?; - assert_eq!( - &keying_material, &expected_server_key, - "ExportKeyingMaterial client export: expected ({:?}) actual ({:?})", - &expected_server_key, &keying_material, - ); - - c.state.is_client = true; - let state = c.connection_state().await; - let keying_material = state.export_keying_material(export_label, &[], 10).await?; - assert_eq!( - &keying_material, &expected_client_key, - "ExportKeyingMaterial client export: expected ({:?}) actual ({:?})", - &expected_client_key, &keying_material, - ); - - Ok(()) -} - -#[tokio::test] -async fn test_psk() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let tests = vec![ - ( - "Server identity specified", - Some("Test Identity".as_bytes().to_vec()), - ), - ("Server identity nil", None), - ]; - - for (name, server_identity) in tests { - let client_identity = "Client Identity".as_bytes(); - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - - let (ca, cb) = pipe(); - tokio::spawn(async move { - let conf = Config { - psk: Some(Arc::new(psk_callback_client)), - psk_identity_hint: Some(client_identity.to_vec()), - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], - ..Default::default() - }; - - let result = create_test_client(Arc::new(ca), conf, false).await; - let _ = client_res_tx.send(result).await; - }); - - let config = Config { - psk: Some(Arc::new(psk_callback_server)), - psk_identity_hint: server_identity, - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], - ..Default::default() - }; - - let server = create_test_server(Arc::new(cb), config, false).await?; - - let actual_psk_identity_hint = &server.connection_state().await.identity_hint; - assert_eq!( - actual_psk_identity_hint, client_identity, - "TestPSK: Server ClientPSKIdentity Mismatch '{name}': expected({client_identity:?}) actual({actual_psk_identity_hint:?})", - ); - - if let Some(result) = client_res_rx.recv().await { - if let Ok(client) = result { - client.close().await?; - } else { - panic!("{name}: Expected create_test_client successfully, but got error",); - } - } - - let _ = server.close().await; - } - - Ok(()) -} - -#[tokio::test] -async fn test_psk_hint_fail() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - - let (ca, cb) = pipe(); - tokio::spawn(async move { - let conf = Config { - psk: Some(Arc::new(psk_callback_hint_fail)), - psk_identity_hint: Some(vec![]), - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], - ..Default::default() - }; - - let result = create_test_client(Arc::new(ca), conf, false).await; - let _ = client_res_tx.send(result).await; - }); - - let config = Config { - psk: Some(Arc::new(psk_callback_hint_fail)), - psk_identity_hint: Some(vec![]), - cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], - ..Default::default() - }; - - if let Err(server_err) = create_test_server(Arc::new(cb), config, false).await { - assert_eq!( - server_err.to_string(), - Error::ErrAlertFatalOrClose.to_string(), - "TestPSK: Server error exp({}) failed({})", - Error::ErrAlertFatalOrClose, - server_err, - ); - } else { - panic!("Expected server error, but got OK"); - } - - let result = client_res_rx.recv().await; - if let Some(client) = result { - if let Err(client_err) = client { - assert!( - client_err.to_string().contains(ERR_PSK_REJECTED), - "TestPSK: Client error exp({ERR_PSK_REJECTED}) failed({client_err})", - ); - } else { - panic!("Expected client error, but got OK"); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_client_timeout() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - - let (ca, _cb) = pipe(); - tokio::spawn(async move { - let conf = Config::default(); - let result = tokio::time::timeout( - Duration::from_millis(100), - create_test_client(Arc::new(ca), conf, true), - ) - .await; - let _ = client_res_tx.send(result).await; - }); - - // no server! - let result = client_res_rx.recv().await; - if let Some(client_timeout_result) = result { - assert!(client_timeout_result.is_err(), "Expected Error but got Ok"); - } - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_srtp_configuration() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - #[allow(clippy::type_complexity)] - let tests: Vec<( - &str, - Vec, - Vec, - SrtpProtectionProfile, - Option, - Option, - )> = vec![ - ( - "No SRTP in use", - vec![], - vec![], - SrtpProtectionProfile::Unsupported, - None, - None, - ), - ( - "SRTP both ends", - vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - None, - None, - ), - ( - "SRTP client only", - vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - vec![], - SrtpProtectionProfile::Unsupported, - Some(Error::ErrAlertFatalOrClose), - Some(Error::ErrServerNoMatchingSrtpProfile), - ), - ( - "SRTP server only", - vec![], - vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - SrtpProtectionProfile::Unsupported, - None, - None, - ), - ( - "Multiple Suites", - vec![ - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32, - ], - vec![ - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32, - ], - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - None, - None, - ), - ( - "Multiple Suites, Client Chooses", - vec![ - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32, - ], - vec![ - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32, - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - ], - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - None, - None, - ), - ]; - - for (name, client_srtp, server_srtp, expected_profile, want_client_err, want_server_err) in - tests - { - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - let (ca, cb) = pipe(); - tokio::spawn(async move { - let conf = Config { - srtp_protection_profiles: client_srtp, - ..Default::default() - }; - - let result = create_test_client(Arc::new(ca), conf, true).await; - let _ = client_res_tx.send(result).await; - }); - - let config = Config { - srtp_protection_profiles: server_srtp, - ..Default::default() - }; - - let result = create_test_server(Arc::new(cb), config, true).await; - if let Some(expected_err) = want_server_err { - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "{name} TestPSK: Server error exp({expected_err}) failed({err})", - ); - } else { - panic!("{name} expected error, but got ok"); - } - } else { - match result { - Ok(server) => { - let actual_server_srtp = server.selected_srtpprotection_profile(); - assert_eq!(actual_server_srtp, expected_profile, - "test_srtp_configuration: Server SRTPProtectionProfile Mismatch '{name}': expected({expected_profile:?}) actual({actual_server_srtp:?})"); - } - Err(err) => { - panic!("{name} expected no error: {err}"); - } - }; - } - - let client_result = client_res_rx.recv().await; - if let Some(result) = client_result { - if let Some(expected_err) = want_client_err { - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "TestPSK: Client error exp({expected_err}) failed({err})", - ); - } else { - panic!("{name} expected error, but got ok"); - } - } else if let Ok(client) = result { - let actual_client_srtp = client.selected_srtpprotection_profile(); - assert_eq!(actual_client_srtp, expected_profile, - "test_srtp_configuration: Client SRTPProtectionProfile Mismatch '{name}': expected({expected_profile:?}) actual({actual_client_srtp:?})"); - } else { - panic!("{name} expected no error"); - } - } else { - panic!("{name} expected client, but got none"); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_client_certificate() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let server_name = "localhost".to_owned(); - - let srv_cert = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - let mut srv_ca_pool = rustls::RootCertStore::empty(); - srv_ca_pool - .add(srv_cert.certificate[0].to_owned()) - .map_err(|_err| Error::Other("add srv_cert error".to_owned()))?; - - let cert = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - let mut ca_pool = rustls::RootCertStore::empty(); - ca_pool - .add(cert.certificate[0].to_owned()) - .map_err(|_err| Error::Other("add cert error".to_owned()))?; - - let tests = vec![ - ( - "NoClientCert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - false, - ), - ( - "NoClientCert_cert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::RequireAnyClientCert, - ..Default::default() - }, - false, - ), - ( - "RequestClientCert_cert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::RequestClientCert, - ..Default::default() - }, - false, - ), - ( - "RequestClientCert_no_cert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::RequestClientCert, - ..Default::default() - }, - false, - ), - ( - "RequireAnyClientCert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::RequireAnyClientCert, - ..Default::default() - }, - false, - ), - ( - "RequireAnyClientCert_error", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::RequireAnyClientCert, - ..Default::default() - }, - true, - ), - ( - "VerifyClientCertIfGiven_no_cert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::VerifyClientCertIfGiven, - client_cas: ca_pool.clone(), - ..Default::default() - }, - false, - ), - ( - "VerifyClientCertIfGiven_cert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::VerifyClientCertIfGiven, - client_cas: ca_pool.clone(), - ..Default::default() - }, - false, - ), - ( - "VerifyClientCertIfGiven_error", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::VerifyClientCertIfGiven, - ..Default::default() - }, - true, - ), - ( - "RequireAndVerifyClientCert", - Config { - roots_cas: srv_ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![srv_cert.clone()], - client_auth: ClientAuthType::RequireAndVerifyClientCert, - client_cas: ca_pool.clone(), - ..Default::default() - }, - false, - ), - ]; - - for (name, client_cfg, server_cfg, want_err) in tests { - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - let (ca, cb) = pipe(); - let client_cfg_clone = client_cfg.clone(); - tokio::spawn(async move { - let result = DTLSConn::new(Arc::new(ca), client_cfg_clone, true, None).await; - let _ = client_res_tx.send(result).await; - }); - - let result = DTLSConn::new(Arc::new(cb), server_cfg.clone(), false, None).await; - let client_result = client_res_rx.recv().await; - - if want_err { - if result.is_err() { - continue; - } - panic!("{name} Error expected"); - } - - assert!( - result.is_ok(), - "{} Server failed({:?})", - name, - result.err().unwrap() - ); - assert!(client_result.is_some(), "{name}, expected client conn"); - - let res = client_result.unwrap(); - assert!( - res.is_ok(), - "{} Client failed({:?})", - name, - res.err().unwrap() - ); - - let server = result.unwrap(); - let client = res.unwrap(); - - let actual_client_cert = &server.connection_state().await.peer_certificates; - if server_cfg.client_auth == ClientAuthType::RequireAnyClientCert - || server_cfg.client_auth == ClientAuthType::RequireAndVerifyClientCert - { - assert!( - !actual_client_cert.is_empty(), - "{name} Client did not provide a certificate", - ); - //if actual_client_cert.len() != len(tt.clientCfg.Certificates[0].Certificate) || !bytes.Equal(tt.clientCfg.Certificates[0].Certificate[0], actual_client_cert[0]) { - assert_eq!( - actual_client_cert[0], - client_cfg.certificates[0].certificate[0].as_ref(), - "{name} Client certificate was not communicated correctly", - ); - } - - if server_cfg.client_auth == ClientAuthType::NoClientCert { - assert!( - actual_client_cert.is_empty(), - "{name} Client certificate wasn't expected", - ); - } - - let actual_server_cert = &client.connection_state().await.peer_certificates; - assert!( - !actual_server_cert.is_empty(), - "{name} Server did not provide a certificate", - ); - - /*if len(actual_server_cert) != len(tt.serverCfg.Certificates[0].Certificate) - || !bytes.Equal( - tt.serverCfg.Certificates[0].Certificate[0], - actual_server_cert[0], - )*/ - assert_eq!( - actual_server_cert[0].len(), - server_cfg.certificates[0].certificate[0].as_ref().len(), - "{name} Server certificate was not communicated correctly", - ); - assert_eq!( - actual_server_cert[0], - server_cfg.certificates[0].certificate[0].as_ref(), - "{name} Server certificate was not communicated correctly", - ); - } - - Ok(()) -} - -#[tokio::test] -async fn test_extended_master_secret() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let tests = vec![ - ( - "Request_Request_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Request, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Request, - ..Default::default() - }, - None, - None, - ), - ( - "Request_Require_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Request, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }, - None, - None, - ), - ( - "Request_Disable_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Request, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Disable, - ..Default::default() - }, - None, - None, - ), - ( - "Require_Request_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Request, - ..Default::default() - }, - None, - None, - ), - ( - "Require_Require_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }, - None, - None, - ), - ( - "Require_Disable_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Disable, - ..Default::default() - }, - Some(Error::ErrClientRequiredButNoServerEms), - Some(Error::ErrAlertFatalOrClose), - ), - ( - "Disable_Request_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Disable, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Request, - ..Default::default() - }, - None, - None, - ), - ( - "Disable_Require_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Disable, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Require, - ..Default::default() - }, - Some(Error::ErrAlertFatalOrClose), - Some(Error::ErrServerRequiredButNoClientEms), - ), - ( - "Disable_Disable_ExtendedMasterSecret", - Config { - extended_master_secret: ExtendedMasterSecretType::Disable, - ..Default::default() - }, - Config { - extended_master_secret: ExtendedMasterSecretType::Disable, - ..Default::default() - }, - None, - None, - ), - ]; - - for (name, client_cfg, server_cfg, expected_client_err, expected_server_err) in tests { - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - let (ca, cb) = pipe(); - let client_cfg_clone = client_cfg.clone(); - tokio::spawn(async move { - let result = create_test_client(Arc::new(ca), client_cfg_clone, true).await; - let _ = client_res_tx.send(result).await; - }); - - let result = create_test_server(Arc::new(cb), server_cfg.clone(), true).await; - let client_result = client_res_rx.recv().await; - assert!(client_result.is_some(), "{name}, expected client conn"); - let res = client_result.unwrap(); - - if let Some(client_err) = expected_client_err { - if let Err(err) = res { - assert_eq!( - err.to_string(), - client_err.to_string(), - "Client error expected: \"{client_err}\" but got \"{err}\"", - ); - } else { - panic!("{name} expected err, but got ok"); - } - } else { - assert!(res.is_ok(), "{name} expected ok, but got err"); - } - - if let Some(server_err) = expected_server_err { - if let Err(err) = result { - assert_eq!( - err.to_string(), - server_err.to_string(), - "Server error expected: \"{server_err}\" but got \"{err}\"", - ); - } else { - panic!("{name} expected err, but got ok"); - } - } else { - assert!(result.is_ok(), "{name} expected ok, but got err"); - } - } - - Ok(()) -} - -fn fn_not_expected_chain(_cert: &[Vec], chain: &[CertificateDer<'static>]) -> Result<()> { - if !chain.is_empty() { - return Err(Error::Other(ERR_NOT_EXPECTED_CHAIN.to_owned())); - } - Ok(()) -} - -fn fn_expected_chain(_cert: &[Vec], chain: &[CertificateDer<'static>]) -> Result<()> { - if chain.is_empty() { - return Err(Error::Other(ERR_EXPECTED_CHAIN.to_owned())); - } - Ok(()) -} - -fn fn_wrong_cert(_cert: &[Vec], _chain: &[CertificateDer<'static>]) -> Result<()> { - Err(Error::Other(ERR_WRONG_CERT.to_owned())) -} - -#[tokio::test] -async fn test_server_certificate() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let server_name = "localhost".to_owned(); - let cert = Certificate::generate_self_signed(vec![server_name.clone()])?; - let mut ca_pool = rustls::RootCertStore::empty(); - ca_pool - .add(cert.certificate[0].clone()) - .map_err(|_err| Error::Other("add cert error".to_owned()))?; - - let tests = vec![ - ( - "no_ca", - Config { - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - true, - ), - ( - "good_ca", - Config { - roots_cas: ca_pool.clone(), - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - false, - ), - ( - "no_ca_skip_verify", - Config { - insecure_skip_verify: true, - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - false, - ), - ( - "good_ca_skip_verify_custom_verify_peer", - Config { - roots_cas: ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::RequireAnyClientCert, - verify_peer_certificate: Some(Arc::new(fn_not_expected_chain)), - ..Default::default() - }, - false, - ), - ( - "good_ca_verify_custom_verify_peer", - Config { - roots_cas: ca_pool.clone(), - server_name: server_name.clone(), - certificates: vec![cert.clone()], - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::RequireAndVerifyClientCert, - client_cas: ca_pool.clone(), - verify_peer_certificate: Some(Arc::new(fn_expected_chain)), - ..Default::default() - }, - false, - ), - ( - "good_ca_custom_verify_peer", - Config { - roots_cas: ca_pool.clone(), - server_name: server_name.clone(), - verify_peer_certificate: Some(Arc::new(fn_wrong_cert)), - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - true, - ), - ( - "server_name", - Config { - roots_cas: ca_pool.clone(), - server_name: server_name.clone(), - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - false, - ), - ( - "server_name_error", - Config { - roots_cas: ca_pool.clone(), - server_name: "barfoo".to_owned(), - ..Default::default() - }, - Config { - certificates: vec![cert.clone()], - client_auth: ClientAuthType::NoClientCert, - ..Default::default() - }, - true, - ), - ]; - - for (name, client_cfg, server_cfg, want_err) in tests { - let (res_tx, mut res_rx) = mpsc::channel(1); - let (ca, cb) = pipe(); - - tokio::spawn(async move { - let result = DTLSConn::new(Arc::new(cb), server_cfg, false, None).await; - let _ = res_tx.send(result).await; - }); - - let cli_result = DTLSConn::new(Arc::new(ca), client_cfg, true, None).await; - - if !want_err && cli_result.is_err() { - panic!("{}: Client failed({})", name, cli_result.err().unwrap()); - } - if want_err && cli_result.is_ok() { - panic!("{name}: Error expected"); - } - - let _ = res_rx.recv().await; - } - Ok(()) -} - -#[tokio::test] -async fn test_cipher_suite_configuration() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let tests = vec![ - ( - "No CipherSuites specified", - vec![], - vec![], - None, - None, - None, - ), - ( - "Invalid CipherSuite", - vec![CipherSuiteId::Unsupported], - vec![CipherSuiteId::Unsupported], - Some(Error::ErrInvalidCipherSuite), - Some(Error::ErrInvalidCipherSuite), - None, - ), - ( - "Valid CipherSuites specified", - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - None, - None, - Some(CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256), - ), - ( - "CipherSuites mismatch", - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha], - Some(Error::ErrAlertFatalOrClose), - Some(Error::ErrCipherSuiteNoIntersection), - None, - ), - ( - "Valid CipherSuites CCM specified", - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm], - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm], - None, - None, - Some(CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm), - ), - ( - "Valid CipherSuites CCM-8 specified", - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8], - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8], - None, - None, - Some(CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Ccm_8), - ), - ( - "Server supports subset of client suites", - vec![ - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha, - ], - vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha], - None, - None, - Some(CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha), - ), - ]; - - for ( - name, - client_cipher_suites, - server_cipher_suites, - want_client_error, - want_server_error, - want_selected_cipher_suite, - ) in tests - { - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - let (ca, cb) = pipe(); - tokio::spawn(async move { - let conf = Config { - cipher_suites: client_cipher_suites, - ..Default::default() - }; - - let result = create_test_client(Arc::new(ca), conf, true).await; - let _ = client_res_tx.send(result).await; - }); - - let config = Config { - cipher_suites: server_cipher_suites, - ..Default::default() - }; - - let result = create_test_server(Arc::new(cb), config, true).await; - if let Some(expected_err) = want_server_error { - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "{name} test_cipher_suite_configuration: Server error exp({expected_err}) failed({err})", - ); - } else { - panic!("{name} expected error, but got ok"); - } - } else { - assert!(result.is_ok(), "{name} expected ok, but got error") - } - - let client_result = client_res_rx.recv().await; - if let Some(result) = client_result { - if let Some(expected_err) = want_client_error { - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "{name} test_cipher_suite_configuration: Client error exp({expected_err}) failed({err})", - ); - } else { - panic!("{name} expected error, but got ok"); - } - } else { - assert!(result.is_ok(), "{name} expected ok, but got error"); - let client = result.unwrap(); - if let Some(want_cs) = want_selected_cipher_suite { - let cipher_suite = client.state.cipher_suite.lock().await; - assert!(cipher_suite.is_some(), "{name} expected some, but got none"); - if let Some(cs) = &*cipher_suite { - assert_eq!(cs.id(), want_cs, - "test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})", - name, want_cs, cs.id()); - } - } - } - } else { - panic!("{name} expected Some, but got None"); - } - } - - Ok(()) -} - -fn psk_callback(_b: &[u8]) -> Result> { - Ok(vec![0x00, 0x01, 0x02]) -} - -#[tokio::test] -async fn test_psk_configuration() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let tests = vec![ - ( - "PSK specified", - false, - false, - true, //Some(psk_callback), - true, //Some(psk_callback), - Some(vec![0x00]), - Some(vec![0x00]), - Some(Error::ErrNoAvailableCipherSuites), - Some(Error::ErrNoAvailableCipherSuites), - ), - ( - "PSK and certificate specified", - true, - true, - true, //Some(psk_callback), - true, //Some(psk_callback), - Some(vec![0x00]), - Some(vec![0x00]), - Some(Error::ErrPskAndCertificate), - Some(Error::ErrPskAndCertificate), - ), - ( - "PSK and no identity specified", - false, - false, - true, //Some(psk_callback), - true, //Some(psk_callback), - None, - None, - Some(Error::ErrPskAndIdentityMustBeSetForClient), - Some(Error::ErrNoAvailableCipherSuites), - ), - ( - "No PSK and identity specified", - false, - false, - false, - false, - Some(vec![0x00]), - Some(vec![0x00]), - Some(Error::ErrIdentityNoPsk), - Some(Error::ErrServerMustHaveCertificate), - ), - ]; - - for ( - name, - client_has_certificate, - server_has_certificate, - client_psk, - server_psk, - client_psk_identity, - server_psk_identity, - want_client_error, - want_server_error, - ) in tests - { - let (client_res_tx, mut client_res_rx) = mpsc::channel(1); - let (ca, cb) = pipe(); - tokio::spawn(async move { - let conf = Config { - psk: if client_psk { - Some(Arc::new(psk_callback)) - } else { - None - }, - psk_identity_hint: client_psk_identity, - ..Default::default() - }; - - let result = create_test_client(Arc::new(ca), conf, client_has_certificate).await; - let _ = client_res_tx.send(result).await; - }); - - let config = Config { - psk: if server_psk { - Some(Arc::new(psk_callback)) - } else { - None - }, - psk_identity_hint: server_psk_identity, - ..Default::default() - }; - - let result = create_test_server(Arc::new(cb), config, server_has_certificate).await; - if let Some(expected_err) = want_server_error { - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "{name} test_psk_configuration: Server error exp({expected_err}) failed({err})", - ); - } else { - panic!("{name} expected error, but got ok"); - } - } else { - assert!(result.is_ok(), "{name} expected ok, but got error") - } - - let client_result = client_res_rx.recv().await; - if let Some(result) = client_result { - if let Some(expected_err) = want_client_error { - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "{name} test_psk_configuration: Client error exp({expected_err}) failed({err})", - ); - } else { - panic!("{name} expected error, but got ok"); - } - } else { - assert!(result.is_ok(), "{name} expected ok, but got error"); - } - } else { - panic!("{name} expected Some, but got None"); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_server_timeout() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut cookie = vec![0u8; 20]; - rand::thread_rng().fill(cookie.as_mut_slice()); - - let random_bytes = [0u8; RANDOM_BYTES_LENGTH]; - let gmt_unix_time = SystemTime::UNIX_EPOCH - .checked_add(Duration::new(500, 0)) - .unwrap(); - let random = HandshakeRandom { - gmt_unix_time, - random_bytes, - }; - - let cipher_suites = vec![ - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, //&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}, - CipherSuiteId::Tls_Ecdhe_Rsa_With_Aes_128_Gcm_Sha256, //&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{}, - ]; - - let extensions = vec![ - Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms { - signature_hash_algorithms: vec![ - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Rsa, - }, - ], - }), - Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves { - elliptic_curves: vec![NamedCurve::X25519, NamedCurve::P256, NamedCurve::P384], - }), - Extension::SupportedPointFormats(ExtensionSupportedPointFormats { - point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED], - }), - ]; - - let record = RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: PROTOCOL_VERSION1_2, - cookie, - random, - cipher_suites, - compression_methods: default_compression_methods(), - extensions, - }, - ))), - ); - - let mut packet = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(packet.as_mut()); - record.marshal(&mut writer)?; - } - - use util::Conn; - let (ca, cb) = pipe(); - - // Client reader - let (ca_read_chan_tx, mut ca_read_chan_rx) = mpsc::channel(1000); - - let ca_rx = Arc::new(ca); - let ca_tx = Arc::clone(&ca_rx); - - tokio::spawn(async move { - let mut data = vec![0; 8192]; - loop { - if let Ok(n) = ca_rx.recv(&mut data).await { - let result = ca_read_chan_tx.send(data[..n].to_vec()).await; - if result.is_ok() { - return; - } - } else { - return; - } - } - }); - - // Start sending ClientHello packets until server responds with first packet - tokio::spawn(async move { - loop { - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() => { - let result = ca_tx.send(&packet).await; - if result.is_err() { - return; - } - } - _ = ca_read_chan_rx.recv() => return, - } - } - }); - - let config = Config { - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - flight_interval: Duration::from_millis(100), - ..Default::default() - }; - - let result = tokio::time::timeout( - Duration::from_millis(50), - create_test_server(Arc::new(cb), config, true), - ) - .await; - assert!(result.is_err(), "Expected Error but got Ok"); - - // Wait a little longer to ensure no additional messages have been sent by the server - //tokio::time::sleep(Duration::from_millis(300)).await; - - /*tokio::select! { - case msg := <-caReadChan: - t.Fatalf("Expected no additional messages from server, got: %+v", msg) - default: - }*/ - - Ok(()) -} - -#[tokio::test] -async fn test_protocol_version_validation() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut cookie = vec![0; 20]; - rand::thread_rng().fill(cookie.as_mut_slice()); - - let random_bytes = [0u8; RANDOM_BYTES_LENGTH]; - let gmt_unix_time = SystemTime::UNIX_EPOCH - .checked_add(Duration::new(500, 0)) - .unwrap(); - let random = HandshakeRandom { - gmt_unix_time, - random_bytes, - }; - - let local_keypair = NamedCurve::X25519.generate_keypair()?; - - //|"Server"| - { - let server_cases = vec![ - ( - "ClientHelloVersion", - vec![RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: ProtocolVersion { - major: 0xfe, - minor: 0xff, - }, // try to downgrade - cookie: cookie.clone(), - random: random.clone(), - cipher_suites: vec![ - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - ], - compression_methods: default_compression_methods(), - extensions: vec![], - }, - ))), - )], - ), - ( - "SecondsClientHelloVersion", - vec![ - RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: PROTOCOL_VERSION1_2, - cookie: cookie.clone(), - random: random.clone(), - cipher_suites: vec![ - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - ], - compression_methods: default_compression_methods(), - extensions: vec![], - }, - ))), - ), - { - let mut handshake = Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: ProtocolVersion { - major: 0xfe, - minor: 0xff, - }, // try to downgrade - cookie: cookie.clone(), - random: random.clone(), - cipher_suites: vec![ - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - ], - compression_methods: default_compression_methods(), - extensions: vec![], - }, - )); - handshake.handshake_header.message_sequence = 1; - let mut record_layer = - RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(handshake)); - record_layer.record_layer_header.sequence_number = 1; - - record_layer - }, - ], - ), - ]; - - use util::Conn; - for (name, records) in server_cases { - let (ca, cb) = pipe(); - - tokio::spawn(async move { - let config = Config { - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - flight_interval: Duration::from_millis(100), - ..Default::default() - }; - let timeout_result = tokio::time::timeout( - Duration::from_millis(1000), - create_test_server(Arc::new(cb), config, true), - ) - .await; - match timeout_result { - Ok(result) => { - if let Err(err) = result { - assert_eq!( - err.to_string(), - Error::ErrUnsupportedProtocolVersion.to_string(), - "{} Client error exp({}) failed({})", - name, - Error::ErrUnsupportedProtocolVersion, - err, - ); - } else { - panic!("{name} expected error, but got ok"); - } - } - Err(err) => { - panic!("server timeout {err}"); - } - }; - }); - - tokio::time::sleep(Duration::from_millis(50)).await; - - let mut resp = vec![0; 1024]; - let mut n = 0; - for record in records { - let mut packet = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(packet.as_mut()); - record.marshal(&mut writer)?; - } - - let _ = ca.send(&packet).await; - n = ca.recv(&mut resp).await?; - } - - let mut reader = BufReader::new(&resp[..n]); - let h = RecordLayerHeader::unmarshal(&mut reader)?; - assert_eq!( - h.content_type, - ContentType::Alert, - "Peer must return alert to unsupported protocol version" - ); - } - } - - //"Client" - { - let client_cases = vec![( - "ServerHelloVersion", - vec![ - RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::HelloVerifyRequest( - HandshakeMessageHelloVerifyRequest { - version: PROTOCOL_VERSION1_2, - cookie: cookie.clone(), - }, - ))), - ), - { - let mut handshake = Handshake::new(HandshakeMessage::ServerHello( - HandshakeMessageServerHello { - version: ProtocolVersion { - major: 0xfe, - minor: 0xff, - }, // try to downgrade - random: random.clone(), - cipher_suite: CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - compression_method: default_compression_methods().ids[0], - extensions: vec![], - }, - )); - handshake.handshake_header.message_sequence = 1; - let mut record = - RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(handshake)); - record.record_layer_header.sequence_number = 1; - record - }, - { - let mut handshake = Handshake::new(HandshakeMessage::Certificate( - HandshakeMessageCertificate { - certificate: vec![], - }, - )); - handshake.handshake_header.message_sequence = 2; - let mut record = - RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(handshake)); - record.record_layer_header.sequence_number = 2; - record - }, - { - let mut handshake = Handshake::new(HandshakeMessage::ServerKeyExchange( - HandshakeMessageServerKeyExchange { - identity_hint: vec![], - elliptic_curve_type: EllipticCurveType::NamedCurve, - named_curve: NamedCurve::X25519, - public_key: local_keypair.public_key.clone(), - algorithm: SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - signature: vec![0; 64], - }, - )); - handshake.handshake_header.message_sequence = 3; - let mut record = - RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(handshake)); - record.record_layer_header.sequence_number = 3; - record - }, - { - let mut handshake = Handshake::new(HandshakeMessage::ServerHelloDone( - HandshakeMessageServerHelloDone {}, - )); - handshake.handshake_header.message_sequence = 4; - let mut record = - RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(handshake)); - record.record_layer_header.sequence_number = 4; - record - }, - ], - )]; - - use util::Conn; - for (name, records) in client_cases { - let (ca, cb) = pipe(); - - tokio::spawn(async move { - let config = Config { - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - flight_interval: Duration::from_millis(100), - ..Default::default() - }; - let timeout_result = tokio::time::timeout( - Duration::from_millis(1000), - create_test_client(Arc::new(cb), config, true), - ) - .await; - match timeout_result { - Ok(result) => { - if let Err(err) = result { - assert_eq!( - err.to_string(), - Error::ErrUnsupportedProtocolVersion.to_string(), - "{} Server error exp({}) failed({})", - name, - Error::ErrUnsupportedProtocolVersion, - err, - ); - } else { - panic!("{name} expected error, but got ok"); - } - } - Err(err) => { - panic!("server timeout {err}"); - } - }; - }); - - tokio::time::sleep(Duration::from_millis(50)).await; - - let mut resp = vec![0; 1024]; - for record in records { - let _ = ca.recv(&mut resp).await?; - - let mut packet = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(packet.as_mut()); - record.marshal(&mut writer)?; - } - let _ = ca.send(&packet).await; - } - - let n = ca.recv(&mut resp).await?; - - let mut reader = BufReader::new(&resp[..n]); - let h = RecordLayerHeader::unmarshal(&mut reader)?; - - assert_eq!( - h.content_type, - ContentType::Alert, - "Peer must return alert to unsupported protocol version" - ); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_multiple_hello_verify_request() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut cookies = vec![ - // first clientHello contains an empty cookie - vec![], - ]; - - let mut packets = vec![]; - for i in 0..2 { - let mut cookie = vec![0; 20]; - rand::thread_rng().fill(cookie.as_mut_slice()); - cookies.push(cookie.clone()); - - let mut handshake = Handshake::new(HandshakeMessage::HelloVerifyRequest( - HandshakeMessageHelloVerifyRequest { - version: PROTOCOL_VERSION1_2, - cookie, - }, - )); - handshake.handshake_header.message_sequence = i as u16; - - let mut record = RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(handshake)); - record.record_layer_header.sequence_number = i as u64; - - let mut packet = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(packet.as_mut()); - record.marshal(&mut writer)?; - } - - packets.push(packet); - } - - let (ca, cb) = pipe(); - - tokio::spawn(async move { - let conf = Config::default(); - let _ = tokio::time::timeout( - Duration::from_millis(100), - create_test_client(Arc::new(ca), conf, true), - ) - .await; - }); - - for i in 0..cookies.len() { - let cookie = &cookies[i]; - trace!("cookie {}: {:?}", i, cookie); - - // read client hello - let mut resp = vec![0; 1024]; - let n = cb.recv(&mut resp).await?; - let mut reader = BufReader::new(&resp[..n]); - let record = RecordLayer::unmarshal(&mut reader)?; - match record.content { - Content::Handshake(h) => match h.handshake_message { - HandshakeMessage::ClientHello(client_hello) => { - assert_eq!( - &client_hello.cookie, cookie, - "Wrong cookie {}, expected: {:?}, got: {:?}", - i, &client_hello.cookie, cookie - ); - } - _ => panic!("unexpected handshake message"), - }, - _ => panic!("unexpected content"), - }; - - if packets.len() <= i { - break; - } - // write hello verify request - cb.send(&packets[i]).await?; - } - - Ok(()) -} - -async fn send_client_hello( - cookie: Vec, - ca: &Arc, - sequence_number: u64, - send_renegotiation_info: bool, -) -> Result<()> { - let mut extensions = vec![]; - if send_renegotiation_info { - extensions.push(Extension::RenegotiationInfo(ExtensionRenegotiationInfo { - renegotiated_connection: 0, - })); - } - - let mut h = Handshake::new(HandshakeMessage::ClientHello(HandshakeMessageClientHello { - version: PROTOCOL_VERSION1_2, - random: HandshakeRandom::default(), - cookie, - - cipher_suites: vec![CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256], - compression_methods: default_compression_methods(), - extensions, - })); - h.handshake_header.message_sequence = sequence_number as u16; - - let mut record = RecordLayer::new(PROTOCOL_VERSION1_2, 0, Content::Handshake(h)); - record.record_layer_header.sequence_number = sequence_number; - - let mut packet = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(packet.as_mut()); - record.marshal(&mut writer)?; - } - - ca.send(&packet).await?; - - Ok(()) -} - -// Assert that a DTLS Server always responds with RenegotiationInfo if -// a ClientHello contained that extension or not -#[cfg(not(target_os = "windows"))] // this times out in CI on windows. -#[tokio::test] -async fn test_renegotiation_info() -> Result<()> { - let mut resp = vec![0u8; 1024]; - - let tests = vec![ - ("Include RenegotiationInfo", true), - ("No RenegotiationInfo", false), - ]; - - for (name, send_renegotiation_info) in tests { - let (ca, cb) = pipe(); - - tokio::spawn(async move { - let conf = Config::default(); - let _ = tokio::time::timeout( - Duration::from_millis(100), - create_test_server(Arc::new(cb), conf, true), - ) - .await; - }); - - tokio::time::sleep(Duration::from_millis(5)).await; - - let ca: Arc = Arc::new(ca); - send_client_hello(vec![], &ca, 0, send_renegotiation_info).await?; - - let n = ca.recv(&mut resp).await?; - let mut reader = BufReader::new(&resp[..n]); - let record = RecordLayer::unmarshal(&mut reader)?; - - let hello_verify_request = match record.content { - Content::Handshake(h) => match h.handshake_message { - HandshakeMessage::HelloVerifyRequest(hvr) => hvr, - _ => { - panic!("unexpected handshake message"); - } - }, - _ => { - panic!("unexpected content"); - } - }; - - send_client_hello( - hello_verify_request.cookie.clone(), - &ca, - 1, - send_renegotiation_info, - ) - .await?; - let n = ca.recv(&mut resp).await?; - let messages = unpack_datagram(&resp[..n])?; - - let mut reader = BufReader::new(&messages[0][..]); - let record = RecordLayer::unmarshal(&mut reader)?; - - let server_hello = match record.content { - Content::Handshake(h) => match h.handshake_message { - HandshakeMessage::ServerHello(sh) => sh, - _ => { - panic!("unexpected handshake message"); - } - }, - _ => { - panic!("unexpected content"); - } - }; - - let got_negotiation_info = server_hello - .extensions - .iter() - .any(|v| matches!(v, Extension::RenegotiationInfo(_))); - - assert!( - got_negotiation_info, - "{name}: Received ServerHello without RenegotiationInfo" - ); - - ca.close().await?; - } - - Ok(()) -} diff --git a/dtls/src/conn/mod.rs b/dtls/src/conn/mod.rs deleted file mode 100644 index f160443c2..000000000 --- a/dtls/src/conn/mod.rs +++ /dev/null @@ -1,1215 +0,0 @@ -#[cfg(test)] -mod conn_test; - -use std::io::{BufReader, BufWriter}; -use std::marker::{Send, Sync}; -use std::net::SocketAddr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use log::*; -use portable_atomic::{AtomicBool, AtomicU16}; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; -use util::replay_detector::*; -use util::Conn; - -use crate::alert::*; -use crate::application_data::*; -use crate::cipher_suite::*; -use crate::config::*; -use crate::content::*; -use crate::curve::named_curve::NamedCurve; -use crate::error::*; -use crate::extension::extension_use_srtp::*; -use crate::flight::flight0::*; -use crate::flight::flight1::*; -use crate::flight::flight5::*; -use crate::flight::flight6::*; -use crate::flight::*; -use crate::fragment_buffer::*; -use crate::handshake::handshake_cache::*; -use crate::handshake::handshake_header::HandshakeHeader; -use crate::handshake::*; -use crate::handshaker::*; -use crate::record_layer::record_layer_header::*; -use crate::record_layer::*; -use crate::signature_hash_algorithm::parse_signature_schemes; -use crate::state::*; - -pub(crate) const INITIAL_TICKER_INTERVAL: Duration = Duration::from_secs(1); -pub(crate) const COOKIE_LENGTH: usize = 20; -pub(crate) const DEFAULT_NAMED_CURVE: NamedCurve = NamedCurve::X25519; -pub(crate) const INBOUND_BUFFER_SIZE: usize = 8192; -// Default replay protection window is specified by RFC 6347 Section 4.1.2.6 -pub(crate) const DEFAULT_REPLAY_PROTECTION_WINDOW: usize = 64; - -pub static INVALID_KEYING_LABELS: &[&str] = &[ - "client finished", - "server finished", - "master secret", - "key expansion", -]; - -type PacketSendRequest = (Vec, Option>>); - -struct ConnReaderContext { - is_client: bool, - replay_protection_window: usize, - replay_detector: Vec>, - decrypted_tx: mpsc::Sender>>, - encrypted_packets: Vec>, - fragment_buffer: FragmentBuffer, - cache: HandshakeCache, - cipher_suite: Arc>>>, - remote_epoch: Arc, - handshake_tx: mpsc::Sender>, - handshake_done_rx: mpsc::Receiver<()>, - packet_tx: Arc>, -} - -// Conn represents a DTLS connection -pub struct DTLSConn { - conn: Arc, - pub(crate) cache: HandshakeCache, // caching of handshake messages for verifyData generation - decrypted_rx: Mutex>>>, // Decrypted Application Data or error, pull by calling `Read` - pub(crate) state: State, // Internal state - - handshake_completed_successfully: Arc, - connection_closed_by_user: bool, - // closeLock sync.Mutex - closed: AtomicBool, // *closer.Closer - //handshakeLoopsFinished sync.WaitGroup - - //readDeadline :deadline.Deadline, - //writeDeadline :deadline.Deadline, - - //log logging.LeveledLogger - /* - reading chan struct{} - handshakeRecv chan chan struct{} - cancelHandshaker func() - cancelHandshakeReader func() - */ - pub(crate) current_flight: Box, - pub(crate) flights: Option>, - pub(crate) cfg: HandshakeConfig, - pub(crate) retransmit: bool, - pub(crate) handshake_rx: mpsc::Receiver>, - - pub(crate) packet_tx: Arc>, - pub(crate) handle_queue_tx: mpsc::Sender>, - pub(crate) handshake_done_tx: Option>, - - reader_close_tx: Mutex>>, -} - -type UtilResult = std::result::Result; - -#[async_trait] -impl Conn for DTLSConn { - async fn connect(&self, _addr: SocketAddr) -> UtilResult<()> { - Err(util::Error::Other("Not applicable".to_owned())) - } - async fn recv(&self, buf: &mut [u8]) -> UtilResult { - self.read(buf, None).await.map_err(util::Error::from_std) - } - async fn recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)> { - if let Some(raddr) = self.conn.remote_addr() { - let n = self.read(buf, None).await.map_err(util::Error::from_std)?; - Ok((n, raddr)) - } else { - Err(util::Error::Other( - "No remote address is provided by underlying Conn".to_owned(), - )) - } - } - async fn send(&self, buf: &[u8]) -> UtilResult { - self.write(buf, None).await.map_err(util::Error::from_std) - } - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult { - Err(util::Error::Other("Not applicable".to_owned())) - } - fn local_addr(&self) -> UtilResult { - self.conn.local_addr() - } - fn remote_addr(&self) -> Option { - self.conn.remote_addr() - } - async fn close(&self) -> UtilResult<()> { - self.close().await.map_err(util::Error::from_std) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -impl DTLSConn { - pub async fn new( - conn: Arc, - mut config: Config, - is_client: bool, - initial_state: Option, - ) -> Result { - validate_config(is_client, &config)?; - - let local_cipher_suites: Vec = parse_cipher_suites( - &config.cipher_suites, - config.psk.is_none(), - config.psk.is_some(), - )? - .iter() - .map(|cs| cs.id()) - .collect(); - - let sigs: Vec = config.signature_schemes.iter().map(|x| *x as u16).collect(); - let local_signature_schemes = parse_signature_schemes(&sigs, config.insecure_hashes)?; - - let retransmit_interval = if config.flight_interval != Duration::from_secs(0) { - config.flight_interval - } else { - INITIAL_TICKER_INTERVAL - }; - - /* - loggerFactory := config.LoggerFactory - if loggerFactory == nil { - loggerFactory = logging.NewDefaultLoggerFactory() - } - - logger := loggerFactory.NewLogger("dtls") - */ - let maximum_transmission_unit = if config.mtu == 0 { - DEFAULT_MTU - } else { - config.mtu - }; - - let replay_protection_window = if config.replay_protection_window == 0 { - DEFAULT_REPLAY_PROTECTION_WINDOW - } else { - config.replay_protection_window - }; - - let mut server_name = config.server_name.clone(); - - // Use host from conn address when server_name is not provided - if is_client && server_name.is_empty() { - if let Some(remote_addr) = conn.remote_addr() { - server_name = remote_addr.ip().to_string(); - } else { - warn!("conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now"); - "localhost".clone_into(&mut server_name); - } - } - - let cfg = HandshakeConfig { - local_psk_callback: config.psk.take(), - local_psk_identity_hint: config.psk_identity_hint.take(), - local_cipher_suites, - local_signature_schemes, - extended_master_secret: config.extended_master_secret, - local_srtp_protection_profiles: config.srtp_protection_profiles.clone(), - server_name, - client_auth: config.client_auth, - local_certificates: config.certificates.clone(), - insecure_skip_verify: config.insecure_skip_verify, - insecure_verification: config.insecure_verification, - verify_peer_certificate: config.verify_peer_certificate.take(), - client_cert_verifier: if config.client_auth as u8 - >= ClientAuthType::VerifyClientCertIfGiven as u8 - { - Some( - rustls::server::WebPkiClientVerifier::builder(Arc::new(config.client_cas)) - .allow_unauthenticated() - .build() - .unwrap_or( - rustls::server::WebPkiClientVerifier::builder(Arc::new( - gen_self_signed_root_cert(), - )) - .allow_unauthenticated() - .build() - .unwrap(), - ), - ) - } else { - None - }, - server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new( - config.roots_cas, - )) - .build() - .unwrap_or( - rustls::client::WebPkiServerVerifier::builder( - Arc::new(gen_self_signed_root_cert()), - ) - .build() - .unwrap(), - ), - retransmit_interval, - //log: logger, - initial_epoch: 0, - ..Default::default() - }; - - let (state, flight, initial_fsm_state) = if let Some(state) = initial_state { - let flight = if is_client { - Box::new(Flight5 {}) as Box - } else { - Box::new(Flight6 {}) as Box - }; - - (state, flight, HandshakeState::Finished) - } else { - let flight = if is_client { - Box::new(Flight1 {}) as Box - } else { - Box::new(Flight0 {}) as Box - }; - - ( - State { - is_client, - ..Default::default() - }, - flight, - HandshakeState::Preparing, - ) - }; - - let (decrypted_tx, decrypted_rx) = mpsc::channel(1); - let (handshake_tx, handshake_rx) = mpsc::channel(1); - let (handshake_done_tx, handshake_done_rx) = mpsc::channel(1); - let (packet_tx, mut packet_rx) = mpsc::channel(1); - let (handle_queue_tx, mut handle_queue_rx) = mpsc::channel(1); - let (reader_close_tx, mut reader_close_rx) = mpsc::channel(1); - - let packet_tx = Arc::new(packet_tx); - let packet_tx2 = Arc::clone(&packet_tx); - let next_conn_rx = Arc::clone(&conn); - let next_conn_tx = Arc::clone(&conn); - let cache = HandshakeCache::new(); - let mut cache1 = cache.clone(); - let cache2 = cache.clone(); - let handshake_completed_successfully = Arc::new(AtomicBool::new(false)); - let handshake_completed_successfully2 = Arc::clone(&handshake_completed_successfully); - - let mut c = DTLSConn { - conn: Arc::clone(&conn), - cache, - decrypted_rx: Mutex::new(decrypted_rx), - state, - handshake_completed_successfully, - connection_closed_by_user: false, - closed: AtomicBool::new(false), - - current_flight: flight, - flights: None, - cfg, - retransmit: false, - handshake_rx, - packet_tx, - handle_queue_tx, - handshake_done_tx: Some(handshake_done_tx), - reader_close_tx: Mutex::new(Some(reader_close_tx)), - }; - - let cipher_suite1 = Arc::clone(&c.state.cipher_suite); - let sequence_number = Arc::clone(&c.state.local_sequence_number); - - tokio::spawn(async move { - loop { - let rx = packet_rx.recv().await; - if let Some(r) = rx { - let (pkt, result_tx) = r; - - let result = DTLSConn::handle_outgoing_packets( - &next_conn_tx, - pkt, - &mut cache1, - is_client, - &sequence_number, - &cipher_suite1, - maximum_transmission_unit, - ) - .await; - - if let Some(tx) = result_tx { - let _ = tx.send(result).await; - } - } else { - trace!("{}: handle_outgoing_packets exit", srv_cli_str(is_client)); - break; - } - } - }); - - let local_epoch = Arc::clone(&c.state.local_epoch); - let remote_epoch = Arc::clone(&c.state.remote_epoch); - let cipher_suite2 = Arc::clone(&c.state.cipher_suite); - - tokio::spawn(async move { - let mut buf = vec![0u8; INBOUND_BUFFER_SIZE]; - let mut ctx = ConnReaderContext { - is_client, - replay_protection_window, - replay_detector: vec![], - decrypted_tx, - encrypted_packets: vec![], - fragment_buffer: FragmentBuffer::new(), - cache: cache2, - cipher_suite: cipher_suite2, - remote_epoch, - handshake_tx, - handshake_done_rx, - packet_tx: packet_tx2, - }; - - //trace!("before enter read_and_buffer: {}] ", srv_cli_str(is_client)); - loop { - tokio::select! { - _ = reader_close_rx.recv() => { - trace!( - "{}: read_and_buffer exit", - srv_cli_str(ctx.is_client), - ); - break; - } - result = DTLSConn::read_and_buffer( - &mut ctx, - &next_conn_rx, - &mut handle_queue_rx, - &mut buf, - &local_epoch, - &handshake_completed_successfully2, - ) => { - if let Err(err) = result { - trace!( - "{}: read_and_buffer return err: {}", - srv_cli_str(is_client), - err - ); - if Error::ErrAlertFatalOrClose == err { - trace!( - "{}: read_and_buffer exit with {}", - srv_cli_str(ctx.is_client), - err - ); - - break; - } - } - } - } - } - }); - - // Do handshake - c.handshake(initial_fsm_state).await?; - - trace!("Handshake Completed"); - - Ok(c) - } - - // Read reads data from the connection. - pub async fn read(&self, p: &mut [u8], duration: Option) -> Result { - if !self.is_handshake_completed_successfully() { - return Err(Error::ErrHandshakeInProgress); - } - - let rx = { - let mut decrypted_rx = self.decrypted_rx.lock().await; - if let Some(d) = duration { - let timer = tokio::time::sleep(d); - tokio::pin!(timer); - - tokio::select! { - r = decrypted_rx.recv() => r, - _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded), - } - } else { - decrypted_rx.recv().await - } - }; - - if let Some(out) = rx { - match out { - Ok(val) => { - let n = val.len(); - if p.len() < n { - return Err(Error::ErrBufferTooSmall); - } - p[..n].copy_from_slice(&val); - Ok(n) - } - Err(err) => Err(err), - } - } else { - Err(Error::ErrAlertFatalOrClose) - } - } - - // Write writes len(p) bytes from p to the DTLS connection - pub async fn write(&self, p: &[u8], duration: Option) -> Result { - if self.is_connection_closed() { - return Err(Error::ErrConnClosed); - } - - if !self.is_handshake_completed_successfully() { - return Err(Error::ErrHandshakeInProgress); - } - - let pkts = vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - self.get_local_epoch(), - Content::ApplicationData(ApplicationData { data: p.to_vec() }), - ), - should_encrypt: true, - reset_local_sequence_number: false, - }]; - - if let Some(d) = duration { - let timer = tokio::time::sleep(d); - tokio::pin!(timer); - - tokio::select! { - result = self.write_packets(pkts) => { - result?; - } - _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded), - } - } else { - self.write_packets(pkts).await?; - } - - Ok(p.len()) - } - - // Close closes the connection. - pub async fn close(&self) -> Result<()> { - if !self.closed.load(Ordering::SeqCst) { - self.closed.store(true, Ordering::SeqCst); - - // Discard error from notify() to return non-error on the first user call of Close() - // even if the underlying connection is already closed. - self.notify(AlertLevel::Warning, AlertDescription::CloseNotify) - .await?; - - { - let mut reader_close_tx = self.reader_close_tx.lock().await; - reader_close_tx.take(); - } - self.conn.close().await?; - } - - Ok(()) - } - - /// connection_state returns basic DTLS details about the connection. - /// Note that this replaced the `Export` function of v1. - pub async fn connection_state(&self) -> State { - self.state.clone().await - } - - /// selected_srtpprotection_profile returns the selected SRTPProtectionProfile - pub fn selected_srtpprotection_profile(&self) -> SrtpProtectionProfile { - self.state.srtp_protection_profile - } - - pub(crate) async fn notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()> { - self.write_packets(vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - self.get_local_epoch(), - Content::Alert(Alert { - alert_level: level, - alert_description: desc, - }), - ), - should_encrypt: self.is_handshake_completed_successfully(), - reset_local_sequence_number: false, - }]) - .await - } - - pub(crate) async fn write_packets(&self, pkts: Vec) -> Result<()> { - let (tx, mut rx) = mpsc::channel(1); - - self.packet_tx.send((pkts, Some(tx))).await?; - - if let Some(result) = rx.recv().await { - result - } else { - Ok(()) - } - } - - async fn handle_outgoing_packets( - next_conn: &Arc, - mut pkts: Vec, - cache: &mut HandshakeCache, - is_client: bool, - local_sequence_number: &Arc>>, - cipher_suite: &Arc>>>, - maximum_transmission_unit: usize, - ) -> Result<()> { - let mut raw_packets = vec![]; - for p in &mut pkts { - if let Content::Handshake(h) = &p.record.content { - let mut handshake_raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(handshake_raw.as_mut()); - p.record.marshal(&mut writer)?; - } - trace!( - "Send [handshake:{}] -> {} (epoch: {}, seq: {})", - srv_cli_str(is_client), - h.handshake_header.handshake_type.to_string(), - p.record.record_layer_header.epoch, - h.handshake_header.message_sequence - ); - cache - .push( - handshake_raw[RECORD_LAYER_HEADER_SIZE..].to_vec(), - p.record.record_layer_header.epoch, - h.handshake_header.message_sequence, - h.handshake_header.handshake_type, - is_client, - ) - .await; - - let raw_handshake_packets = DTLSConn::process_handshake_packet( - local_sequence_number, - cipher_suite, - maximum_transmission_unit, - p, - h, - ) - .await?; - raw_packets.extend_from_slice(&raw_handshake_packets); - } else { - /*if let Content::Alert(a) = &p.record.content { - if a.alert_description == AlertDescription::CloseNotify { - closed = true; - } - }*/ - - let raw_packet = - DTLSConn::process_packet(local_sequence_number, cipher_suite, p).await?; - raw_packets.push(raw_packet); - } - } - - if !raw_packets.is_empty() { - let compacted_raw_packets = - compact_raw_packets(&raw_packets, maximum_transmission_unit); - - for compacted_raw_packets in &compacted_raw_packets { - next_conn.send(compacted_raw_packets).await?; - } - } - - Ok(()) - } - - async fn process_packet( - local_sequence_number: &Arc>>, - cipher_suite: &Arc>>>, - p: &mut Packet, - ) -> Result> { - let epoch = p.record.record_layer_header.epoch as usize; - let seq = { - let mut lsn = local_sequence_number.lock().await; - while lsn.len() <= epoch { - lsn.push(0); - } - - lsn[epoch] += 1; - lsn[epoch] - 1 - }; - //trace!("{}: seq = {}", srv_cli_str(is_client), seq); - - if seq > MAX_SEQUENCE_NUMBER { - // RFC 6347 Section 4.1.0 - // The implementation must either abandon an association or rehandshake - // prior to allowing the sequence number to wrap. - return Err(Error::ErrSequenceNumberOverflow); - } - p.record.record_layer_header.sequence_number = seq; - - let mut raw_packet = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw_packet.as_mut()); - p.record.marshal(&mut writer)?; - } - - if p.should_encrypt { - let cipher_suite = cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - raw_packet = cipher_suite.encrypt(&p.record.record_layer_header, &raw_packet)?; - } - } - - Ok(raw_packet) - } - - async fn process_handshake_packet( - local_sequence_number: &Arc>>, - cipher_suite: &Arc>>>, - maximum_transmission_unit: usize, - p: &Packet, - h: &Handshake, - ) -> Result>> { - let mut raw_packets = vec![]; - - let handshake_fragments = DTLSConn::fragment_handshake(maximum_transmission_unit, h)?; - - let epoch = p.record.record_layer_header.epoch as usize; - - let mut lsn = local_sequence_number.lock().await; - while lsn.len() <= epoch { - lsn.push(0); - } - - for handshake_fragment in &handshake_fragments { - let seq = { - lsn[epoch] += 1; - lsn[epoch] - 1 - }; - //trace!("seq = {}", seq); - if seq > MAX_SEQUENCE_NUMBER { - return Err(Error::ErrSequenceNumberOverflow); - } - - let record_layer_header = RecordLayerHeader { - protocol_version: p.record.record_layer_header.protocol_version, - content_type: p.record.record_layer_header.content_type, - content_len: handshake_fragment.len() as u16, - epoch: p.record.record_layer_header.epoch, - sequence_number: seq, - }; - - let mut record_layer_header_bytes = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(record_layer_header_bytes.as_mut()); - record_layer_header.marshal(&mut writer)?; - } - - //p.record.record_layer_header = record_layer_header; - - let mut raw_packet = vec![]; - raw_packet.extend_from_slice(&record_layer_header_bytes); - raw_packet.extend_from_slice(handshake_fragment); - if p.should_encrypt { - let cipher_suite = cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - raw_packet = cipher_suite.encrypt(&record_layer_header, &raw_packet)?; - } - } - - raw_packets.push(raw_packet); - } - - Ok(raw_packets) - } - - fn fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result>> { - let mut content = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(content.as_mut()); - h.handshake_message.marshal(&mut writer)?; - } - - let mut fragmented_handshakes = vec![]; - - let mut content_fragments = split_bytes(&content, maximum_transmission_unit); - if content_fragments.is_empty() { - content_fragments = vec![vec![]]; - } - - let mut offset = 0; - for content_fragment in &content_fragments { - let content_fragment_len = content_fragment.len(); - - let handshake_header_fragment = HandshakeHeader { - handshake_type: h.handshake_header.handshake_type, - length: h.handshake_header.length, - message_sequence: h.handshake_header.message_sequence, - fragment_offset: offset as u32, - fragment_length: content_fragment_len as u32, - }; - - offset += content_fragment_len; - - let mut handshake_header_fragment_raw = vec![]; - { - let mut writer = - BufWriter::<&mut Vec>::new(handshake_header_fragment_raw.as_mut()); - handshake_header_fragment.marshal(&mut writer)?; - } - - let mut fragmented_handshake = vec![]; - fragmented_handshake.extend_from_slice(&handshake_header_fragment_raw); - fragmented_handshake.extend_from_slice(content_fragment); - - fragmented_handshakes.push(fragmented_handshake); - } - - Ok(fragmented_handshakes) - } - - pub(crate) fn set_handshake_completed_successfully(&mut self) { - self.handshake_completed_successfully - .store(true, Ordering::SeqCst); - } - - pub(crate) fn is_handshake_completed_successfully(&self) -> bool { - self.handshake_completed_successfully.load(Ordering::SeqCst) - } - - async fn read_and_buffer( - ctx: &mut ConnReaderContext, - next_conn: &Arc, - handle_queue_rx: &mut mpsc::Receiver>, - buf: &mut [u8], - local_epoch: &Arc, - handshake_completed_successfully: &Arc, - ) -> Result<()> { - let n = next_conn.recv(buf).await?; - let pkts = unpack_datagram(&buf[..n])?; - let mut has_handshake = false; - for pkt in pkts { - let (hs, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, pkt, true).await; - if let Some(alert) = alert { - let alert_err = ctx - .packet_tx - .send(( - vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - local_epoch.load(Ordering::SeqCst), - Content::Alert(Alert { - alert_level: alert.alert_level, - alert_description: alert.alert_description, - }), - ), - should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst), - reset_local_sequence_number: false, - }], - None, - )) - .await; - - if let Err(alert_err) = alert_err { - if err.is_none() { - err = Some(Error::Other(alert_err.to_string())); - } - } - - if alert.alert_level == AlertLevel::Fatal - || alert.alert_description == AlertDescription::CloseNotify - { - return Err(Error::ErrAlertFatalOrClose); - } - } - - if let Some(err) = err { - return Err(err); - } - - if hs { - has_handshake = true - } - } - - if has_handshake { - let (done_tx, mut done_rx) = mpsc::channel(1); - - tokio::select! { - _ = ctx.handshake_tx.send(done_tx) => { - let mut wait_done_rx = true; - while wait_done_rx{ - tokio::select!{ - _ = done_rx.recv() => { - // If the other party may retransmit the flight, - // we should respond even if it not a new message. - wait_done_rx = false; - } - done = handle_queue_rx.recv() => { - //trace!("recv handle_queue: {} ", srv_cli_str(ctx.is_client)); - - let pkts = ctx.encrypted_packets.drain(..).collect(); - DTLSConn::handle_queued_packets(ctx, local_epoch, handshake_completed_successfully, pkts).await?; - - drop(done); - } - } - } - } - _ = ctx.handshake_done_rx.recv() => {} - } - } - - Ok(()) - } - - async fn handle_queued_packets( - ctx: &mut ConnReaderContext, - local_epoch: &Arc, - handshake_completed_successfully: &Arc, - pkts: Vec>, - ) -> Result<()> { - for p in pkts { - let (_, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, p, false).await; // don't re-enqueue - if let Some(alert) = alert { - let alert_err = ctx - .packet_tx - .send(( - vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - local_epoch.load(Ordering::SeqCst), - Content::Alert(Alert { - alert_level: alert.alert_level, - alert_description: alert.alert_description, - }), - ), - should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst), - reset_local_sequence_number: false, - }], - None, - )) - .await; - - if let Err(alert_err) = alert_err { - if err.is_none() { - err = Some(Error::Other(alert_err.to_string())); - } - } - if alert.alert_level == AlertLevel::Fatal - || alert.alert_description == AlertDescription::CloseNotify - { - return Err(Error::ErrAlertFatalOrClose); - } - } - - if let Some(err) = err { - return Err(err); - } - } - - Ok(()) - } - - async fn handle_incoming_packet( - ctx: &mut ConnReaderContext, - mut pkt: Vec, - enqueue: bool, - ) -> (bool, Option, Option) { - let mut reader = BufReader::new(pkt.as_slice()); - let h = match RecordLayerHeader::unmarshal(&mut reader) { - Ok(h) => h, - Err(err) => { - // Decode error must be silently discarded - // [RFC6347 Section-4.1.2.7] - debug!( - "{}: discarded broken packet: {}", - srv_cli_str(ctx.is_client), - err - ); - return (false, None, None); - } - }; - - // Validate epoch - let epoch = ctx.remote_epoch.load(Ordering::SeqCst); - if h.epoch > epoch { - if h.epoch > epoch + 1 { - debug!( - "{}: discarded future packet (epoch: {}, seq: {})", - srv_cli_str(ctx.is_client), - h.epoch, - h.sequence_number, - ); - return (false, None, None); - } - if enqueue { - debug!( - "{}: received packet of next epoch, queuing packet", - srv_cli_str(ctx.is_client) - ); - ctx.encrypted_packets.push(pkt); - } - return (false, None, None); - } - - // Anti-replay protection - while ctx.replay_detector.len() <= h.epoch as usize { - ctx.replay_detector - .push(Box::new(SlidingWindowDetector::new( - ctx.replay_protection_window, - MAX_SEQUENCE_NUMBER, - ))); - } - - let ok = ctx.replay_detector[h.epoch as usize].check(h.sequence_number); - if !ok { - debug!( - "{}: discarded duplicated packet (epoch: {}, seq: {})", - srv_cli_str(ctx.is_client), - h.epoch, - h.sequence_number, - ); - return (false, None, None); - } - - // Decrypt - if h.epoch != 0 { - let invalid_cipher_suite = { - let cipher_suite = ctx.cipher_suite.lock().await; - if cipher_suite.is_none() { - true - } else if let Some(cipher_suite) = &*cipher_suite { - !cipher_suite.is_initialized() - } else { - false - } - }; - if invalid_cipher_suite { - if enqueue { - debug!( - "{}: handshake not finished, queuing packet", - srv_cli_str(ctx.is_client) - ); - ctx.encrypted_packets.push(pkt); - } - return (false, None, None); - } - - let cipher_suite = ctx.cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - pkt = match cipher_suite.decrypt(&pkt) { - Ok(pkt) => pkt, - Err(err) => { - debug!("{}: decrypt failed: {}", srv_cli_str(ctx.is_client), err); - - // If we get an error for PSK we need to return an error. - if cipher_suite.is_psk() { - return ( - false, - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::UnknownPskIdentity, - }), - None, - ); - } else { - return (false, None, None); - } - } - }; - } - } - - let is_handshake = match ctx.fragment_buffer.push(&pkt) { - Ok(is_handshake) => is_handshake, - Err(err) => { - // Decode error must be silently discarded - // [RFC6347 Section-4.1.2.7] - debug!("{}: defragment failed: {}", srv_cli_str(ctx.is_client), err); - return (false, None, None); - } - }; - if is_handshake { - ctx.replay_detector[h.epoch as usize].accept(); - while let Ok((out, epoch)) = ctx.fragment_buffer.pop() { - //log::debug!("Extension Debug: out.len()={}", out.len()); - let mut reader = BufReader::new(out.as_slice()); - let raw_handshake = match Handshake::unmarshal(&mut reader) { - Ok(rh) => { - trace!( - "Recv [handshake:{}] -> {} (epoch: {}, seq: {})", - srv_cli_str(ctx.is_client), - rh.handshake_header.handshake_type.to_string(), - h.epoch, - rh.handshake_header.message_sequence - ); - rh - } - Err(err) => { - debug!( - "{}: handshake parse failed: {}", - srv_cli_str(ctx.is_client), - err - ); - continue; - } - }; - - ctx.cache - .push( - out, - epoch, - raw_handshake.handshake_header.message_sequence, - raw_handshake.handshake_header.handshake_type, - !ctx.is_client, - ) - .await; - } - - return (true, None, None); - } - - let mut reader = BufReader::new(pkt.as_slice()); - let r = match RecordLayer::unmarshal(&mut reader) { - Ok(r) => r, - Err(err) => { - return ( - false, - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::DecodeError, - }), - Some(err), - ); - } - }; - - match r.content { - Content::Alert(mut a) => { - trace!("{}: <- {}", srv_cli_str(ctx.is_client), a.to_string()); - if a.alert_description == AlertDescription::CloseNotify { - // Respond with a close_notify [RFC5246 Section 7.2.1] - a = Alert { - alert_level: AlertLevel::Warning, - alert_description: AlertDescription::CloseNotify, - }; - } - ctx.replay_detector[h.epoch as usize].accept(); - return ( - false, - Some(a), - Some(Error::Other(format!("Error of Alert {a}"))), - ); - } - Content::ChangeCipherSpec(_) => { - let invalid_cipher_suite = { - let cipher_suite = ctx.cipher_suite.lock().await; - if cipher_suite.is_none() { - true - } else if let Some(cipher_suite) = &*cipher_suite { - !cipher_suite.is_initialized() - } else { - false - } - }; - - if invalid_cipher_suite { - if enqueue { - debug!( - "{}: CipherSuite not initialized, queuing packet", - srv_cli_str(ctx.is_client) - ); - ctx.encrypted_packets.push(pkt); - } - return (false, None, None); - } - - let new_remote_epoch = h.epoch + 1; - trace!( - "{}: <- ChangeCipherSpec (epoch: {})", - srv_cli_str(ctx.is_client), - new_remote_epoch - ); - - if epoch + 1 == new_remote_epoch { - ctx.remote_epoch.store(new_remote_epoch, Ordering::SeqCst); - ctx.replay_detector[h.epoch as usize].accept(); - } - } - Content::ApplicationData(a) => { - if h.epoch == 0 { - return ( - false, - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::UnexpectedMessage, - }), - Some(Error::ErrApplicationDataEpochZero), - ); - } - - ctx.replay_detector[h.epoch as usize].accept(); - - let _ = ctx.decrypted_tx.send(Ok(a.data)).await; - //TODO - /*select { - case self.decrypted < - content.data: - case < -c.closed.Done(): - }*/ - } - _ => { - return ( - false, - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::UnexpectedMessage, - }), - Some(Error::ErrUnhandledContextType), - ); - } - }; - - (false, None, None) - } - - fn is_connection_closed(&self) -> bool { - self.closed.load(Ordering::SeqCst) - } - - pub(crate) fn set_local_epoch(&mut self, epoch: u16) { - self.state.local_epoch.store(epoch, Ordering::SeqCst); - } - - pub(crate) fn get_local_epoch(&self) -> u16 { - self.state.local_epoch.load(Ordering::SeqCst) - } -} - -fn compact_raw_packets(raw_packets: &[Vec], maximum_transmission_unit: usize) -> Vec> { - let mut combined_raw_packets = vec![]; - let mut current_combined_raw_packet = vec![]; - - for raw_packet in raw_packets { - if !current_combined_raw_packet.is_empty() - && current_combined_raw_packet.len() + raw_packet.len() >= maximum_transmission_unit - { - combined_raw_packets.push(current_combined_raw_packet); - current_combined_raw_packet = vec![]; - } - current_combined_raw_packet.extend_from_slice(raw_packet); - } - - combined_raw_packets.push(current_combined_raw_packet); - - combined_raw_packets -} - -fn split_bytes(bytes: &[u8], split_len: usize) -> Vec> { - let mut splits = vec![]; - let num_bytes = bytes.len(); - for i in (0..num_bytes).step_by(split_len) { - let mut j = i + split_len; - if j > num_bytes { - j = num_bytes; - } - - splits.push(bytes[i..j].to_vec()); - } - - splits -} diff --git a/dtls/src/content.rs b/dtls/src/content.rs deleted file mode 100644 index f1b0f623f..000000000 --- a/dtls/src/content.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::io::{Read, Write}; - -use super::alert::*; -use super::application_data::*; -use super::change_cipher_spec::*; -use super::handshake::*; -use crate::error::*; - -// https://tools.ietf.org/html/rfc4346#section-6.2.1 -#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)] -pub enum ContentType { - ChangeCipherSpec = 20, - Alert = 21, - Handshake = 22, - ApplicationData = 23, - #[default] - Invalid, -} - -impl From for ContentType { - fn from(val: u8) -> Self { - match val { - 20 => ContentType::ChangeCipherSpec, - 21 => ContentType::Alert, - 22 => ContentType::Handshake, - 23 => ContentType::ApplicationData, - _ => ContentType::Invalid, - } - } -} - -#[derive(PartialEq, Debug, Clone)] -pub enum Content { - ChangeCipherSpec(ChangeCipherSpec), - Alert(Alert), - Handshake(Handshake), - ApplicationData(ApplicationData), -} - -impl Content { - pub fn content_type(&self) -> ContentType { - match self { - Content::ChangeCipherSpec(c) => c.content_type(), - Content::Alert(c) => c.content_type(), - Content::Handshake(c) => c.content_type(), - Content::ApplicationData(c) => c.content_type(), - } - } - - pub fn size(&self) -> usize { - match self { - Content::ChangeCipherSpec(c) => c.size(), - Content::Alert(c) => c.size(), - Content::Handshake(c) => c.size(), - Content::ApplicationData(c) => c.size(), - } - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - match self { - Content::ChangeCipherSpec(c) => c.marshal(writer), - Content::Alert(c) => c.marshal(writer), - Content::Handshake(c) => c.marshal(writer), - Content::ApplicationData(c) => c.marshal(writer), - } - } - - pub fn unmarshal(content_type: ContentType, reader: &mut R) -> Result { - match content_type { - ContentType::ChangeCipherSpec => Ok(Content::ChangeCipherSpec( - ChangeCipherSpec::unmarshal(reader)?, - )), - ContentType::Alert => Ok(Content::Alert(Alert::unmarshal(reader)?)), - ContentType::Handshake => Ok(Content::Handshake(Handshake::unmarshal(reader)?)), - ContentType::ApplicationData => Ok(Content::ApplicationData( - ApplicationData::unmarshal(reader)?, - )), - _ => Err(Error::ErrInvalidContentType), - } - } -} diff --git a/dtls/src/crypto/crypto_cbc.rs b/dtls/src/crypto/crypto_cbc.rs deleted file mode 100644 index 74b1681ff..000000000 --- a/dtls/src/crypto/crypto_cbc.rs +++ /dev/null @@ -1,129 +0,0 @@ -// AES-CBC (Cipher Block Chaining) -// First historic block cipher for AES. -// CBC mode is insecure and must not be used. Itโ€™s been progressively deprecated and -// removed from SSL libraries. -// Introduced with TLS 1.0 year 2002. Superseded by GCM in TLS 1.2 year 2008. -// Removed in TLS 1.3 year 2018. -// RFC 3268 year 2002 https://tools.ietf.org/html/rfc3268 - -// https://github.com/RustCrypto/block-ciphers - -use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit}; -use p256::elliptic_curve::subtle::ConstantTimeEq; -use rand::Rng; -use std::io::Cursor; -use std::ops::Not; - -use super::padding::DtlsPadding; -use crate::content::*; -use crate::error::*; -use crate::prf::*; -use crate::record_layer::record_layer_header::*; -type Aes256CbcEnc = cbc::Encryptor; -type Aes256CbcDec = cbc::Decryptor; - -// State needed to handle encrypted input/output -#[derive(Clone)] -pub struct CryptoCbc { - local_key: Vec, - remote_key: Vec, - write_mac: Vec, - read_mac: Vec, -} - -impl CryptoCbc { - const BLOCK_SIZE: usize = 16; - const MAC_SIZE: usize = 20; - - pub fn new( - local_key: &[u8], - local_mac: &[u8], - remote_key: &[u8], - remote_mac: &[u8], - ) -> Result { - Ok(CryptoCbc { - local_key: local_key.to_vec(), - write_mac: local_mac.to_vec(), - - remote_key: remote_key.to_vec(), - read_mac: remote_mac.to_vec(), - }) - } - - pub fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - let mut payload = raw[RECORD_LAYER_HEADER_SIZE..].to_vec(); - let raw = &raw[..RECORD_LAYER_HEADER_SIZE]; - - // Generate + Append MAC - let h = pkt_rlh; - - let mac = prf_mac( - h.epoch, - h.sequence_number, - h.content_type, - h.protocol_version, - &payload, - &self.write_mac, - )?; - payload.extend_from_slice(&mac); - - let mut iv: Vec = vec![0; Self::BLOCK_SIZE]; - rand::thread_rng().fill(iv.as_mut_slice()); - - let write_cbc = Aes256CbcEnc::new_from_slices(&self.local_key, &iv)?; - let encrypted = write_cbc.encrypt_padded_vec_mut::(&payload); - - // Prepend unencrypte header with encrypted payload - let mut r = vec![]; - r.extend_from_slice(raw); - r.extend_from_slice(&iv); - r.extend_from_slice(&encrypted); - - let r_len = (r.len() - RECORD_LAYER_HEADER_SIZE) as u16; - r[RECORD_LAYER_HEADER_SIZE - 2..RECORD_LAYER_HEADER_SIZE] - .copy_from_slice(&r_len.to_be_bytes()); - - Ok(r) - } - - pub fn decrypt(&self, r: &[u8]) -> Result> { - let mut reader = Cursor::new(r); - let h = RecordLayerHeader::unmarshal(&mut reader)?; - if h.content_type == ContentType::ChangeCipherSpec { - // Nothing to encrypt with ChangeCipherSpec - return Ok(r.to_vec()); - } - - let body = &r[RECORD_LAYER_HEADER_SIZE..]; - let iv = &body[0..Self::BLOCK_SIZE]; - let body = &body[Self::BLOCK_SIZE..]; - //TODO: add body.len() check - - let read_cbc = Aes256CbcDec::new_from_slices(&self.remote_key, iv)?; - - let decrypted = read_cbc - .decrypt_padded_vec_mut::(body) - .map_err(|_| Error::ErrInvalidPacketLength)?; - - let recv_mac = &decrypted[decrypted.len() - Self::MAC_SIZE..]; - let decrypted = &decrypted[0..decrypted.len() - Self::MAC_SIZE]; - let mac = prf_mac( - h.epoch, - h.sequence_number, - h.content_type, - h.protocol_version, - decrypted, - &self.read_mac, - )?; - - if recv_mac.ct_eq(&mac).not().into() { - return Err(Error::ErrInvalidMac); - } - - let mut d = Vec::with_capacity(RECORD_LAYER_HEADER_SIZE + decrypted.len()); - d.extend_from_slice(&r[..RECORD_LAYER_HEADER_SIZE]); - d.extend_from_slice(decrypted); - - Ok(d) - } -} diff --git a/dtls/src/crypto/crypto_ccm.rs b/dtls/src/crypto/crypto_ccm.rs deleted file mode 100644 index a932e5439..000000000 --- a/dtls/src/crypto/crypto_ccm.rs +++ /dev/null @@ -1,188 +0,0 @@ -// AES-CCM (Counter with CBC-MAC) -// Alternative to GCM mode. -// Available in OpenSSL as of TLS 1.3 (2018), but disabled by default. -// Two AES computations per block, thus expected to be somewhat slower than AES-GCM. -// RFC 6655 year 2012 https://tools.ietf.org/html/rfc6655 -// Much lower adoption, probably because it came after GCM and offer no significant benefit. - -// https://github.com/RustCrypto/AEADs -// https://docs.rs/ccm/0.3.0/ccm/ Or https://crates.io/crates/aes-ccm? - -use std::io::Cursor; - -use aes::Aes128; -use ccm::aead::generic_array::GenericArray; -use ccm::aead::AeadInPlace; -use ccm::consts::{U12, U16, U8}; -use ccm::Ccm; -use ccm::KeyInit; -use rand::Rng; - -use super::*; -use crate::content::*; -use crate::error::*; -use crate::record_layer::record_layer_header::*; - -const CRYPTO_CCM_8_TAG_LENGTH: usize = 8; -const CRYPTO_CCM_TAG_LENGTH: usize = 16; -const CRYPTO_CCM_NONCE_LENGTH: usize = 12; - -type AesCcm8 = Ccm; -type AesCcm = Ccm; - -#[derive(Clone)] -pub enum CryptoCcmTagLen { - CryptoCcm8TagLength, - CryptoCcmTagLength, -} - -enum CryptoCcmType { - CryptoCcm8(AesCcm8), - CryptoCcm(AesCcm), -} - -// State needed to handle encrypted input/output -pub struct CryptoCcm { - local_ccm: CryptoCcmType, - remote_ccm: CryptoCcmType, - local_write_iv: Vec, - remote_write_iv: Vec, - // used by clone() - local_write_key: Vec, - remote_write_key: Vec, -} - -impl Clone for CryptoCcm { - fn clone(&self) -> Self { - match self.local_ccm { - CryptoCcmType::CryptoCcm(_) => Self::new( - &CryptoCcmTagLen::CryptoCcmTagLength, - &self.local_write_key, - &self.local_write_iv, - &self.remote_write_key, - &self.remote_write_iv, - ), - CryptoCcmType::CryptoCcm8(_) => Self::new( - &CryptoCcmTagLen::CryptoCcm8TagLength, - &self.local_write_key, - &self.local_write_iv, - &self.remote_write_key, - &self.remote_write_iv, - ), - } - } -} - -impl CryptoCcm { - pub fn new( - tag_len: &CryptoCcmTagLen, - local_key: &[u8], - local_write_iv: &[u8], - remote_key: &[u8], - remote_write_iv: &[u8], - ) -> Self { - let key = GenericArray::from_slice(local_key); - let local_ccm = match tag_len { - CryptoCcmTagLen::CryptoCcmTagLength => CryptoCcmType::CryptoCcm(AesCcm::new(key)), - CryptoCcmTagLen::CryptoCcm8TagLength => CryptoCcmType::CryptoCcm8(AesCcm8::new(key)), - }; - - let key = GenericArray::from_slice(remote_key); - let remote_ccm = match tag_len { - CryptoCcmTagLen::CryptoCcmTagLength => CryptoCcmType::CryptoCcm(AesCcm::new(key)), - CryptoCcmTagLen::CryptoCcm8TagLength => CryptoCcmType::CryptoCcm8(AesCcm8::new(key)), - }; - - CryptoCcm { - local_ccm, - local_write_key: local_key.to_vec(), - local_write_iv: local_write_iv.to_vec(), - remote_ccm, - remote_write_key: remote_key.to_vec(), - remote_write_iv: remote_write_iv.to_vec(), - } - } - - pub fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - let payload = &raw[RECORD_LAYER_HEADER_SIZE..]; - let raw = &raw[..RECORD_LAYER_HEADER_SIZE]; - - let mut nonce = vec![0u8; CRYPTO_CCM_NONCE_LENGTH]; - nonce[..4].copy_from_slice(&self.local_write_iv[..4]); - rand::thread_rng().fill(&mut nonce[4..]); - let nonce = GenericArray::from_slice(&nonce); - - let additional_data = generate_aead_additional_data(pkt_rlh, payload.len()); - - let mut buffer: Vec = Vec::new(); - buffer.extend_from_slice(payload); - - match &self.local_ccm { - CryptoCcmType::CryptoCcm(ccm) => { - ccm.encrypt_in_place(nonce, &additional_data, &mut buffer) - .map_err(|e| Error::Other(e.to_string()))?; - } - CryptoCcmType::CryptoCcm8(ccm8) => { - ccm8.encrypt_in_place(nonce, &additional_data, &mut buffer) - .map_err(|e| Error::Other(e.to_string()))?; - } - } - - let mut r = Vec::with_capacity(raw.len() + nonce.len() + buffer.len()); - - r.extend_from_slice(raw); - r.extend_from_slice(&nonce[4..]); - r.extend_from_slice(&buffer); - - // Update recordLayer size to include explicit nonce - let r_len = (r.len() - RECORD_LAYER_HEADER_SIZE) as u16; - r[RECORD_LAYER_HEADER_SIZE - 2..RECORD_LAYER_HEADER_SIZE] - .copy_from_slice(&r_len.to_be_bytes()); - - Ok(r) - } - - pub fn decrypt(&self, r: &[u8]) -> Result> { - let mut reader = Cursor::new(r); - let h = RecordLayerHeader::unmarshal(&mut reader)?; - if h.content_type == ContentType::ChangeCipherSpec { - // Nothing to encrypt with ChangeCipherSpec - return Ok(r.to_vec()); - } - - if r.len() <= (RECORD_LAYER_HEADER_SIZE + 8) { - return Err(Error::ErrNotEnoughRoomForNonce); - } - - let mut nonce = vec![]; - nonce.extend_from_slice(&self.remote_write_iv[..4]); - nonce.extend_from_slice(&r[RECORD_LAYER_HEADER_SIZE..RECORD_LAYER_HEADER_SIZE + 8]); - let nonce = GenericArray::from_slice(&nonce); - - let out = &r[RECORD_LAYER_HEADER_SIZE + 8..]; - - let mut buffer: Vec = Vec::new(); - buffer.extend_from_slice(out); - - match &self.remote_ccm { - CryptoCcmType::CryptoCcm(ccm) => { - let additional_data = - generate_aead_additional_data(&h, out.len() - CRYPTO_CCM_TAG_LENGTH); - ccm.decrypt_in_place(nonce, &additional_data, &mut buffer) - .map_err(|e| Error::Other(e.to_string()))?; - } - CryptoCcmType::CryptoCcm8(ccm8) => { - let additional_data = - generate_aead_additional_data(&h, out.len() - CRYPTO_CCM_8_TAG_LENGTH); - ccm8.decrypt_in_place(nonce, &additional_data, &mut buffer) - .map_err(|e| Error::Other(e.to_string()))?; - } - } - - let mut d = Vec::with_capacity(RECORD_LAYER_HEADER_SIZE + buffer.len()); - d.extend_from_slice(&r[..RECORD_LAYER_HEADER_SIZE]); - d.extend_from_slice(&buffer); - - Ok(d) - } -} diff --git a/dtls/src/crypto/crypto_gcm.rs b/dtls/src/crypto/crypto_gcm.rs deleted file mode 100644 index 9928ccad0..000000000 --- a/dtls/src/crypto/crypto_gcm.rs +++ /dev/null @@ -1,119 +0,0 @@ -// AES-GCM (Galois Counter Mode) -// The most widely used block cipher worldwide. -// Mandatory as of TLS 1.2 (2008) and used by default by most clients. -// RFC 5288 year 2008 https://tools.ietf.org/html/rfc5288 - -// https://github.com/RustCrypto/AEADs -// https://docs.rs/aes-gcm/0.8.0/aes_gcm/ - -use std::io::Cursor; - -use aes_gcm::aead::generic_array::GenericArray; -use aes_gcm::aead::AeadInPlace; -use aes_gcm::{Aes128Gcm, KeyInit}; -use rand::Rng; - -use super::*; -use crate::content::*; -use crate::error::*; -use crate::record_layer::record_layer_header::*; // what about Aes256Gcm? - -const CRYPTO_GCM_TAG_LENGTH: usize = 16; -const CRYPTO_GCM_NONCE_LENGTH: usize = 12; - -// State needed to handle encrypted input/output -#[derive(Clone)] -pub struct CryptoGcm { - local_gcm: Aes128Gcm, - remote_gcm: Aes128Gcm, - local_write_iv: Vec, - remote_write_iv: Vec, -} - -impl CryptoGcm { - pub fn new( - local_key: &[u8], - local_write_iv: &[u8], - remote_key: &[u8], - remote_write_iv: &[u8], - ) -> Self { - let key = GenericArray::from_slice(local_key); - let local_gcm = Aes128Gcm::new(key); - - let key = GenericArray::from_slice(remote_key); - let remote_gcm = Aes128Gcm::new(key); - - CryptoGcm { - local_gcm, - local_write_iv: local_write_iv.to_vec(), - remote_gcm, - remote_write_iv: remote_write_iv.to_vec(), - } - } - - pub fn encrypt(&self, pkt_rlh: &RecordLayerHeader, raw: &[u8]) -> Result> { - let payload = &raw[RECORD_LAYER_HEADER_SIZE..]; - let raw = &raw[..RECORD_LAYER_HEADER_SIZE]; - - let mut nonce = vec![0u8; CRYPTO_GCM_NONCE_LENGTH]; - nonce[..4].copy_from_slice(&self.local_write_iv[..4]); - rand::thread_rng().fill(&mut nonce[4..]); - let nonce = GenericArray::from_slice(&nonce); - - let additional_data = generate_aead_additional_data(pkt_rlh, payload.len()); - - let mut buffer: Vec = Vec::new(); - buffer.extend_from_slice(payload); - - self.local_gcm - .encrypt_in_place(nonce, &additional_data, &mut buffer) - .map_err(|e| Error::Other(e.to_string()))?; - - let mut r = Vec::with_capacity(raw.len() + nonce.len() + buffer.len()); - r.extend_from_slice(raw); - r.extend_from_slice(&nonce[4..]); - r.extend_from_slice(&buffer); - - // Update recordLayer size to include explicit nonce - let r_len = (r.len() - RECORD_LAYER_HEADER_SIZE) as u16; - r[RECORD_LAYER_HEADER_SIZE - 2..RECORD_LAYER_HEADER_SIZE] - .copy_from_slice(&r_len.to_be_bytes()); - - Ok(r) - } - - pub fn decrypt(&self, r: &[u8]) -> Result> { - let mut reader = Cursor::new(r); - let h = RecordLayerHeader::unmarshal(&mut reader)?; - if h.content_type == ContentType::ChangeCipherSpec { - // Nothing to encrypt with ChangeCipherSpec - return Ok(r.to_vec()); - } - - if r.len() <= (RECORD_LAYER_HEADER_SIZE + 8) { - return Err(Error::ErrNotEnoughRoomForNonce); - } - - let mut nonce = vec![]; - nonce.extend_from_slice(&self.remote_write_iv[..4]); - nonce.extend_from_slice(&r[RECORD_LAYER_HEADER_SIZE..RECORD_LAYER_HEADER_SIZE + 8]); - let nonce = GenericArray::from_slice(&nonce); - - let out = &r[RECORD_LAYER_HEADER_SIZE + 8..]; - - let additional_data = generate_aead_additional_data(&h, out.len() - CRYPTO_GCM_TAG_LENGTH); - - let mut buffer: Vec = Vec::new(); - buffer.extend_from_slice(out); - - self.remote_gcm - .decrypt_in_place(nonce, &additional_data, &mut buffer) - .map_err(|e| Error::Other(e.to_string()))?; - - let mut d = Vec::with_capacity(RECORD_LAYER_HEADER_SIZE + buffer.len()); - d.extend_from_slice(&r[..RECORD_LAYER_HEADER_SIZE]); - d.extend_from_slice(&buffer); - - Ok(d) - } -} diff --git a/dtls/src/crypto/crypto_test.rs b/dtls/src/crypto/crypto_test.rs deleted file mode 100644 index f06988bc6..000000000 --- a/dtls/src/crypto/crypto_test.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::io::Cursor; - -use x509_parser::pem::Pem; - -use super::crypto_ccm::*; -use super::*; -use crate::content::ContentType; -use crate::record_layer::record_layer_header::{ProtocolVersion, RECORD_LAYER_HEADER_SIZE}; - -const RAW_PRIVATE_KEY: &str = " ------BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP -6MBx/HPnAkwNvPS90R05a7pwRkoT6Ur4PfPhCVlUe8lV+0Eto3ZSEeHz3HdsqlM3 -bso67L7Dqrc7MdVstlKcgJi8yeAoGOIL9/igOv0XBFCeznm9nznx6mnsR5cugw+1 -ypXelaHmBCLV7r5SeVSh57+KhvZGbQ2fFpUaTPegRpJZXBNS8lSeWvtOv9d6N5UB -ROTAJodMZT5AfX0jB0QB9IT/0I96H6BSENH08NXOeXApMuLKvnAf361rS7cRAfRL -rWZqERMP4u6Cnk0Cnckc3WcW27kGGIbtwbqUIQIDAQABAoIBAGF7OVIdZp8Hejn0 -N3L8HvT8xtUEe9kS6ioM0lGgvX5s035Uo4/T6LhUx0VcdXRH9eLHnLTUyN4V4cra -ZkxVsE3zAvZl60G6E+oDyLMWZOP6Wu4kWlub9597A5atT7BpMIVCdmFVZFLB4SJ3 -AXkC3nplFAYP+Lh1rJxRIrIn2g+pEeBboWbYA++oDNuMQffDZaokTkJ8Bn1JZYh0 -xEXKY8Bi2Egd5NMeZa1UFO6y8tUbZfwgVs6Enq5uOgtfayq79vZwyjj1kd29MBUD -8g8byV053ZKxbUOiOuUts97eb+fN3DIDRTcT2c+lXt/4C54M1FclJAbtYRK/qwsl -pYWKQAECgYEA4ZUbqQnTo1ICvj81ifGrz+H4LKQqe92Hbf/W51D/Umk2kP702W22 -HP4CvrJRtALThJIG9m2TwUjl/WAuZIBrhSAbIvc3Fcoa2HjdRp+sO5U1ueDq7d/S -Z+PxRI8cbLbRpEdIaoR46qr/2uWZ943PHMv9h4VHPYn1w8b94hwD6vkCgYEA3v87 -mFLzyM9ercnEv9zHMRlMZFQhlcUGQZvfb8BuJYl/WogyT6vRrUuM0QXULNEPlrin -mBQTqc1nCYbgkFFsD2VVt1qIyiAJsB9MD1LNV6YuvE7T2KOSadmsA4fa9PUqbr71 -hf3lTTq+LeR09LebO7WgSGYY+5YKVOEGpYMR1GkCgYEAxPVQmk3HKHEhjgRYdaG5 -lp9A9ZE8uruYVJWtiHgzBTxx9TV2iST+fd/We7PsHFTfY3+wbpcMDBXfIVRKDVwH -BMwchXH9+Ztlxx34bYJaegd0SmA0Hw9ugWEHNgoSEmWpM1s9wir5/ELjc7dGsFtz -uzvsl9fpdLSxDYgAAdzeGtkCgYBAzKIgrVox7DBzB8KojhtD5ToRnXD0+H/M6OKQ -srZPKhlb0V/tTtxrIx0UUEFLlKSXA6mPw6XDHfDnD86JoV9pSeUSlrhRI+Ysy6tq -eIE7CwthpPZiaYXORHZ7wCqcK/HcpJjsCs9rFbrV0yE5S3FMdIbTAvgXg44VBB7O -UbwIoQKBgDuY8gSrA5/A747wjjmsdRWK4DMTMEV4eCW1BEP7Tg7Cxd5n3xPJiYhr -nhLGN+mMnVIcv2zEMS0/eNZr1j/0BtEdx+3IC6Eq+ONY0anZ4Irt57/5QeKgKn/L -JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/ ------END RSA PRIVATE KEY----- -"; - -#[test] -fn test_generate_key_signature() -> Result<()> { - let reader = Cursor::new(RAW_PRIVATE_KEY.as_bytes()); - let pem = match Pem::read(reader) { - Ok((pem, _)) => pem, - Err(_) => return Err(Error::Other("Pem::read error".to_owned())), - }; - //let private_key = rsa::RSAPrivateKey::from_pkcs1(&pem.contents)?; - - let client_random = vec![ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, - 0x1e, 0x1f, - ]; - let server_random = vec![ - 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, - 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, - 0x8e, 0x8f, - ]; - let public_key = vec![ - 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, - 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, - 0x80, 0xb6, 0x15, - ]; - let expected_signature = vec![ - 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, - 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, - 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, - 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, - 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, - 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, - 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, - 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, - 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, - 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, - 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, - 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, - 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, - 0x2c, 0x3e, 0x21, 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, - 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, - 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, - 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, 0x5c, 0x36, 0x75, - 0x86, - ]; - - let signature = generate_key_signature( - &client_random, - &server_random, - &public_key, - NamedCurve::X25519, - &CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Rsa256( - ring::rsa::KeyPair::from_der(&pem.contents) - .map_err(|e| Error::Other(e.to_string()))?, - ), - serialized_der: pem.contents.clone(), - }, //hashAlgorithmSHA256, - )?; - - assert_eq!( - signature, expected_signature, - "Signature generation failed \nexp {expected_signature:?} \nactual {signature:?} " - ); - - Ok(()) -} - -#[test] -fn test_ccm_encryption_and_decryption() -> Result<()> { - let key = vec![ - 0x18, 0x78, 0xac, 0xc2, 0x2a, 0xd8, 0xbd, 0xd8, 0xc6, 0x01, 0xa6, 0x17, 0x12, 0x6f, 0x63, - 0x54, - ]; - let iv = vec![0x0e, 0xb2, 0x09, 0x06]; - - let ccm = CryptoCcm::new(&CryptoCcmTagLen::CryptoCcmTagLength, &key, &iv, &key, &iv); - - let rlh = RecordLayerHeader { - content_type: ContentType::ApplicationData, - protocol_version: ProtocolVersion { - major: 0xfe, - minor: 0xff, - }, - epoch: 0, - sequence_number: 18, - content_len: 3, - }; - - let raw = vec![ - 0x17, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x03, 0xff, 0xaa, - 0xbb, - ]; - - let cipher_text = ccm.encrypt(&rlh, &raw)?; - - assert_eq!( - &cipher_text[RECORD_LAYER_HEADER_SIZE - 2..RECORD_LAYER_HEADER_SIZE], - [0, 27], - "RecordLayer size updating failed \nexp: {:?} \nactual {:?} ", - [0, 27], - &cipher_text[RECORD_LAYER_HEADER_SIZE - 2..RECORD_LAYER_HEADER_SIZE] - ); - - let plain_text = ccm.decrypt(&cipher_text)?; - - assert_eq!( - raw[RECORD_LAYER_HEADER_SIZE..], - plain_text[RECORD_LAYER_HEADER_SIZE..], - "Decryption failed \nexp: {:?} \nactual {:?} ", - &raw[RECORD_LAYER_HEADER_SIZE..], - &plain_text[RECORD_LAYER_HEADER_SIZE..] - ); - - Ok(()) -} - -#[test] -fn test_certificate_verify() -> Result<()> { - let plain_text: Vec = vec![ - 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, - 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, - 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, - 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, - 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, - 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, - 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, - 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, - 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, - 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, - 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, - 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, - 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, - 0x2c, 0x3e, 0x21, 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, - 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, - 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, - 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, 0x5c, 0x36, 0x75, - 0x86, - ]; - - //test ECDSA256 - let certificate_ecdsa256 = Certificate::generate_self_signed(vec!["localhost".to_owned()])?; - let cert_verify_ecdsa256 = - generate_certificate_verify(&plain_text, &certificate_ecdsa256.private_key)?; - verify_certificate_verify( - &plain_text, - &SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - &cert_verify_ecdsa256, - &certificate_ecdsa256 - .certificate - .iter() - .map(|x| x.as_ref().to_owned()) - .collect::>>(), - false, - )?; - - //test ED25519 - let certificate_ed25519 = Certificate::generate_self_signed_with_alg( - vec!["localhost".to_owned()], - &rcgen::PKCS_ED25519, - )?; - let cert_verify_ed25519 = - generate_certificate_verify(&plain_text, &certificate_ed25519.private_key)?; - verify_certificate_verify( - &plain_text, - &SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ed25519, - }, - &cert_verify_ed25519, - &certificate_ed25519 - .certificate - .iter() - .map(|x| x.as_ref().to_owned()) - .collect::>>(), - false, - )?; - - Ok(()) -} diff --git a/dtls/src/crypto/mod.rs b/dtls/src/crypto/mod.rs deleted file mode 100644 index 92bba0977..000000000 --- a/dtls/src/crypto/mod.rs +++ /dev/null @@ -1,528 +0,0 @@ -#[cfg(test)] -mod crypto_test; - -pub mod crypto_cbc; -pub mod crypto_ccm; -pub mod crypto_gcm; -pub mod padding; - -use std::convert::TryFrom; -use std::sync::Arc; - -use der_parser::oid; -use der_parser::oid::Oid; - -use rustls::client::danger::ServerCertVerifier; -use rustls::pki_types::{CertificateDer, ServerName}; -use rustls::server::danger::ClientCertVerifier; - -use rcgen::{generate_simple_self_signed, CertifiedKey, KeyPair}; -use ring::rand::SystemRandom; -use ring::signature::{EcdsaKeyPair, Ed25519KeyPair}; - -use crate::curve::named_curve::*; -use crate::error::*; -use crate::record_layer::record_layer_header::*; -use crate::signature_hash_algorithm::{HashAlgorithm, SignatureAlgorithm, SignatureHashAlgorithm}; - -/// A X.509 certificate(s) used to authenticate a DTLS connection. -#[derive(Clone, PartialEq, Debug)] -pub struct Certificate { - /// DER-encoded certificates. - pub certificate: Vec>, - /// Private key. - pub private_key: CryptoPrivateKey, -} - -impl Certificate { - /// Generate a self-signed certificate. - /// - /// See [`rcgen::generate_simple_self_signed`]. - pub fn generate_self_signed(subject_alt_names: impl Into>) -> Result { - let CertifiedKey { cert, key_pair } = - generate_simple_self_signed(subject_alt_names).unwrap(); - Ok(Certificate { - certificate: vec![cert.der().to_owned()], - private_key: CryptoPrivateKey::try_from(&key_pair)?, - }) - } - - /// Generate a self-signed certificate with the given algorithm. - /// - /// See [`rcgen::Certificate::from_params`]. - pub fn generate_self_signed_with_alg( - subject_alt_names: impl Into>, - alg: &'static rcgen::SignatureAlgorithm, - ) -> Result { - let params = rcgen::CertificateParams::new(subject_alt_names).unwrap(); - let key_pair = rcgen::KeyPair::generate_for(alg).unwrap(); - let cert = params.self_signed(&key_pair).unwrap(); - - Ok(Certificate { - certificate: vec![cert.der().to_owned()], - private_key: CryptoPrivateKey::try_from(&key_pair)?, - }) - } - - /// Parses a certificate from the ASCII PEM format. - #[cfg(feature = "pem")] - pub fn from_pem(pem_str: &str) -> Result { - let mut pems = pem::parse_many(pem_str).map_err(|e| Error::InvalidPEM(e.to_string()))?; - if pems.len() < 2 { - return Err(Error::InvalidPEM(format!( - "expected at least two PEM blocks, got {}", - pems.len() - ))); - } - if pems[0].tag() != "PRIVATE_KEY" { - return Err(Error::InvalidPEM(format!( - "invalid tag (expected: 'PRIVATE_KEY', got: '{}')", - pems[0].tag() - ))); - } - - let keypair = KeyPair::try_from(pems[0].contents()) - .map_err(|e| Error::InvalidPEM(format!("can't decode keypair: {e}")))?; - - let mut rustls_certs = Vec::new(); - for p in pems.drain(1..) { - if p.tag() != "CERTIFICATE" { - return Err(Error::InvalidPEM(format!( - "invalid tag (expected: 'CERTIFICATE', got: '{}')", - p.tag() - ))); - } - rustls_certs.push(CertificateDer::from(p.contents().to_vec())); - } - - Ok(Certificate { - certificate: rustls_certs, - private_key: CryptoPrivateKey::try_from(&keypair)?, - }) - } - - /// Serializes the certificate (including the private key) in PKCS#8 format in PEM. - #[cfg(feature = "pem")] - pub fn serialize_pem(&self) -> String { - let mut data = vec![pem::Pem::new( - "PRIVATE_KEY".to_string(), - self.private_key.serialized_der.clone(), - )]; - for rustls_cert in &self.certificate { - data.push(pem::Pem::new( - "CERTIFICATE".to_string(), - rustls_cert.as_ref(), - )); - } - pem::encode_many(&data) - } -} - -pub(crate) fn value_key_message( - client_random: &[u8], - server_random: &[u8], - public_key: &[u8], - named_curve: NamedCurve, -) -> Vec { - let mut server_ecdh_params = vec![0u8; 4]; - server_ecdh_params[0] = 3; // named curve - server_ecdh_params[1..3].copy_from_slice(&(named_curve as u16).to_be_bytes()); - server_ecdh_params[3] = public_key.len() as u8; - - let mut plaintext = vec![]; - plaintext.extend_from_slice(client_random); - plaintext.extend_from_slice(server_random); - plaintext.extend_from_slice(&server_ecdh_params); - plaintext.extend_from_slice(public_key); - - plaintext -} - -/// Either ED25519, ECDSA or RSA keypair. -#[derive(Debug)] -pub enum CryptoPrivateKeyKind { - Ed25519(Ed25519KeyPair), - Ecdsa256(EcdsaKeyPair), - Rsa256(ring::rsa::KeyPair), -} - -/// Private key. -#[derive(Debug)] -pub struct CryptoPrivateKey { - /// Keypair. - pub kind: CryptoPrivateKeyKind, - /// DER-encoded keypair. - pub serialized_der: Vec, -} - -impl PartialEq for CryptoPrivateKey { - fn eq(&self, other: &Self) -> bool { - if self.serialized_der != other.serialized_der { - return false; - } - - matches!( - (&self.kind, &other.kind), - ( - CryptoPrivateKeyKind::Rsa256(_), - CryptoPrivateKeyKind::Rsa256(_) - ) | ( - CryptoPrivateKeyKind::Ecdsa256(_), - CryptoPrivateKeyKind::Ecdsa256(_) - ) | ( - CryptoPrivateKeyKind::Ed25519(_), - CryptoPrivateKeyKind::Ed25519(_) - ) - ) - } -} - -impl Clone for CryptoPrivateKey { - fn clone(&self) -> Self { - match self.kind { - CryptoPrivateKeyKind::Ed25519(_) => CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Ed25519( - Ed25519KeyPair::from_pkcs8(&self.serialized_der).unwrap(), - ), - serialized_der: self.serialized_der.clone(), - }, - CryptoPrivateKeyKind::Ecdsa256(_) => CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Ecdsa256( - EcdsaKeyPair::from_pkcs8( - &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING, - &self.serialized_der, - &SystemRandom::new(), - ) - .unwrap(), - ), - serialized_der: self.serialized_der.clone(), - }, - CryptoPrivateKeyKind::Rsa256(_) => CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Rsa256( - ring::rsa::KeyPair::from_pkcs8(&self.serialized_der).unwrap(), - ), - serialized_der: self.serialized_der.clone(), - }, - } - } -} - -impl TryFrom<&KeyPair> for CryptoPrivateKey { - type Error = Error; - - fn try_from(key_pair: &KeyPair) -> Result { - Self::from_key_pair(key_pair) - } -} - -impl CryptoPrivateKey { - pub fn from_key_pair(key_pair: &KeyPair) -> Result { - let serialized_der = key_pair.serialize_der(); - if key_pair.is_compatible(&rcgen::PKCS_ED25519) { - Ok(CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Ed25519( - Ed25519KeyPair::from_pkcs8(&serialized_der) - .map_err(|e| Error::Other(e.to_string()))?, - ), - serialized_der, - }) - } else if key_pair.is_compatible(&rcgen::PKCS_ECDSA_P256_SHA256) { - Ok(CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Ecdsa256( - EcdsaKeyPair::from_pkcs8( - &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING, - &serialized_der, - &SystemRandom::new(), - ) - .map_err(|e| Error::Other(e.to_string()))?, - ), - serialized_der, - }) - } else if key_pair.is_compatible(&rcgen::PKCS_RSA_SHA256) { - Ok(CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Rsa256( - ring::rsa::KeyPair::from_pkcs8(&serialized_der) - .map_err(|e| Error::Other(e.to_string()))?, - ), - serialized_der, - }) - } else { - Err(Error::Other("Unsupported key_pair".to_owned())) - } - } -} - -// If the client provided a "signature_algorithms" extension, then all -// certificates provided by the server MUST be signed by a -// hash/signature algorithm pair that appears in that extension -// -// https://tools.ietf.org/html/rfc5246#section-7.4.2 -pub(crate) fn generate_key_signature( - client_random: &[u8], - server_random: &[u8], - public_key: &[u8], - named_curve: NamedCurve, - private_key: &CryptoPrivateKey, /*, hash_algorithm: HashAlgorithm*/ -) -> Result> { - let msg = value_key_message(client_random, server_random, public_key, named_curve); - let signature = match &private_key.kind { - CryptoPrivateKeyKind::Ed25519(kp) => kp.sign(&msg).as_ref().to_vec(), - CryptoPrivateKeyKind::Ecdsa256(kp) => { - let system_random = SystemRandom::new(); - kp.sign(&system_random, &msg) - .map_err(|e| Error::Other(e.to_string()))? - .as_ref() - .to_vec() - } - CryptoPrivateKeyKind::Rsa256(kp) => { - let system_random = SystemRandom::new(); - let mut signature = vec![0; kp.public().modulus_len()]; - kp.sign( - &ring::signature::RSA_PKCS1_SHA256, - &system_random, - &msg, - &mut signature, - ) - .map_err(|e| Error::Other(e.to_string()))?; - - signature - } - }; - - Ok(signature) -} - -// add OID_ED25519 which is not defined in x509_parser -pub const OID_ED25519: Oid<'static> = oid!(1.3.101 .112); -pub const OID_ECDSA: Oid<'static> = oid!(1.2.840 .10045 .2 .1); - -fn verify_signature( - message: &[u8], - hash_algorithm: &SignatureHashAlgorithm, - remote_key_signature: &[u8], - raw_certificates: &[Vec], - insecure_verification: bool, -) -> Result<()> { - if raw_certificates.is_empty() { - return Err(Error::ErrLengthMismatch); - } - - let (_, certificate) = x509_parser::parse_x509_certificate(&raw_certificates[0]) - .map_err(|e| Error::Other(e.to_string()))?; - - let verify_alg: &dyn ring::signature::VerificationAlgorithm = match hash_algorithm.signature { - SignatureAlgorithm::Ed25519 => &ring::signature::ED25519, - SignatureAlgorithm::Ecdsa if hash_algorithm.hash == HashAlgorithm::Sha256 => { - &ring::signature::ECDSA_P256_SHA256_ASN1 - } - SignatureAlgorithm::Ecdsa if hash_algorithm.hash == HashAlgorithm::Sha384 => { - &ring::signature::ECDSA_P384_SHA384_ASN1 - } - SignatureAlgorithm::Rsa if hash_algorithm.hash == HashAlgorithm::Sha1 => { - &ring::signature::RSA_PKCS1_1024_8192_SHA1_FOR_LEGACY_USE_ONLY - } - SignatureAlgorithm::Rsa if (hash_algorithm.hash == HashAlgorithm::Sha256) => { - if remote_key_signature.len() < 256 && insecure_verification { - &ring::signature::RSA_PKCS1_1024_8192_SHA256_FOR_LEGACY_USE_ONLY - } else { - &ring::signature::RSA_PKCS1_2048_8192_SHA256 - } - } - SignatureAlgorithm::Rsa if hash_algorithm.hash == HashAlgorithm::Sha384 => { - &ring::signature::RSA_PKCS1_2048_8192_SHA384 - } - SignatureAlgorithm::Rsa if hash_algorithm.hash == HashAlgorithm::Sha512 => { - if remote_key_signature.len() < 256 && insecure_verification { - &ring::signature::RSA_PKCS1_1024_8192_SHA512_FOR_LEGACY_USE_ONLY - } else { - &ring::signature::RSA_PKCS1_2048_8192_SHA512 - } - } - _ => return Err(Error::ErrKeySignatureVerifyUnimplemented), - }; - - log::trace!("Picked an algorithm {:?}", verify_alg); - - let public_key = ring::signature::UnparsedPublicKey::new( - verify_alg, - certificate - .tbs_certificate - .subject_pki - .subject_public_key - .data, - ); - - public_key - .verify(message, remote_key_signature) - .map_err(|e| Error::Other(e.to_string()))?; - - Ok(()) -} - -pub(crate) fn verify_key_signature( - message: &[u8], - hash_algorithm: &SignatureHashAlgorithm, - remote_key_signature: &[u8], - raw_certificates: &[Vec], - insecure_verification: bool, -) -> Result<()> { - verify_signature( - message, - hash_algorithm, - remote_key_signature, - raw_certificates, - insecure_verification, - ) -} - -// If the server has sent a CertificateRequest message, the client MUST send the Certificate -// message. The ClientKeyExchange message is now sent, and the content -// of that message will depend on the public key algorithm selected -// between the ClientHello and the ServerHello. If the client has sent -// a certificate with signing ability, a digitally-signed -// CertificateVerify message is sent to explicitly verify possession of -// the private key in the certificate. -// https://tools.ietf.org/html/rfc5246#section-7.3 -pub(crate) fn generate_certificate_verify( - handshake_bodies: &[u8], - private_key: &CryptoPrivateKey, /*, hashAlgorithm hashAlgorithm*/ -) -> Result> { - let signature = match &private_key.kind { - CryptoPrivateKeyKind::Ed25519(kp) => kp.sign(handshake_bodies).as_ref().to_vec(), - CryptoPrivateKeyKind::Ecdsa256(kp) => { - let system_random = SystemRandom::new(); - kp.sign(&system_random, handshake_bodies) - .map_err(|e| Error::Other(e.to_string()))? - .as_ref() - .to_vec() - } - CryptoPrivateKeyKind::Rsa256(kp) => { - let system_random = SystemRandom::new(); - let mut signature = vec![0; kp.public().modulus_len()]; - kp.sign( - &ring::signature::RSA_PKCS1_SHA256, - &system_random, - handshake_bodies, - &mut signature, - ) - .map_err(|e| Error::Other(e.to_string()))?; - - signature - } - }; - - Ok(signature) -} - -pub(crate) fn verify_certificate_verify( - handshake_bodies: &[u8], - hash_algorithm: &SignatureHashAlgorithm, - remote_key_signature: &[u8], - raw_certificates: &[Vec], - insecure_verification: bool, -) -> Result<()> { - verify_signature( - handshake_bodies, - hash_algorithm, - remote_key_signature, - raw_certificates, - insecure_verification, - ) -} - -pub(crate) fn load_certs(raw_certificates: &[Vec]) -> Result>> { - if raw_certificates.is_empty() { - return Err(Error::ErrLengthMismatch); - } - - let mut certs = vec![]; - for raw_cert in raw_certificates { - let cert = CertificateDer::from(raw_cert.to_vec()); - certs.push(cert); - } - - Ok(certs) -} - -pub(crate) fn verify_client_cert( - raw_certificates: &[Vec], - cert_verifier: &Arc, -) -> Result>> { - let chains = load_certs(raw_certificates)?; - - let (end_entity, intermediates) = chains - .split_first() - .ok_or(Error::ErrClientCertificateRequired)?; - - match cert_verifier.verify_client_cert( - end_entity, - intermediates, - rustls::pki_types::UnixTime::now(), - ) { - Ok(_) => {} - Err(err) => return Err(Error::Other(err.to_string())), - }; - - Ok(chains) -} - -pub(crate) fn verify_server_cert( - raw_certificates: &[Vec], - cert_verifier: &Arc, - server_name: &str, -) -> Result>> { - let chains = load_certs(raw_certificates)?; - let server_name = match ServerName::try_from(server_name) { - Ok(server_name) => server_name, - Err(err) => return Err(Error::Other(err.to_string())), - }; - - let (end_entity, intermediates) = chains - .split_first() - .ok_or(Error::ErrServerMustHaveCertificate)?; - match cert_verifier.verify_server_cert( - end_entity, - intermediates, - &server_name, - &[], - rustls::pki_types::UnixTime::now(), - ) { - Ok(_) => {} - Err(err) => return Err(Error::Other(err.to_string())), - }; - - Ok(chains) -} - -pub(crate) fn generate_aead_additional_data(h: &RecordLayerHeader, payload_len: usize) -> Vec { - let mut additional_data = vec![0u8; 13]; - // SequenceNumber MUST be set first - // we only want uint48, clobbering an extra 2 (using uint64, rust doesn't have uint48) - additional_data[..8].copy_from_slice(&h.sequence_number.to_be_bytes()); - additional_data[..2].copy_from_slice(&h.epoch.to_be_bytes()); - additional_data[8] = h.content_type as u8; - additional_data[9] = h.protocol_version.major; - additional_data[10] = h.protocol_version.minor; - additional_data[11..].copy_from_slice(&(payload_len as u16).to_be_bytes()); - - additional_data -} - -#[cfg(test)] -mod test { - #[cfg(feature = "pem")] - use super::*; - - #[cfg(feature = "pem")] - #[test] - fn test_certificate_serialize_pem_and_from_pem() -> crate::error::Result<()> { - let cert = Certificate::generate_self_signed(vec!["webrtc.rs".to_owned()])?; - - let pem = cert.serialize_pem(); - let loaded_cert = Certificate::from_pem(&pem)?; - - assert_eq!(loaded_cert, cert); - - Ok(()) - } -} diff --git a/dtls/src/crypto/padding.rs b/dtls/src/crypto/padding.rs deleted file mode 100644 index 7c5c82d74..000000000 --- a/dtls/src/crypto/padding.rs +++ /dev/null @@ -1,122 +0,0 @@ -use cbc::cipher::block_padding::{PadType, RawPadding, UnpadError}; -use core::panic; - -pub enum DtlsPadding {} -/// Reference: RFC5246, 6.2.3.2 -impl RawPadding for DtlsPadding { - const TYPE: PadType = PadType::Reversible; - - fn raw_pad(block: &mut [u8], pos: usize) { - if pos >= block.len() { - panic!("`pos` is bigger or equal to block size"); - } - - let padding_length = block.len() - pos - 1; - if padding_length > 255 { - panic!("block size is too big for DTLS"); - } - - set(&mut block[pos..], padding_length as u8); - } - - fn raw_unpad(data: &[u8]) -> Result<&[u8], UnpadError> { - let padding_length = data.last().copied().unwrap_or(1) as usize; - if padding_length + 1 > data.len() { - return Err(UnpadError); - } - - let padding_begin = data.len() - padding_length - 1; - - if data[padding_begin..data.len() - 1] - .iter() - .any(|&byte| byte as usize != padding_length) - { - return Err(UnpadError); - } - - Ok(&data[0..padding_begin]) - } -} - -/// Sets all bytes in `dst` equal to `value` -#[inline(always)] -fn set(dst: &mut [u8], value: u8) { - // SAFETY: we overwrite valid memory behind `dst` - // note: loop is not used here because it produces - // unnecessary branch which tests for zero-length slices - unsafe { - core::ptr::write_bytes(dst.as_mut_ptr(), value, dst.len()); - } -} - -#[cfg(test)] -pub mod tests { - use rand::Rng; - - use super::*; - - #[test] - fn padding_length_is_amount_of_bytes_excluding_the_padding_length_itself() -> Result<(), ()> { - for original_length in 0..128 { - for padding_length in 0..(256 - original_length) { - let mut block = vec![0; original_length + padding_length + 1]; - rand::thread_rng().fill(&mut block[0..original_length]); - let original = block[0..original_length].to_vec(); - DtlsPadding::raw_pad(&mut block, original_length); - - for byte in block[original_length..].iter() { - assert_eq!(*byte as usize, padding_length); - } - assert_eq!(block[0..original_length], original); - } - } - - Ok(()) - } - - #[test] - #[should_panic] - fn full_block_is_padding_error() { - for original_length in 0..256 { - let mut block = vec![0; original_length]; - DtlsPadding::raw_pad(&mut block, original_length); - } - } - - #[test] - #[should_panic] - fn padding_length_bigger_than_255_is_a_pad_error() { - let padding_length = 256; - for original_length in 0..128 { - let mut block = vec![0; original_length + padding_length + 1]; - DtlsPadding::raw_pad(&mut block, original_length); - } - } - - #[test] - fn empty_block_is_unpadding_error() { - let r = DtlsPadding::raw_unpad(&[]); - assert!(r.is_err()); - } - - #[test] - fn padding_too_big_for_block_is_unpadding_error() { - let r = DtlsPadding::raw_unpad(&[1]); - assert!(r.is_err()); - } - - #[test] - fn one_of_the_padding_bytes_with_value_different_than_padding_length_is_unpadding_error() { - for padding_length in 0..16 { - for invalid_byte in 0..padding_length { - let mut block = vec![0; padding_length + 1]; - DtlsPadding::raw_pad(&mut block, 0); - - assert_eq!(DtlsPadding::raw_unpad(&block).ok(), Some(&[][..])); - block[invalid_byte] = (padding_length - 1) as u8; - let r = DtlsPadding::raw_unpad(&block); - assert!(r.is_err()); - } - } - } -} diff --git a/dtls/src/curve/mod.rs b/dtls/src/curve/mod.rs deleted file mode 100644 index dd48d5ff2..000000000 --- a/dtls/src/curve/mod.rs +++ /dev/null @@ -1,17 +0,0 @@ -pub mod named_curve; - -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10 -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum EllipticCurveType { - NamedCurve = 0x03, - Unsupported, -} - -impl From for EllipticCurveType { - fn from(val: u8) -> Self { - match val { - 0x03 => EllipticCurveType::NamedCurve, - _ => EllipticCurveType::Unsupported, - } - } -} diff --git a/dtls/src/curve/named_curve.rs b/dtls/src/curve/named_curve.rs deleted file mode 100644 index 6d7d97f8e..000000000 --- a/dtls/src/curve/named_curve.rs +++ /dev/null @@ -1,83 +0,0 @@ -use rand_core::OsRng; // requires 'getrandom' feature - -use crate::error::*; - -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 -#[repr(u16)] -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum NamedCurve { - Unsupported = 0x0000, - P256 = 0x0017, - P384 = 0x0018, - X25519 = 0x001d, -} - -impl From for NamedCurve { - fn from(val: u16) -> Self { - match val { - 0x0017 => NamedCurve::P256, - 0x0018 => NamedCurve::P384, - 0x001d => NamedCurve::X25519, - _ => NamedCurve::Unsupported, - } - } -} - -pub(crate) enum NamedCurvePrivateKey { - EphemeralSecretP256(p256::ecdh::EphemeralSecret), - EphemeralSecretP384(p384::ecdh::EphemeralSecret), - StaticSecretX25519(x25519_dalek::StaticSecret), -} - -pub struct NamedCurveKeypair { - pub(crate) curve: NamedCurve, - pub(crate) public_key: Vec, - pub(crate) private_key: NamedCurvePrivateKey, -} - -fn elliptic_curve_keypair(curve: NamedCurve) -> Result { - let (public_key, private_key) = match curve { - NamedCurve::P256 => { - let secret_key = p256::ecdh::EphemeralSecret::random(&mut OsRng); - let public_key = p256::EncodedPoint::from(secret_key.public_key()); - ( - public_key.as_bytes().to_vec(), - NamedCurvePrivateKey::EphemeralSecretP256(secret_key), - ) - } - NamedCurve::P384 => { - let secret_key = p384::ecdh::EphemeralSecret::random(&mut OsRng); - let public_key = p384::EncodedPoint::from(secret_key.public_key()); - ( - public_key.as_bytes().to_vec(), - NamedCurvePrivateKey::EphemeralSecretP384(secret_key), - ) - } - NamedCurve::X25519 => { - let secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); - let public_key = x25519_dalek::PublicKey::from(&secret_key); - ( - public_key.as_bytes().to_vec(), - NamedCurvePrivateKey::StaticSecretX25519(secret_key), - ) - } - _ => return Err(Error::ErrInvalidNamedCurve), - }; - - Ok(NamedCurveKeypair { - curve, - public_key, - private_key, - }) -} - -impl NamedCurve { - pub fn generate_keypair(&self) -> Result { - match *self { - NamedCurve::X25519 => elliptic_curve_keypair(NamedCurve::X25519), - NamedCurve::P256 => elliptic_curve_keypair(NamedCurve::P256), - NamedCurve::P384 => elliptic_curve_keypair(NamedCurve::P384), - _ => Err(Error::ErrInvalidNamedCurve), - } - } -} diff --git a/dtls/src/error.rs b/dtls/src/error.rs deleted file mode 100644 index a90d9fa8b..000000000 --- a/dtls/src/error.rs +++ /dev/null @@ -1,223 +0,0 @@ -use std::io; -use std::string::FromUtf8Error; - -use thiserror::Error; -use tokio::sync::mpsc::error::SendError as MpscSendError; -use util::KeyingMaterialExporterError; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("conn is closed")] - ErrConnClosed, - #[error("read/write timeout")] - ErrDeadlineExceeded, - #[error("buffer is too small")] - ErrBufferTooSmall, - #[error("context is not supported for export_keying_material")] - ErrContextUnsupported, - #[error("packet is too short")] - ErrDtlspacketInvalidLength, - #[error("handshake is in progress")] - ErrHandshakeInProgress, - #[error("invalid content type")] - ErrInvalidContentType, - #[error("invalid mac")] - ErrInvalidMac, - #[error("packet length and declared length do not match")] - ErrInvalidPacketLength, - #[error("export_keying_material can not be used with a reserved label")] - ErrReservedExportKeyingMaterial, - #[error("client sent certificate verify but we have no certificate to verify")] - ErrCertificateVerifyNoCertificate, - #[error("client+server do not support any shared cipher suites")] - ErrCipherSuiteNoIntersection, - #[error("server hello can not be created without a cipher suite")] - ErrCipherSuiteUnset, - #[error("client sent certificate but did not verify it")] - ErrClientCertificateNotVerified, - #[error("server required client verification, but got none")] - ErrClientCertificateRequired, - #[error("server responded with SRTP Profile we do not support")] - ErrClientNoMatchingSrtpProfile, - #[error("client required Extended Master Secret extension, but server does not support it")] - ErrClientRequiredButNoServerEms, - #[error("server hello can not be created without a compression method")] - ErrCompressionMethodUnset, - #[error("client+server cookie does not match")] - ErrCookieMismatch, - #[error("cookie must not be longer then 255 bytes")] - ErrCookieTooLong, - #[error("PSK Identity Hint provided but PSK is nil")] - ErrIdentityNoPsk, - #[error("no certificate provided")] - ErrInvalidCertificate, - #[error("cipher spec invalid")] - ErrInvalidCipherSpec, - #[error("invalid or unknown cipher suite")] - ErrInvalidCipherSuite, - #[error("unable to determine if ClientKeyExchange is a public key or PSK Identity")] - ErrInvalidClientKeyExchange, - #[error("invalid or unknown compression method")] - ErrInvalidCompressionMethod, - #[error("ECDSA signature contained zero or negative values")] - ErrInvalidEcdsasignature, - #[error("invalid or unknown elliptic curve type")] - ErrInvalidEllipticCurveType, - #[error("invalid extension type")] - ErrInvalidExtensionType, - #[error("invalid hash algorithm")] - ErrInvalidHashAlgorithm, - #[error("invalid named curve")] - ErrInvalidNamedCurve, - #[error("invalid private key type")] - ErrInvalidPrivateKey, - #[error("named curve and private key type does not match")] - ErrNamedCurveAndPrivateKeyMismatch, - #[error("invalid server name format")] - ErrInvalidSniFormat, - #[error("invalid signature algorithm")] - ErrInvalidSignatureAlgorithm, - #[error("expected and actual key signature do not match")] - ErrKeySignatureMismatch, - #[error("Conn can not be created with a nil nextConn")] - ErrNilNextConn, - #[error("connection can not be created, no CipherSuites satisfy this Config")] - ErrNoAvailableCipherSuites, - #[error("connection can not be created, no SignatureScheme satisfy this Config")] - ErrNoAvailableSignatureSchemes, - #[error("no certificates configured")] - ErrNoCertificates, - #[error("no config provided")] - ErrNoConfigProvided, - #[error("client requested zero or more elliptic curves that are not supported by the server")] - ErrNoSupportedEllipticCurves, - #[error("unsupported protocol version")] - ErrUnsupportedProtocolVersion, - #[error("Certificate and PSK provided")] - ErrPskAndCertificate, - #[error("PSK and PSK Identity Hint must both be set for client")] - ErrPskAndIdentityMustBeSetForClient, - #[error("SRTP support was requested but server did not respond with use_srtp extension")] - ErrRequestedButNoSrtpExtension, - #[error("Certificate is mandatory for server")] - ErrServerMustHaveCertificate, - #[error("client requested SRTP but we have no matching profiles")] - ErrServerNoMatchingSrtpProfile, - #[error( - "server requires the Extended Master Secret extension, but the client does not support it" - )] - ErrServerRequiredButNoClientEms, - #[error("expected and actual verify data does not match")] - ErrVerifyDataMismatch, - #[error("handshake message unset, unable to marshal")] - ErrHandshakeMessageUnset, - #[error("invalid flight number")] - ErrInvalidFlight, - #[error("unable to generate key signature, unimplemented")] - ErrKeySignatureGenerateUnimplemented, - #[error("unable to verify key signature, unimplemented")] - ErrKeySignatureVerifyUnimplemented, - #[error("data length and declared length do not match")] - ErrLengthMismatch, - #[error("buffer not long enough to contain nonce")] - ErrNotEnoughRoomForNonce, - #[error("feature has not been implemented yet")] - ErrNotImplemented, - #[error("sequence number overflow")] - ErrSequenceNumberOverflow, - #[error("unable to marshal fragmented handshakes")] - ErrUnableToMarshalFragmented, - #[error("invalid state machine transition")] - ErrInvalidFsmTransition, - #[error("ApplicationData with epoch of 0")] - ErrApplicationDataEpochZero, - #[error("unhandled contentType")] - ErrUnhandledContextType, - #[error("context canceled")] - ErrContextCanceled, - #[error("empty fragment")] - ErrEmptyFragment, - #[error("Alert is Fatal or Close Notify")] - ErrAlertFatalOrClose, - - #[error( - "Fragment buffer overflow. New size {new_size} is greater than specified max {max_size}" - )] - ErrFragmentBufferOverflow { new_size: usize, max_size: usize }, - - #[error("{0}")] - Io(#[source] IoError), - #[error("{0}")] - Util(#[from] util::Error), - #[error("utf8: {0}")] - Utf8(#[from] FromUtf8Error), - #[error("{0}")] - Sec1(#[source] sec1::Error), - #[error("{0}")] - Aes(#[from] aes::cipher::InvalidLength), - #[error("{0}")] - P256(#[source] P256Error), - #[error("{0}")] - RcGen(#[from] rcgen::Error), - #[error("mpsc send: {0}")] - MpscSend(String), - #[error("keying material: {0}")] - KeyingMaterial(#[from] KeyingMaterialExporterError), - - /// Error parsing a given PEM string. - #[error("invalid PEM: {0}")] - InvalidPEM(String), - - #[allow(non_camel_case_types)] - #[error("{0}")] - Other(String), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} - -impl From for Error { - fn from(e: sec1::Error) -> Self { - Error::Sec1(e) - } -} - -#[derive(Debug, Error)] -#[error("{0}")] -pub struct P256Error(#[source] p256::elliptic_curve::Error); - -impl PartialEq for P256Error { - fn eq(&self, _: &Self) -> bool { - false - } -} - -impl From for Error { - fn from(e: p256::elliptic_curve::Error) -> Self { - Error::P256(P256Error(e)) - } -} - -// Because Tokio SendError is parameterized, we sadly lose the backtrace. -impl From> for Error { - fn from(e: MpscSendError) -> Self { - Error::MpscSend(e.to_string()) - } -} diff --git a/dtls/src/extension/extension_server_name.rs b/dtls/src/extension/extension_server_name.rs deleted file mode 100644 index dbe018e9f..000000000 --- a/dtls/src/extension/extension_server_name.rs +++ /dev/null @@ -1,56 +0,0 @@ -#[cfg(test)] -mod extension_server_name_test; - -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::*; - -const EXTENSION_SERVER_NAME_TYPE_DNSHOST_NAME: u8 = 0; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionServerName { - pub(crate) server_name: String, -} - -impl ExtensionServerName { - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::ServerName - } - - pub fn size(&self) -> usize { - //TODO: check how to do cryptobyte? - 2 + 2 + 1 + 2 + self.server_name.as_bytes().len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - //TODO: check how to do cryptobyte? - writer.write_u16::(2 + 1 + 2 + self.server_name.len() as u16)?; - writer.write_u16::(1 + 2 + self.server_name.len() as u16)?; - writer.write_u8(EXTENSION_SERVER_NAME_TYPE_DNSHOST_NAME)?; - writer.write_u16::(self.server_name.len() as u16)?; - writer.write_all(self.server_name.as_bytes())?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - //TODO: check how to do cryptobyte? - let _ = reader.read_u16::()? as usize; - let _ = reader.read_u16::()? as usize; - - let name_type = reader.read_u8()?; - if name_type != EXTENSION_SERVER_NAME_TYPE_DNSHOST_NAME { - return Err(Error::ErrInvalidSniFormat); - } - - let buf_len = reader.read_u16::()? as usize; - let mut buf: Vec = vec![0u8; buf_len]; - reader.read_exact(&mut buf)?; - - let server_name = String::from_utf8(buf)?; - - Ok(ExtensionServerName { server_name }) - } -} diff --git a/dtls/src/extension/extension_server_name/extension_server_name_test.rs b/dtls/src/extension/extension_server_name/extension_server_name_test.rs deleted file mode 100644 index 4cfb7a1a1..000000000 --- a/dtls/src/extension/extension_server_name/extension_server_name_test.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_extension_server_name() -> Result<()> { - let extension = ExtensionServerName { - server_name: "test.domain".to_owned(), - }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - extension.marshal(&mut writer)?; - } - - let mut reader = BufReader::new(raw.as_slice()); - let new_extension = ExtensionServerName::unmarshal(&mut reader)?; - - assert_eq!( - new_extension, extension, - "extensionServerName marshal: got {new_extension:?} expected {extension:?}", - ); - - Ok(()) -} diff --git a/dtls/src/extension/extension_supported_elliptic_curves.rs b/dtls/src/extension/extension_supported_elliptic_curves.rs deleted file mode 100644 index 88caf6f76..000000000 --- a/dtls/src/extension/extension_supported_elliptic_curves.rs +++ /dev/null @@ -1,46 +0,0 @@ -#[cfg(test)] -mod extension_supported_elliptic_curves_test; - -use super::*; -use crate::curve::named_curve::*; - -const EXTENSION_SUPPORTED_GROUPS_HEADER_SIZE: usize = 6; - -// https://tools.ietf.org/html/rfc8422#section-5.1.1 -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionSupportedEllipticCurves { - pub elliptic_curves: Vec, -} - -impl ExtensionSupportedEllipticCurves { - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::SupportedEllipticCurves - } - - pub fn size(&self) -> usize { - 2 + 2 + self.elliptic_curves.len() * 2 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(2 + 2 * self.elliptic_curves.len() as u16)?; - writer.write_u16::(2 * self.elliptic_curves.len() as u16)?; - for v in &self.elliptic_curves { - writer.write_u16::(*v as u16)?; - } - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let _ = reader.read_u16::()?; - - let group_count = reader.read_u16::()? as usize / 2; - let mut elliptic_curves = vec![]; - for _ in 0..group_count { - let elliptic_curve = reader.read_u16::()?.into(); - elliptic_curves.push(elliptic_curve); - } - - Ok(ExtensionSupportedEllipticCurves { elliptic_curves }) - } -} diff --git a/dtls/src/extension/extension_supported_elliptic_curves/extension_supported_elliptic_curves_test.rs b/dtls/src/extension/extension_supported_elliptic_curves/extension_supported_elliptic_curves_test.rs deleted file mode 100644 index 2bcf70c17..000000000 --- a/dtls/src/extension/extension_supported_elliptic_curves/extension_supported_elliptic_curves_test.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_extension_supported_groups() -> Result<()> { - let raw_supported_groups = vec![0x0, 0x4, 0x0, 0x2, 0x0, 0x1d]; // 0x0, 0xa, - let parsed_supported_groups = ExtensionSupportedEllipticCurves { - elliptic_curves: vec![NamedCurve::X25519], - }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - parsed_supported_groups.marshal(&mut writer)?; - } - - assert_eq!( - raw, raw_supported_groups, - "extensionSupportedGroups marshal: got {raw:?}, want {raw_supported_groups:?}" - ); - - let mut reader = BufReader::new(raw.as_slice()); - let new_supported_groups = ExtensionSupportedEllipticCurves::unmarshal(&mut reader)?; - - assert_eq!( - new_supported_groups, parsed_supported_groups, - "extensionSupportedGroups unmarshal: got {new_supported_groups:?}, want {parsed_supported_groups:?}" - ); - - Ok(()) -} diff --git a/dtls/src/extension/extension_supported_point_formats.rs b/dtls/src/extension/extension_supported_point_formats.rs deleted file mode 100644 index 17e2448af..000000000 --- a/dtls/src/extension/extension_supported_point_formats.rs +++ /dev/null @@ -1,49 +0,0 @@ -#[cfg(test)] -mod extension_supported_point_formats_test; - -use super::*; - -const EXTENSION_SUPPORTED_POINT_FORMATS_SIZE: usize = 5; - -pub type EllipticCurvePointFormat = u8; - -pub const ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED: EllipticCurvePointFormat = 0; - -// https://tools.ietf.org/html/rfc4492#section-5.1.2 -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionSupportedPointFormats { - pub(crate) point_formats: Vec, -} - -impl ExtensionSupportedPointFormats { - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::SupportedPointFormats - } - - pub fn size(&self) -> usize { - 2 + 1 + self.point_formats.len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(1 + self.point_formats.len() as u16)?; - writer.write_u8(self.point_formats.len() as u8)?; - for v in &self.point_formats { - writer.write_u8(*v)?; - } - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let _ = reader.read_u16::()?; - - let point_format_count = reader.read_u8()? as usize; - let mut point_formats = vec![]; - for _ in 0..point_format_count { - let point_format = reader.read_u8()?; - point_formats.push(point_format); - } - - Ok(ExtensionSupportedPointFormats { point_formats }) - } -} diff --git a/dtls/src/extension/extension_supported_point_formats/extension_supported_point_formats_test.rs b/dtls/src/extension/extension_supported_point_formats/extension_supported_point_formats_test.rs deleted file mode 100644 index e624a98a1..000000000 --- a/dtls/src/extension/extension_supported_point_formats/extension_supported_point_formats_test.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_extension_supported_point_formats() -> Result<()> { - let raw_extension_supported_point_formats = vec![0x00, 0x02, 0x01, 0x00]; // 0x00, 0x0b, - let parsed_extension_supported_point_formats = ExtensionSupportedPointFormats { - point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED], - }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - parsed_extension_supported_point_formats.marshal(&mut writer)?; - } - - assert_eq!( - raw, raw_extension_supported_point_formats, - "extensionSupportedPointFormats marshal: got {raw:?}, want {raw_extension_supported_point_formats:?}" - ); - - let mut reader = BufReader::new(raw.as_slice()); - let new_extension_supported_point_formats = - ExtensionSupportedPointFormats::unmarshal(&mut reader)?; - - assert_eq!( - new_extension_supported_point_formats, parsed_extension_supported_point_formats, - "extensionSupportedPointFormats unmarshal: got {new_extension_supported_point_formats:?}, want {parsed_extension_supported_point_formats:?}" - ); - - Ok(()) -} diff --git a/dtls/src/extension/extension_supported_signature_algorithms.rs b/dtls/src/extension/extension_supported_signature_algorithms.rs deleted file mode 100644 index e15210b2e..000000000 --- a/dtls/src/extension/extension_supported_signature_algorithms.rs +++ /dev/null @@ -1,50 +0,0 @@ -#[cfg(test)] -mod extension_supported_signature_algorithms_test; - -use super::*; -use crate::signature_hash_algorithm::*; - -const EXTENSION_SUPPORTED_SIGNATURE_ALGORITHMS_HEADER_SIZE: usize = 6; - -// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionSupportedSignatureAlgorithms { - pub(crate) signature_hash_algorithms: Vec, -} - -impl ExtensionSupportedSignatureAlgorithms { - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::SupportedSignatureAlgorithms - } - - pub fn size(&self) -> usize { - 2 + 2 + self.signature_hash_algorithms.len() * 2 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(2 + 2 * self.signature_hash_algorithms.len() as u16)?; - writer.write_u16::(2 * self.signature_hash_algorithms.len() as u16)?; - for v in &self.signature_hash_algorithms { - writer.write_u8(v.hash as u8)?; - writer.write_u8(v.signature as u8)?; - } - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let _ = reader.read_u16::()?; - - let algorithm_count = reader.read_u16::()? as usize / 2; - let mut signature_hash_algorithms = vec![]; - for _ in 0..algorithm_count { - let hash = reader.read_u8()?.into(); - let signature = reader.read_u8()?.into(); - signature_hash_algorithms.push(SignatureHashAlgorithm { hash, signature }); - } - - Ok(ExtensionSupportedSignatureAlgorithms { - signature_hash_algorithms, - }) - } -} diff --git a/dtls/src/extension/extension_supported_signature_algorithms/extension_supported_signature_algorithms_test.rs b/dtls/src/extension/extension_supported_signature_algorithms/extension_supported_signature_algorithms_test.rs deleted file mode 100644 index e531fee03..000000000 --- a/dtls/src/extension/extension_supported_signature_algorithms/extension_supported_signature_algorithms_test.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_extension_supported_signature_algorithms() -> Result<()> { - let raw_extension_supported_signature_algorithms = - vec![0x00, 0x08, 0x00, 0x06, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03]; //0x00, 0x0d, - let parsed_extension_supported_signature_algorithms = ExtensionSupportedSignatureAlgorithms { - signature_hash_algorithms: vec![ - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Ecdsa, - }, - ], - }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - parsed_extension_supported_signature_algorithms.marshal(&mut writer)?; - } - - assert_eq!( - raw, raw_extension_supported_signature_algorithms, - "extensionSupportedSignatureAlgorithms marshal: got {raw:?}, want {raw_extension_supported_signature_algorithms:?}" - ); - - let mut reader = BufReader::new(raw.as_slice()); - let new_extension_supported_signature_algorithms = - ExtensionSupportedSignatureAlgorithms::unmarshal(&mut reader)?; - - assert_eq!( - new_extension_supported_signature_algorithms, - parsed_extension_supported_signature_algorithms, - "extensionSupportedSignatureAlgorithms unmarshal: got {new_extension_supported_signature_algorithms:?}, want {parsed_extension_supported_signature_algorithms:?}" - ); - - Ok(()) -} diff --git a/dtls/src/extension/extension_use_extended_master_secret.rs b/dtls/src/extension/extension_use_extended_master_secret.rs deleted file mode 100644 index 9d62c4eaa..000000000 --- a/dtls/src/extension/extension_use_extended_master_secret.rs +++ /dev/null @@ -1,35 +0,0 @@ -#[cfg(test)] -mod extension_use_extended_master_secret_test; - -use super::*; - -const EXTENSION_USE_EXTENDED_MASTER_SECRET_HEADER_SIZE: usize = 4; - -// https://tools.ietf.org/html/rfc8422 -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionUseExtendedMasterSecret { - pub(crate) supported: bool, -} - -impl ExtensionUseExtendedMasterSecret { - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::UseExtendedMasterSecret - } - - pub fn size(&self) -> usize { - 2 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - // length - writer.write_u16::(0)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let _ = reader.read_u16::()?; - - Ok(ExtensionUseExtendedMasterSecret { supported: true }) - } -} diff --git a/dtls/src/extension/extension_use_extended_master_secret/extension_use_extended_master_secret_test.rs b/dtls/src/extension/extension_use_extended_master_secret/extension_use_extended_master_secret_test.rs deleted file mode 100644 index d2b0dd424..000000000 --- a/dtls/src/extension/extension_use_extended_master_secret/extension_use_extended_master_secret_test.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_extension_use_extended_master_secret() -> Result<()> { - let raw_extension_use_extended_master_secret = vec![0x00, 0x00]; - let parsed_extension_use_extended_master_secret = - ExtensionUseExtendedMasterSecret { supported: true }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - parsed_extension_use_extended_master_secret.marshal(&mut writer)?; - } - - assert_eq!( - raw, raw_extension_use_extended_master_secret, - "extension_use_extended_master_secret marshal: got {raw:?}, want {raw_extension_use_extended_master_secret:?}" - ); - - let mut reader = BufReader::new(raw.as_slice()); - let new_extension_use_extended_master_secret = - ExtensionUseExtendedMasterSecret::unmarshal(&mut reader)?; - - assert_eq!( - new_extension_use_extended_master_secret, parsed_extension_use_extended_master_secret, - "extension_use_extended_master_secret unmarshal: got {new_extension_use_extended_master_secret:?}, want {parsed_extension_use_extended_master_secret:?}" - ); - - Ok(()) -} diff --git a/dtls/src/extension/extension_use_srtp.rs b/dtls/src/extension/extension_use_srtp.rs deleted file mode 100644 index b8620d7ea..000000000 --- a/dtls/src/extension/extension_use_srtp.rs +++ /dev/null @@ -1,80 +0,0 @@ -#[cfg(test)] -mod extension_use_srtp_test; - -use super::*; - -// SRTPProtectionProfile defines the parameters and options that are in effect for the SRTP processing -// https://tools.ietf.org/html/rfc5764#section-4.1.2 -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum SrtpProtectionProfile { - Srtp_Aes128_Cm_Hmac_Sha1_80 = 0x0001, - Srtp_Aes128_Cm_Hmac_Sha1_32 = 0x0002, - Srtp_Aead_Aes_128_Gcm = 0x0007, - Srtp_Aead_Aes_256_Gcm = 0x0008, - Unsupported, -} - -impl From for SrtpProtectionProfile { - fn from(val: u16) -> Self { - match val { - 0x0001 => SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - 0x0002 => SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32, - 0x0007 => SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm, - 0x0008 => SrtpProtectionProfile::Srtp_Aead_Aes_256_Gcm, - _ => SrtpProtectionProfile::Unsupported, - } - } -} - -const EXTENSION_USE_SRTPHEADER_SIZE: usize = 6; - -// https://tools.ietf.org/html/rfc8422 -#[allow(non_camel_case_types)] -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionUseSrtp { - pub(crate) protection_profiles: Vec, -} - -impl ExtensionUseSrtp { - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::UseSrtp - } - - pub fn size(&self) -> usize { - 2 + 2 + self.protection_profiles.len() * 2 + 1 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u16::( - 2 + /* MKI Length */ 1 + 2 * self.protection_profiles.len() as u16, - )?; - writer.write_u16::(2 * self.protection_profiles.len() as u16)?; - for v in &self.protection_profiles { - writer.write_u16::(*v as u16)?; - } - - /* MKI Length */ - writer.write_u8(0x00)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let _ = reader.read_u16::()?; - - let profile_count = reader.read_u16::()? as usize / 2; - let mut protection_profiles = vec![]; - for _ in 0..profile_count { - let protection_profile = reader.read_u16::()?.into(); - protection_profiles.push(protection_profile); - } - - /* MKI Length */ - let _ = reader.read_u8()?; - - Ok(ExtensionUseSrtp { - protection_profiles, - }) - } -} diff --git a/dtls/src/extension/extension_use_srtp/extension_use_srtp_test.rs b/dtls/src/extension/extension_use_srtp/extension_use_srtp_test.rs deleted file mode 100644 index c58ba56a9..000000000 --- a/dtls/src/extension/extension_use_srtp/extension_use_srtp_test.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_extension_use_srtp() -> Result<()> { - let raw_use_srtp = vec![0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00]; //0x00, 0x0e, - let parsed_use_srtp = ExtensionUseSrtp { - protection_profiles: vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80], - }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - parsed_use_srtp.marshal(&mut writer)?; - } - - assert_eq!( - raw, raw_use_srtp, - "extensionUseSRTP marshal: got {raw:?}, want {raw_use_srtp:?}" - ); - - let mut reader = BufReader::new(raw.as_slice()); - let new_use_srtp = ExtensionUseSrtp::unmarshal(&mut reader)?; - - assert_eq!( - new_use_srtp, parsed_use_srtp, - "extensionUseSRTP unmarshal: got {new_use_srtp:?}, want {parsed_use_srtp:?}" - ); - - Ok(()) -} diff --git a/dtls/src/extension/mod.rs b/dtls/src/extension/mod.rs deleted file mode 100644 index 203003810..000000000 --- a/dtls/src/extension/mod.rs +++ /dev/null @@ -1,130 +0,0 @@ -pub mod extension_server_name; -pub mod extension_supported_elliptic_curves; -pub mod extension_supported_point_formats; -pub mod extension_supported_signature_algorithms; -pub mod extension_use_extended_master_secret; -pub mod extension_use_srtp; -pub mod renegotiation_info; - -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use extension_server_name::*; -use extension_supported_elliptic_curves::*; -use extension_supported_point_formats::*; -use extension_supported_signature_algorithms::*; -use extension_use_extended_master_secret::*; -use extension_use_srtp::*; - -use crate::error::*; -use crate::extension::renegotiation_info::ExtensionRenegotiationInfo; - -// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum ExtensionValue { - ServerName = 0, - SupportedEllipticCurves = 10, - SupportedPointFormats = 11, - SupportedSignatureAlgorithms = 13, - UseSrtp = 14, - UseExtendedMasterSecret = 23, - RenegotiationInfo = 65281, - Unsupported, -} - -impl From for ExtensionValue { - fn from(val: u16) -> Self { - match val { - 0 => ExtensionValue::ServerName, - 10 => ExtensionValue::SupportedEllipticCurves, - 11 => ExtensionValue::SupportedPointFormats, - 13 => ExtensionValue::SupportedSignatureAlgorithms, - 14 => ExtensionValue::UseSrtp, - 23 => ExtensionValue::UseExtendedMasterSecret, - 65281 => ExtensionValue::RenegotiationInfo, - _ => ExtensionValue::Unsupported, - } - } -} - -#[derive(PartialEq, Eq, Debug, Clone)] -pub enum Extension { - ServerName(ExtensionServerName), - SupportedEllipticCurves(ExtensionSupportedEllipticCurves), - SupportedPointFormats(ExtensionSupportedPointFormats), - SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms), - UseSrtp(ExtensionUseSrtp), - UseExtendedMasterSecret(ExtensionUseExtendedMasterSecret), - RenegotiationInfo(ExtensionRenegotiationInfo), -} - -impl Extension { - pub fn extension_value(&self) -> ExtensionValue { - match self { - Extension::ServerName(ext) => ext.extension_value(), - Extension::SupportedEllipticCurves(ext) => ext.extension_value(), - Extension::SupportedPointFormats(ext) => ext.extension_value(), - Extension::SupportedSignatureAlgorithms(ext) => ext.extension_value(), - Extension::UseSrtp(ext) => ext.extension_value(), - Extension::UseExtendedMasterSecret(ext) => ext.extension_value(), - Extension::RenegotiationInfo(ext) => ext.extension_value(), - } - } - - pub fn size(&self) -> usize { - let mut len = 2; - - len += match self { - Extension::ServerName(ext) => ext.size(), - Extension::SupportedEllipticCurves(ext) => ext.size(), - Extension::SupportedPointFormats(ext) => ext.size(), - Extension::SupportedSignatureAlgorithms(ext) => ext.size(), - Extension::UseSrtp(ext) => ext.size(), - Extension::UseExtendedMasterSecret(ext) => ext.size(), - Extension::RenegotiationInfo(ext) => ext.size(), - }; - - len - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(self.extension_value() as u16)?; - match self { - Extension::ServerName(ext) => ext.marshal(writer), - Extension::SupportedEllipticCurves(ext) => ext.marshal(writer), - Extension::SupportedPointFormats(ext) => ext.marshal(writer), - Extension::SupportedSignatureAlgorithms(ext) => ext.marshal(writer), - Extension::UseSrtp(ext) => ext.marshal(writer), - Extension::UseExtendedMasterSecret(ext) => ext.marshal(writer), - Extension::RenegotiationInfo(ext) => ext.marshal(writer), - } - } - - pub fn unmarshal(reader: &mut R) -> Result { - let extension_value: ExtensionValue = reader.read_u16::()?.into(); - match extension_value { - ExtensionValue::ServerName => Ok(Extension::ServerName( - ExtensionServerName::unmarshal(reader)?, - )), - ExtensionValue::SupportedEllipticCurves => Ok(Extension::SupportedEllipticCurves( - ExtensionSupportedEllipticCurves::unmarshal(reader)?, - )), - ExtensionValue::SupportedPointFormats => Ok(Extension::SupportedPointFormats( - ExtensionSupportedPointFormats::unmarshal(reader)?, - )), - ExtensionValue::SupportedSignatureAlgorithms => { - Ok(Extension::SupportedSignatureAlgorithms( - ExtensionSupportedSignatureAlgorithms::unmarshal(reader)?, - )) - } - ExtensionValue::UseSrtp => Ok(Extension::UseSrtp(ExtensionUseSrtp::unmarshal(reader)?)), - ExtensionValue::UseExtendedMasterSecret => Ok(Extension::UseExtendedMasterSecret( - ExtensionUseExtendedMasterSecret::unmarshal(reader)?, - )), - ExtensionValue::RenegotiationInfo => Ok(Extension::RenegotiationInfo( - ExtensionRenegotiationInfo::unmarshal(reader)?, - )), - _ => Err(Error::ErrInvalidExtensionType), - } - } -} diff --git a/dtls/src/extension/renegotiation_info.rs b/dtls/src/extension/renegotiation_info.rs deleted file mode 100644 index ab3d29ab8..000000000 --- a/dtls/src/extension/renegotiation_info.rs +++ /dev/null @@ -1,48 +0,0 @@ -#[cfg(test)] -mod renegotiation_info_test; - -use super::*; -use crate::error::Error::ErrInvalidPacketLength; - -const RENEGOTIATION_INFO_HEADER_SIZE: usize = 5; - -/// RenegotiationInfo allows a Client/Server to -/// communicate their renegotiation support -/// https://tools.ietf.org/html/rfc5746 -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ExtensionRenegotiationInfo { - pub(crate) renegotiated_connection: u8, -} - -impl ExtensionRenegotiationInfo { - // TypeValue returns the extension TypeValue - pub fn extension_value(&self) -> ExtensionValue { - ExtensionValue::RenegotiationInfo - } - - pub fn size(&self) -> usize { - 3 - } - - /// marshal encodes the extension - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(1)?; //length - writer.write_u8(self.renegotiated_connection)?; - - Ok(writer.flush()?) - } - - /// Unmarshal populates the extension from encoded data - pub fn unmarshal(reader: &mut R) -> Result { - let l = reader.read_u16::()?; //length - if l != 1 { - return Err(ErrInvalidPacketLength); - } - - let renegotiated_connection = reader.read_u8()?; - - Ok(ExtensionRenegotiationInfo { - renegotiated_connection, - }) - } -} diff --git a/dtls/src/extension/renegotiation_info/renegotiation_info_test.rs b/dtls/src/extension/renegotiation_info/renegotiation_info_test.rs deleted file mode 100644 index 5266a94c3..000000000 --- a/dtls/src/extension/renegotiation_info/renegotiation_info_test.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_renegotiation_info() -> Result<()> { - let extension = ExtensionRenegotiationInfo { - renegotiated_connection: 0, - }; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - extension.marshal(&mut writer)?; - } - - let mut reader = BufReader::new(raw.as_slice()); - let new_extension = ExtensionRenegotiationInfo::unmarshal(&mut reader)?; - - assert_eq!( - new_extension.renegotiated_connection, - extension.renegotiated_connection - ); - - Ok(()) -} diff --git a/dtls/src/flight/flight0.rs b/dtls/src/flight/flight0.rs deleted file mode 100644 index 834d7c098..000000000 --- a/dtls/src/flight/flight0.rs +++ /dev/null @@ -1,203 +0,0 @@ -use std::fmt; -use std::sync::atomic::Ordering; - -use async_trait::async_trait; -use rand::Rng; - -use super::flight2::*; -use super::*; -use crate::config::*; -use crate::conn::*; -use crate::error::Error; -use crate::extension::*; -use crate::handshake::*; -use crate::record_layer::record_layer_header::*; -use crate::*; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight0; - -impl fmt::Display for Flight0 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 0") - } -} - -#[async_trait] -impl Flight for Flight0 { - async fn parse( - &self, - _tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let (seq, msgs) = match cache - .full_pull_map( - 0, - &[HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }], - ) - .await - { - Ok((seq, msgs)) => (seq, msgs), - Err(_) => return Err((None, None)), - }; - - state.handshake_recv_sequence = seq; - - if let Some(message) = msgs.get(&HandshakeType::ClientHello) { - // Validate type - let client_hello = match message { - HandshakeMessage::ClientHello(client_hello) => client_hello, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - if client_hello.version != PROTOCOL_VERSION1_2 { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::ProtocolVersion, - }), - Some(Error::ErrUnsupportedProtocolVersion), - )); - } - - state.remote_random = client_hello.random.clone(); - - if let Ok(id) = - find_matching_cipher_suite(&client_hello.cipher_suites, &cfg.local_cipher_suites) - { - if let Ok(cipher_suite) = cipher_suite_for_id(id) { - log::debug!( - "[handshake:{}] use cipher suite: {}", - srv_cli_str(state.is_client), - cipher_suite.to_string() - ); - let mut cs = state.cipher_suite.lock().await; - *cs = Some(cipher_suite); - } - } else { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrCipherSuiteNoIntersection), - )); - } - - for extension in &client_hello.extensions { - match extension { - Extension::SupportedEllipticCurves(e) => { - if e.elliptic_curves.is_empty() { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrNoSupportedEllipticCurves), - )); - } - state.named_curve = e.elliptic_curves[0]; - } - Extension::UseSrtp(e) => { - if let Ok(profile) = find_matching_srtp_profile( - &e.protection_profiles, - &cfg.local_srtp_protection_profiles, - ) { - state.srtp_protection_profile = profile; - } else { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrServerNoMatchingSrtpProfile), - )); - } - } - Extension::UseExtendedMasterSecret(_) => { - if cfg.extended_master_secret != ExtendedMasterSecretType::Disable { - state.extended_master_secret = true; - } - } - Extension::ServerName(e) => { - state.server_name.clone_from(&e.server_name); // remote server name - } - _ => {} - } - } - - if cfg.extended_master_secret == ExtendedMasterSecretType::Require - && !state.extended_master_secret - { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrServerRequiredButNoClientEms), - )); - } - - if state.local_keypair.is_none() { - state.local_keypair = match state.named_curve.generate_keypair() { - Ok(local_keypar) => Some(local_keypar), - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::IllegalParameter, - }), - Some(err), - )) - } - }; - } - - Ok(Box::new(Flight2 {})) - } else { - Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - } - - async fn generate( - &self, - state: &mut State, - _cache: &HandshakeCache, - _cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - // Initialize - state.cookie = vec![0; COOKIE_LENGTH]; - rand::thread_rng().fill(state.cookie.as_mut_slice()); - - //TODO: figure out difference between golang's atom store and rust atom store - let zero_epoch = 0; - state.local_epoch.store(zero_epoch, Ordering::SeqCst); - state.remote_epoch.store(zero_epoch, Ordering::SeqCst); - - state.named_curve = DEFAULT_NAMED_CURVE; - state.local_random.populate(); - - Ok(vec![]) - } -} diff --git a/dtls/src/flight/flight1.rs b/dtls/src/flight/flight1.rs deleted file mode 100644 index 0cb37e62f..000000000 --- a/dtls/src/flight/flight1.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::fmt; -use std::sync::atomic::Ordering; - -use async_trait::async_trait; - -use super::flight3::*; -use super::*; -use crate::compression_methods::*; -use crate::config::*; -use crate::conn::*; -use crate::content::*; -use crate::curve::named_curve::*; -use crate::error::Error; -use crate::extension::extension_server_name::*; -use crate::extension::extension_supported_elliptic_curves::*; -use crate::extension::extension_supported_point_formats::*; -use crate::extension::extension_supported_signature_algorithms::*; -use crate::extension::extension_use_extended_master_secret::*; -use crate::extension::extension_use_srtp::*; -use crate::extension::renegotiation_info::ExtensionRenegotiationInfo; -use crate::extension::*; -use crate::handshake::handshake_message_client_hello::*; -use crate::handshake::*; -use crate::record_layer::record_layer_header::*; -use crate::record_layer::*; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight1; - -impl fmt::Display for Flight1 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 1") - } -} - -#[async_trait] -impl Flight for Flight1 { - async fn parse( - &self, - tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - // HelloVerifyRequest can be skipped by the server, - // so allow ServerHello during flight1 also - let (seq, msgs) = match cache - .full_pull_map( - state.handshake_recv_sequence, - &[ - HandshakeCachePullRule { - typ: HandshakeType::HelloVerifyRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: true, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: true, - }, - ], - ) - .await - { - // No valid message received. Keep reading - Ok((seq, msgs)) => (seq, msgs), - Err(_) => return Err((None, None)), - }; - - if msgs.contains_key(&HandshakeType::ServerHello) { - // Flight1 and flight2 were skipped. - // Parse as flight3. - let flight3 = Flight3 {}; - return flight3.parse(tx, state, cache, cfg).await; - } - - if let Some(message) = msgs.get(&HandshakeType::HelloVerifyRequest) { - // DTLS 1.2 clients must not assume that the server will use the protocol version - // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 - let h = match message { - HandshakeMessage::HelloVerifyRequest(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - if h.version != PROTOCOL_VERSION1_0 && h.version != PROTOCOL_VERSION1_2 { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::ProtocolVersion, - }), - Some(Error::ErrUnsupportedProtocolVersion), - )); - } - - state.cookie.clone_from(&h.cookie); - state.handshake_recv_sequence = seq; - Ok(Box::new(Flight3 {})) - } else { - Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - } - - async fn generate( - &self, - state: &mut State, - _cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let zero_epoch = 0; - state.local_epoch.store(zero_epoch, Ordering::SeqCst); - state.remote_epoch.store(zero_epoch, Ordering::SeqCst); - - state.named_curve = DEFAULT_NAMED_CURVE; - state.cookie = vec![]; - state.local_random.populate(); - - let mut extensions = vec![ - Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms { - signature_hash_algorithms: cfg.local_signature_schemes.clone(), - }), - Extension::RenegotiationInfo(ExtensionRenegotiationInfo { - renegotiated_connection: 0, - }), - ]; - - if cfg.local_psk_callback.is_none() { - extensions.extend_from_slice(&[ - Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves { - elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384], - }), - Extension::SupportedPointFormats(ExtensionSupportedPointFormats { - point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED], - }), - ]); - } - - if !cfg.local_srtp_protection_profiles.is_empty() { - extensions.push(Extension::UseSrtp(ExtensionUseSrtp { - protection_profiles: cfg.local_srtp_protection_profiles.clone(), - })); - } - - if cfg.extended_master_secret == ExtendedMasterSecretType::Request - || cfg.extended_master_secret == ExtendedMasterSecretType::Require - { - extensions.push(Extension::UseExtendedMasterSecret( - ExtensionUseExtendedMasterSecret { supported: true }, - )); - } - - if !cfg.server_name.is_empty() { - extensions.push(Extension::ServerName(ExtensionServerName { - server_name: cfg.server_name.clone(), - })); - } - - Ok(vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: PROTOCOL_VERSION1_2, - random: state.local_random.clone(), - cookie: state.cookie.clone(), - - cipher_suites: cfg.local_cipher_suites.clone(), - compression_methods: default_compression_methods(), - extensions, - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }]) - } -} diff --git a/dtls/src/flight/flight2.rs b/dtls/src/flight/flight2.rs deleted file mode 100644 index 9703c9aaa..000000000 --- a/dtls/src/flight/flight2.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::fmt; - -use async_trait::async_trait; - -use super::flight0::*; -use super::flight4::*; -use super::*; -use crate::content::*; -use crate::error::Error; -use crate::handshake::handshake_message_hello_verify_request::*; -use crate::handshake::*; -use crate::record_layer::record_layer_header::*; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight2; - -impl fmt::Display for Flight2 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 2") - } -} - -#[async_trait] -impl Flight for Flight2 { - fn has_retransmit(&self) -> bool { - false - } - - async fn parse( - &self, - tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let (seq, msgs) = match cache - .full_pull_map( - state.handshake_recv_sequence, - &[HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }], - ) - .await - { - // No valid message received. Keep reading - Ok((seq, msgs)) => (seq, msgs), - - // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped. - // Parse as flight 0 in this case. - Err(_) => return Flight0 {}.parse(tx, state, cache, cfg).await, - }; - - state.handshake_recv_sequence = seq; - - if let Some(message) = msgs.get(&HandshakeType::ClientHello) { - // Validate type - let client_hello = match message { - HandshakeMessage::ClientHello(client_hello) => client_hello, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - if client_hello.version != PROTOCOL_VERSION1_2 { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::ProtocolVersion, - }), - Some(Error::ErrUnsupportedProtocolVersion), - )); - } - - if client_hello.cookie.is_empty() { - return Err((None, None)); - } - - if state.cookie != client_hello.cookie { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::AccessDenied, - }), - Some(Error::ErrCookieMismatch), - )); - } - - Ok(Box::new(Flight4 {})) - } else { - Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - } - - async fn generate( - &self, - state: &mut State, - _cache: &HandshakeCache, - _cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - state.handshake_send_sequence = 0; - Ok(vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::HelloVerifyRequest( - HandshakeMessageHelloVerifyRequest { - version: PROTOCOL_VERSION1_2, - cookie: state.cookie.clone(), - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }]) - } -} diff --git a/dtls/src/flight/flight3.rs b/dtls/src/flight/flight3.rs deleted file mode 100644 index d84e3bdd0..000000000 --- a/dtls/src/flight/flight3.rs +++ /dev/null @@ -1,470 +0,0 @@ -use std::fmt; - -use async_trait::async_trait; -use log::*; - -use super::flight5::*; -use super::*; -use crate::cipher_suite::cipher_suite_for_id; -use crate::compression_methods::*; -use crate::config::*; -use crate::content::*; -use crate::curve::named_curve::*; -use crate::error::Error; -use crate::extension::extension_server_name::*; -use crate::extension::extension_supported_elliptic_curves::*; -use crate::extension::extension_supported_point_formats::*; -use crate::extension::extension_supported_signature_algorithms::*; -use crate::extension::extension_use_extended_master_secret::*; -use crate::extension::extension_use_srtp::*; -use crate::extension::renegotiation_info::ExtensionRenegotiationInfo; -use crate::extension::*; -use crate::handshake::handshake_message_client_hello::*; -use crate::handshake::handshake_message_server_key_exchange::*; -use crate::handshake::*; -use crate::prf::{prf_pre_master_secret, prf_psk_pre_master_secret}; -use crate::record_layer::record_layer_header::*; -use crate::record_layer::*; -use crate::{find_matching_cipher_suite, find_matching_srtp_profile}; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight3; - -impl fmt::Display for Flight3 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 3") - } -} - -#[async_trait] -impl Flight for Flight3 { - async fn parse( - &self, - _tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - // Clients may receive multiple HelloVerifyRequest messages with different cookies. - // Clients SHOULD handle this by sending a new ClientHello with a cookie in response - // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 - if let Ok((seq, msgs)) = cache - .full_pull_map( - state.handshake_recv_sequence, - &[HandshakeCachePullRule { - typ: HandshakeType::HelloVerifyRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: true, - }], - ) - .await - { - if let Some(message) = msgs.get(&HandshakeType::HelloVerifyRequest) { - // DTLS 1.2 clients must not assume that the server will use the protocol version - // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 - let h = match message { - HandshakeMessage::HelloVerifyRequest(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - // DTLS 1.2 clients must not assume that the server will use the protocol version - // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 - if h.version != PROTOCOL_VERSION1_0 && h.version != PROTOCOL_VERSION1_2 { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::ProtocolVersion, - }), - Some(Error::ErrUnsupportedProtocolVersion), - )); - } - - state.cookie.clone_from(&h.cookie); - state.handshake_recv_sequence = seq; - return Ok(Box::new(Flight3 {}) as Box); - } - } - - let result = if cfg.local_psk_callback.is_some() { - cache - .full_pull_map( - state.handshake_recv_sequence, - &[ - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: true, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - ], - ) - .await - } else { - cache - .full_pull_map( - state.handshake_recv_sequence, - &[ - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: false, - optional: true, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: true, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - ], - ) - .await - }; - - let (seq, msgs) = match result { - Ok((seq, msgs)) => (seq, msgs), - Err(_) => return Err((None, None)), - }; - - state.handshake_recv_sequence = seq; - - if let Some(message) = msgs.get(&HandshakeType::ServerHello) { - let h = match message { - HandshakeMessage::ServerHello(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - if h.version != PROTOCOL_VERSION1_2 { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::ProtocolVersion, - }), - Some(Error::ErrUnsupportedProtocolVersion), - )); - } - - for extension in &h.extensions { - match extension { - Extension::UseSrtp(e) => { - let profile = match find_matching_srtp_profile( - &e.protection_profiles, - &cfg.local_srtp_protection_profiles, - ) { - Ok(profile) => profile, - Err(_) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::IllegalParameter, - }), - Some(Error::ErrClientNoMatchingSrtpProfile), - )) - } - }; - state.srtp_protection_profile = profile; - } - Extension::UseExtendedMasterSecret(_) => { - if cfg.extended_master_secret != ExtendedMasterSecretType::Disable { - state.extended_master_secret = true; - } - } - _ => {} - }; - } - - if cfg.extended_master_secret == ExtendedMasterSecretType::Require - && !state.extended_master_secret - { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrClientRequiredButNoServerEms), - )); - } - if !cfg.local_srtp_protection_profiles.is_empty() - && state.srtp_protection_profile == SrtpProtectionProfile::Unsupported - { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrRequestedButNoSrtpExtension), - )); - } - if find_matching_cipher_suite(&[h.cipher_suite], &cfg.local_cipher_suites).is_err() { - debug!( - "[handshake:{}] use cipher suite: {}", - srv_cli_str(state.is_client), - h.cipher_suite - ); - - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrCipherSuiteNoIntersection), - )); - } - - let cipher_suite = match cipher_suite_for_id(h.cipher_suite) { - Ok(cipher_suite) => cipher_suite, - Err(_) => { - debug!( - "[handshake:{}] use cipher suite: {}", - srv_cli_str(state.is_client), - h.cipher_suite - ); - - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrInvalidCipherSuite), - )); - } - }; - - trace!( - "[handshake:{}] use cipher suite: {}", - srv_cli_str(state.is_client), - cipher_suite.to_string() - ); - { - let mut cs = state.cipher_suite.lock().await; - *cs = Some(cipher_suite); - } - state.remote_random = h.random.clone(); - } - - if let Some(message) = msgs.get(&HandshakeType::Certificate) { - let h = match message { - HandshakeMessage::Certificate(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - state.peer_certificates.clone_from(&h.certificate); - } - - if let Some(message) = msgs.get(&HandshakeType::ServerKeyExchange) { - let h = match message { - HandshakeMessage::ServerKeyExchange(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h) { - return Err((alert, err)); - } - } - - if let Some(message) = msgs.get(&HandshakeType::CertificateRequest) { - match message { - HandshakeMessage::CertificateRequest(_) => {} - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - state.remote_requested_certificate = true; - } - - Ok(Box::new(Flight5 {}) as Box) - } - - async fn generate( - &self, - state: &mut State, - _cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let mut extensions = vec![ - Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms { - signature_hash_algorithms: cfg.local_signature_schemes.clone(), - }), - Extension::RenegotiationInfo(ExtensionRenegotiationInfo { - renegotiated_connection: 0, - }), - ]; - - if cfg.local_psk_callback.is_none() { - extensions.extend_from_slice(&[ - Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves { - elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384], - }), - Extension::SupportedPointFormats(ExtensionSupportedPointFormats { - point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED], - }), - ]); - } - - if !cfg.local_srtp_protection_profiles.is_empty() { - extensions.push(Extension::UseSrtp(ExtensionUseSrtp { - protection_profiles: cfg.local_srtp_protection_profiles.clone(), - })); - } - - if cfg.extended_master_secret == ExtendedMasterSecretType::Request - || cfg.extended_master_secret == ExtendedMasterSecretType::Require - { - extensions.push(Extension::UseExtendedMasterSecret( - ExtensionUseExtendedMasterSecret { supported: true }, - )); - } - - if !cfg.server_name.is_empty() { - extensions.push(Extension::ServerName(ExtensionServerName { - server_name: cfg.server_name.clone(), - })); - } - - Ok(vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( - HandshakeMessageClientHello { - version: PROTOCOL_VERSION1_2, - random: state.local_random.clone(), - cookie: state.cookie.clone(), - - cipher_suites: cfg.local_cipher_suites.clone(), - compression_methods: default_compression_methods(), - extensions, - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }]) - } -} - -pub(crate) fn handle_server_key_exchange( - state: &mut State, - cfg: &HandshakeConfig, - h: &HandshakeMessageServerKeyExchange, -) -> Result<(), (Option, Option)> { - if let Some(local_psk_callback) = &cfg.local_psk_callback { - let psk = match local_psk_callback(&h.identity_hint) { - Ok(psk) => psk, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state.identity_hint.clone_from(&h.identity_hint); - state.pre_master_secret = prf_psk_pre_master_secret(&psk); - } else { - let local_keypair = match h.named_curve.generate_keypair() { - Ok(local_keypair) => local_keypair, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state.pre_master_secret = match prf_pre_master_secret( - &h.public_key, - &local_keypair.private_key, - local_keypair.curve, - ) { - Ok(pre_master_secret) => pre_master_secret, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state.local_keypair = Some(local_keypair); - } - - Ok(()) -} diff --git a/dtls/src/flight/flight4.rs b/dtls/src/flight/flight4.rs deleted file mode 100644 index 11a77ffe1..000000000 --- a/dtls/src/flight/flight4.rs +++ /dev/null @@ -1,852 +0,0 @@ -use std::fmt; -use std::io::BufWriter; - -use async_trait::async_trait; -use log::*; - -use super::flight6::*; -use super::*; -use crate::cipher_suite::*; -use crate::client_certificate_type::*; -use crate::compression_methods::*; -use crate::config::*; -use crate::content::*; -use crate::crypto::*; -use crate::curve::named_curve::*; -use crate::curve::*; -use crate::error::Error; -use crate::extension::extension_supported_elliptic_curves::*; -use crate::extension::extension_supported_point_formats::*; -use crate::extension::extension_use_extended_master_secret::*; -use crate::extension::extension_use_srtp::*; -use crate::extension::renegotiation_info::ExtensionRenegotiationInfo; -use crate::extension::*; -use crate::handshake::handshake_message_certificate::*; -use crate::handshake::handshake_message_certificate_request::*; -use crate::handshake::handshake_message_server_hello::*; -use crate::handshake::handshake_message_server_hello_done::*; -use crate::handshake::handshake_message_server_key_exchange::*; -use crate::handshake::*; -use crate::prf::*; -use crate::record_layer::record_layer_header::*; -use crate::record_layer::*; -use crate::signature_hash_algorithm::*; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight4; - -impl fmt::Display for Flight4 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 4") - } -} - -#[async_trait] -impl Flight for Flight4 { - async fn parse( - &self, - tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let (seq, msgs) = match cache - .full_pull_map( - state.handshake_recv_sequence, - &[ - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: true, - optional: true, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateVerify, - epoch: cfg.initial_epoch, - is_client: true, - optional: true, - }, - ], - ) - .await - { - Ok((seq, msgs)) => (seq, msgs), - Err(_) => return Err((None, None)), - }; - - let client_key_exchange = if let Some(HandshakeMessage::ClientKeyExchange(h)) = - msgs.get(&HandshakeType::ClientKeyExchange) - { - h - } else { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )); - }; - - if let Some(message) = msgs.get(&HandshakeType::Certificate) { - let h = match message { - HandshakeMessage::Certificate(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - state.peer_certificates.clone_from(&h.certificate); - trace!( - "[handshake] PeerCertificates4 {}", - state.peer_certificates.len() - ); - } - - if let Some(message) = msgs.get(&HandshakeType::CertificateVerify) { - let h = match message { - HandshakeMessage::CertificateVerify(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - - if state.peer_certificates.is_empty() { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::NoCertificate, - }), - Some(Error::ErrCertificateVerifyNoCertificate), - )); - } - - let plain_text = cache - .pull_and_merge(&[ - HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - ]) - .await; - - // Verify that the pair of hash algorithm and signature is listed. - let mut valid_signature_scheme = false; - for ss in &cfg.local_signature_schemes { - if ss.hash == h.algorithm.hash && ss.signature == h.algorithm.signature { - valid_signature_scheme = true; - break; - } - } - if !valid_signature_scheme { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrNoAvailableSignatureSchemes), - )); - } - - if let Err(err) = verify_certificate_verify( - &plain_text, - &h.algorithm, - &h.signature, - &state.peer_certificates, - cfg.insecure_verification, - ) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(err), - )); - } - - let mut chains = vec![]; - let mut verified = false; - if cfg.client_auth as u8 >= ClientAuthType::VerifyClientCertIfGiven as u8 { - if let Some(client_cert_verifier) = &cfg.client_cert_verifier { - chains = - match verify_client_cert(&state.peer_certificates, client_cert_verifier) { - Ok(chains) => chains, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(err), - )) - } - }; - } else { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(Error::ErrInvalidCertificate), - )); - } - - verified = true - } - if let Some(verify_peer_certificate) = &cfg.verify_peer_certificate { - if let Err(err) = verify_peer_certificate(&state.peer_certificates, &chains) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(err), - )); - } - } - state.peer_certificates_verified = verified - } else if !state.peer_certificates.is_empty() { - // A certificate was received, but we haven't seen a CertificateVerify - // keep reading until we receive one - return Err((None, None)); - } - - { - let mut cipher_suite = state.cipher_suite.lock().await; - if let Some(cipher_suite) = &mut *cipher_suite { - if !cipher_suite.is_initialized() { - let mut server_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(server_random.as_mut()); - let _ = state.local_random.marshal(&mut writer); - } - let mut client_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(client_random.as_mut()); - let _ = state.remote_random.marshal(&mut writer); - } - - let mut pre_master_secret = vec![]; - if let Some(local_psk_callback) = &cfg.local_psk_callback { - let psk = match local_psk_callback(&client_key_exchange.identity_hint) { - Ok(psk) => psk, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state - .identity_hint - .clone_from(&client_key_exchange.identity_hint); - pre_master_secret = prf_psk_pre_master_secret(&psk); - } else if let Some(local_keypair) = &state.local_keypair { - pre_master_secret = match prf_pre_master_secret( - &client_key_exchange.public_key, - &local_keypair.private_key, - local_keypair.curve, - ) { - Ok(pre_master_secret) => pre_master_secret, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::IllegalParameter, - }), - Some(err), - )) - } - }; - } - - if state.extended_master_secret { - let hf = cipher_suite.hash_func(); - let session_hash = - match cache.session_hash(hf, cfg.initial_epoch, &[]).await { - Ok(s) => s, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state.master_secret = match prf_extended_master_secret( - &pre_master_secret, - &session_hash, - cipher_suite.hash_func(), - ) { - Ok(ms) => ms, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - } else { - state.master_secret = match prf_master_secret( - &pre_master_secret, - &client_random, - &server_random, - cipher_suite.hash_func(), - ) { - Ok(ms) => ms, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - } - - if let Err(err) = cipher_suite.init( - &state.master_secret, - &client_random, - &server_random, - false, - ) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )); - } - } - } - } - - // Now, encrypted packets can be handled - let (done_tx, mut done_rx) = mpsc::channel(1); - if let Err(err) = tx.send(done_tx).await { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(Error::Other(err.to_string())), - )); - } - - done_rx.recv().await; - - let (seq, msgs) = match cache - .full_pull_map( - seq, - &[HandshakeCachePullRule { - typ: HandshakeType::Finished, - epoch: cfg.initial_epoch + 1, - is_client: true, - optional: false, - }], - ) - .await - { - Ok((seq, msgs)) => (seq, msgs), - // No valid message received. Keep reading - Err(_) => return Err((None, None)), - }; - - state.handshake_recv_sequence = seq; - - if let Some(HandshakeMessage::Finished(h)) = msgs.get(&HandshakeType::Finished) { - h - } else { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )); - }; - - match cfg.client_auth { - ClientAuthType::RequireAnyClientCert => { - trace!( - "{} peer_certificates.len() {}", - srv_cli_str(state.is_client), - state.peer_certificates.len(), - ); - if state.peer_certificates.is_empty() { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::NoCertificate, - }), - Some(Error::ErrClientCertificateRequired), - )); - } - } - ClientAuthType::VerifyClientCertIfGiven => { - if !state.peer_certificates.is_empty() && !state.peer_certificates_verified { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(Error::ErrClientCertificateNotVerified), - )); - } - } - ClientAuthType::RequireAndVerifyClientCert => { - if state.peer_certificates.is_empty() { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::NoCertificate, - }), - Some(Error::ErrClientCertificateRequired), - )); - } - if !state.peer_certificates_verified { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(Error::ErrClientCertificateNotVerified), - )); - } - } - ClientAuthType::NoClientCert | ClientAuthType::RequestClientCert => { - return Ok(Box::new(Flight6 {}) as Box); - } - } - - Ok(Box::new(Flight6 {}) as Box) - } - - async fn generate( - &self, - state: &mut State, - _cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let mut extensions = vec![Extension::RenegotiationInfo(ExtensionRenegotiationInfo { - renegotiated_connection: 0, - })]; - if (cfg.extended_master_secret == ExtendedMasterSecretType::Request - || cfg.extended_master_secret == ExtendedMasterSecretType::Require) - && state.extended_master_secret - { - extensions.push(Extension::UseExtendedMasterSecret( - ExtensionUseExtendedMasterSecret { supported: true }, - )); - } - - if state.srtp_protection_profile != SrtpProtectionProfile::Unsupported { - extensions.push(Extension::UseSrtp(ExtensionUseSrtp { - protection_profiles: vec![state.srtp_protection_profile], - })); - } - - if cfg.local_psk_callback.is_none() { - extensions.extend_from_slice(&[ - Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves { - elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384], - }), - Extension::SupportedPointFormats(ExtensionSupportedPointFormats { - point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED], - }), - ]); - } - - let mut pkts = vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ServerHello( - HandshakeMessageServerHello { - version: PROTOCOL_VERSION1_2, - random: state.local_random.clone(), - cipher_suite: { - let cipher_suite = state.cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - cipher_suite.id() - } else { - CipherSuiteId::Unsupported - } - }, - compression_method: default_compression_methods().ids[0], - extensions, - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }]; - - if cfg.local_psk_callback.is_none() { - let certificate = match cfg.get_certificate(&cfg.server_name) { - Ok(cert) => cert, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::HandshakeFailure, - }), - Some(err), - )) - } - }; - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::Certificate( - HandshakeMessageCertificate { - certificate: certificate - .certificate - .iter() - .map(|x| x.as_ref().to_owned()) - .collect(), - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - - let mut server_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(server_random.as_mut()); - let _ = state.local_random.marshal(&mut writer); - } - let mut client_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(client_random.as_mut()); - let _ = state.remote_random.marshal(&mut writer); - } - - // Find compatible signature scheme - let signature_hash_algo = match select_signature_scheme( - &cfg.local_signature_schemes, - &certificate.private_key, - ) { - Ok(s) => s, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(err), - )) - } - }; - - if let Some(local_keypair) = &state.local_keypair { - let signature = match generate_key_signature( - &client_random, - &server_random, - &local_keypair.public_key, - state.named_curve, - &certificate.private_key, /*, signature_hash_algo.hash*/ - ) { - Ok(s) => s, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state.local_key_signature = signature; - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ServerKeyExchange( - HandshakeMessageServerKeyExchange { - identity_hint: vec![], - elliptic_curve_type: EllipticCurveType::NamedCurve, - named_curve: state.named_curve, - public_key: local_keypair.public_key.clone(), - algorithm: SignatureHashAlgorithm { - hash: signature_hash_algo.hash, - signature: signature_hash_algo.signature, - }, - signature: state.local_key_signature.clone(), - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - } - - if cfg.client_auth as u8 > ClientAuthType::NoClientCert as u8 { - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::CertificateRequest( - HandshakeMessageCertificateRequest { - certificate_types: vec![ - ClientCertificateType::RsaSign, - ClientCertificateType::EcdsaSign, - ], - signature_hash_algorithms: cfg.local_signature_schemes.clone(), - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - } - } else if let Some(local_psk_identity_hint) = &cfg.local_psk_identity_hint { - // To help the client in selecting which identity to use, the server - // can provide a "PSK identity hint" in the ServerKeyExchange message. - // If no hint is provided, the ServerKeyExchange message is omitted. - // - // https://tools.ietf.org/html/rfc4279#section-2 - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ServerKeyExchange( - HandshakeMessageServerKeyExchange { - identity_hint: local_psk_identity_hint.clone(), - elliptic_curve_type: EllipticCurveType::Unsupported, - named_curve: NamedCurve::Unsupported, - public_key: vec![], - algorithm: SignatureHashAlgorithm { - hash: HashAlgorithm::Unsupported, - signature: SignatureAlgorithm::Unsupported, - }, - signature: vec![], - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - } - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ServerHelloDone( - HandshakeMessageServerHelloDone {}, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - - Ok(pkts) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use tokio::sync::Mutex; - - use super::*; - use crate::error::Result; - - struct MockCipherSuite {} - - impl CipherSuite for MockCipherSuite { - fn to_string(&self) -> String { - "MockCipherSuite".into() - } - fn id(&self) -> CipherSuiteId { - unimplemented!(); - } - fn certificate_type(&self) -> ClientCertificateType { - unimplemented!(); - } - fn hash_func(&self) -> CipherSuiteHash { - unimplemented!(); - } - fn is_psk(&self) -> bool { - false - } - fn is_initialized(&self) -> bool { - panic!("is_initialized called with Certificate but not CertificateVerify"); - } - - // Generate the internal encryption state - fn init( - &mut self, - _master_secret: &[u8], - _client_random: &[u8], - _server_random: &[u8], - _is_client: bool, - ) -> Result<()> { - unimplemented!(); - } - - fn encrypt(&self, _pkt_rlh: &RecordLayerHeader, _raw: &[u8]) -> Result> { - unimplemented!(); - } - fn decrypt(&self, _input: &[u8]) -> Result> { - unimplemented!(); - } - } - - // Assert that if a client sends a certificate they must also send a `CertificateVerify` - // message. The `Flight4` must not interact with the `cipher_suite` if the `CertificateVerify` - // is missing. - #[tokio::test] - async fn test_flight4_process_certificateverify() { - let mut state = State { - cipher_suite: Arc::new(Mutex::new(Some(Box::new(MockCipherSuite {})))), - ..Default::default() - }; - - let raw_certificate = vec![ - 0x0b, 0x00, 0x01, 0x9b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x9b, 0x00, 0x01, - 0x98, 0x00, 0x01, 0x95, 0x30, 0x82, 0x01, 0x91, 0x30, 0x82, 0x01, 0x38, 0xa0, 0x03, - 0x02, 0x01, 0x02, 0x02, 0x11, 0x01, 0x65, 0x03, 0x3f, 0x4d, 0x0b, 0x9a, 0x62, 0x91, - 0xdb, 0x4d, 0x28, 0x2c, 0x1f, 0xd6, 0x73, 0x32, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, - 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x00, 0x30, 0x1e, 0x17, 0x0d, 0x32, 0x32, - 0x30, 0x35, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x17, 0x0d, 0x32, - 0x32, 0x30, 0x36, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x30, 0x00, - 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, - 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0xc3, - 0xb7, 0x13, 0x1a, 0x0a, 0xfc, 0xd0, 0x82, 0xf8, 0x94, 0x5e, 0xc0, 0x77, 0x07, 0x81, - 0x28, 0xc9, 0xcb, 0x08, 0x84, 0x50, 0x6b, 0xf0, 0x22, 0xe8, 0x79, 0xb9, 0x15, 0x33, - 0xc4, 0x56, 0xa1, 0xd3, 0x1b, 0x24, 0xe3, 0x61, 0xbd, 0x4d, 0x65, 0x80, 0x6b, 0x5d, - 0x96, 0x48, 0xa2, 0x44, 0x9e, 0xce, 0xe8, 0x65, 0xd6, 0x3c, 0xe0, 0x9b, 0x6b, 0xa1, - 0x36, 0x34, 0xb2, 0x39, 0xe2, 0x03, 0x00, 0xa3, 0x81, 0x92, 0x30, 0x81, 0x8f, 0x30, - 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x02, - 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, - 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, - 0x05, 0x07, 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, - 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, - 0x04, 0x16, 0x04, 0x14, 0xb1, 0x1a, 0xe3, 0xeb, 0x6f, 0x7c, 0xc3, 0x8f, 0xba, 0x6f, - 0x1c, 0xe8, 0xf0, 0x23, 0x08, 0x50, 0x8d, 0x3c, 0xea, 0x31, 0x30, 0x2e, 0x06, 0x03, - 0x55, 0x1d, 0x11, 0x01, 0x01, 0xff, 0x04, 0x24, 0x30, 0x22, 0x82, 0x20, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, - 0x03, 0x47, 0x00, 0x30, 0x44, 0x02, 0x20, 0x06, 0x31, 0x43, 0xac, 0x03, 0x45, 0x79, - 0x3c, 0xd7, 0x5f, 0x6e, 0x6a, 0xf8, 0x0e, 0xfd, 0x35, 0x49, 0xee, 0x1b, 0xbc, 0x47, - 0xce, 0xe3, 0x39, 0xec, 0xe4, 0x62, 0xe1, 0x30, 0x1a, 0xa1, 0x89, 0x02, 0x20, 0x35, - 0xcd, 0x7a, 0x15, 0x68, 0x09, 0x50, 0x49, 0x9e, 0x3e, 0x05, 0xd7, 0xc2, 0x69, 0x3f, - 0x9c, 0x0c, 0x98, 0x92, 0x65, 0xec, 0xae, 0x44, 0xfe, 0xe5, 0x68, 0xb8, 0x09, 0x78, - 0x7f, 0x6b, 0x77, - ]; - - let raw_client_key_exchange = vec![ - 0x10, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0x20, 0x96, - 0xed, 0x0c, 0xee, 0xf3, 0x11, 0xb1, 0x9d, 0x8b, 0x1c, 0x02, 0x7f, 0x06, 0x7c, 0x57, - 0x7a, 0x14, 0xa6, 0x41, 0xde, 0x63, 0x57, 0x9e, 0xcd, 0x34, 0x54, 0xba, 0x37, 0x4d, - 0x34, 0x15, 0x18, - ]; - - let mut cache = HandshakeCache::new(); - cache - .push(raw_certificate, 0, 0, HandshakeType::Certificate, true) - .await; - cache - .push( - raw_client_key_exchange, - 0, - 1, - HandshakeType::ClientKeyExchange, - true, - ) - .await; - - let cfg = HandshakeConfig::default(); - - let (mut tx, _rx) = mpsc::channel::>(1); - - let f = Flight4 {}; - let res = f.parse(&mut tx, &mut state, &cache, &cfg).await; - assert!(res.is_err()); - } -} diff --git a/dtls/src/flight/flight5.rs b/dtls/src/flight/flight5.rs deleted file mode 100644 index 264cd9e0a..000000000 --- a/dtls/src/flight/flight5.rs +++ /dev/null @@ -1,778 +0,0 @@ -use std::fmt; -use std::io::{BufReader, BufWriter}; - -use async_trait::async_trait; - -use super::flight3::*; -use super::*; -use crate::change_cipher_spec::ChangeCipherSpec; -use crate::content::*; -use crate::crypto::*; -use crate::curve::named_curve::*; -use crate::curve::*; -use crate::error::Error; -use crate::handshake::handshake_message_certificate::*; -use crate::handshake::handshake_message_certificate_verify::*; -use crate::handshake::handshake_message_client_key_exchange::*; -use crate::handshake::handshake_message_finished::*; -use crate::handshake::handshake_message_server_key_exchange::*; -use crate::handshake::*; -use crate::prf::*; -use crate::record_layer::record_layer_header::*; -use crate::record_layer::*; -use crate::signature_hash_algorithm::*; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight5; - -impl fmt::Display for Flight5 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 5") - } -} - -#[async_trait] -impl Flight for Flight5 { - fn is_last_recv_flight(&self) -> bool { - true - } - - async fn parse( - &self, - _tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let (_seq, msgs) = match cache - .full_pull_map( - state.handshake_recv_sequence, - &[HandshakeCachePullRule { - typ: HandshakeType::Finished, - epoch: cfg.initial_epoch + 1, - is_client: false, - optional: false, - }], - ) - .await - { - Ok((seq, msgs)) => (seq, msgs), - Err(_) => return Err((None, None)), - }; - - let finished = - if let Some(HandshakeMessage::Finished(h)) = msgs.get(&HandshakeType::Finished) { - h - } else { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )); - }; - - let plain_text = cache - .pull_and_merge(&[ - HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateVerify, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Finished, - epoch: cfg.initial_epoch + 1, - is_client: true, - optional: false, - }, - ]) - .await; - - { - let cipher_suite = state.cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - let expected_verify_data = match prf_verify_data_server( - &state.master_secret, - &plain_text, - cipher_suite.hash_func(), - ) { - Ok(d) => d, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(err), - )) - } - }; - - if expected_verify_data != finished.verify_data { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::HandshakeFailure, - }), - Some(Error::ErrVerifyDataMismatch), - )); - } - } - } - - Ok(Box::new(Flight5 {})) - } - - async fn generate( - &self, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let certificate = if !cfg.local_certificates.is_empty() { - let cert = match cfg.get_certificate(&cfg.server_name) { - Ok(cert) => cert, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::HandshakeFailure, - }), - Some(err), - )) - } - }; - Some(cert) - } else { - None - }; - - let mut pkts = vec![]; - - if state.remote_requested_certificate { - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::Certificate( - HandshakeMessageCertificate { - certificate: if let Some(cert) = &certificate { - cert.certificate - .iter() - .map(|x| x.as_ref().to_owned()) - .collect() - } else { - vec![] - }, - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - } - - let mut client_key_exchange = HandshakeMessageClientKeyExchange { - identity_hint: vec![], - public_key: vec![], - }; - if cfg.local_psk_callback.is_none() { - if let Some(local_keypair) = &state.local_keypair { - client_key_exchange - .public_key - .clone_from(&local_keypair.public_key); - } - } else if let Some(local_psk_identity_hint) = &cfg.local_psk_identity_hint { - client_key_exchange - .identity_hint - .clone_from(local_psk_identity_hint); - } - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::ClientKeyExchange( - client_key_exchange, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - - let server_key_exchange_data = cache - .pull_and_merge(&[HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }]) - .await; - - let mut server_key_exchange = HandshakeMessageServerKeyExchange { - identity_hint: vec![], - elliptic_curve_type: EllipticCurveType::Unsupported, - named_curve: NamedCurve::Unsupported, - public_key: vec![], - algorithm: SignatureHashAlgorithm { - hash: HashAlgorithm::Unsupported, - signature: SignatureAlgorithm::Unsupported, - }, - signature: vec![], - }; - - // handshakeMessageServerKeyExchange is optional for PSK - if server_key_exchange_data.is_empty() { - if let Err((alert, err)) = handle_server_key_exchange(state, cfg, &server_key_exchange) - { - return Err((alert, err)); - } - } else { - let mut reader = BufReader::new(server_key_exchange_data.as_slice()); - let raw_handshake = match Handshake::unmarshal(&mut reader) { - Ok(h) => h, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::UnexpectedMessage, - }), - Some(err), - )) - } - }; - - match raw_handshake.handshake_message { - HandshakeMessage::ServerKeyExchange(h) => server_key_exchange = h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::UnexpectedMessage, - }), - Some(Error::ErrInvalidContentType), - )) - } - }; - } - - // Append not-yet-sent packets - let mut merged = vec![]; - let mut seq_pred = state.handshake_send_sequence as u16; - for p in &mut pkts { - let h = match &mut p.record.content { - Content::Handshake(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(Error::ErrInvalidContentType), - )) - } - }; - h.handshake_header.message_sequence = seq_pred; - seq_pred += 1; - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - if let Err(err) = h.marshal(&mut writer) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )); - } - } - - merged.extend_from_slice(&raw); - } - - if let Err((alert, err)) = - initialize_cipher_suite(state, cache, cfg, &server_key_exchange, &merged).await - { - return Err((alert, err)); - } - - // If the client has sent a certificate with signing ability, a digitally-signed - // CertificateVerify message is sent to explicitly verify possession of the - // private key in the certificate. - if state.remote_requested_certificate && !cfg.local_certificates.is_empty() { - let mut plain_text = cache - .pull_and_merge(&[ - HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - ]) - .await; - - plain_text.extend_from_slice(&merged); - - // Find compatible signature scheme - let signature_hash_algo = match select_signature_scheme( - &cfg.local_signature_schemes, - &certificate.as_ref().unwrap().private_key, - ) { - Ok(s) => s, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(err), - )) - } - }; - - let cert_verify = match generate_certificate_verify( - &plain_text, - &certificate.as_ref().unwrap().private_key, /*, signature_hash_algo.hash*/ - ) { - Ok(cert) => cert, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - state.local_certificates_verify = cert_verify; - - let mut p = Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::Handshake(Handshake::new(HandshakeMessage::CertificateVerify( - HandshakeMessageCertificateVerify { - algorithm: signature_hash_algo, - signature: state.local_certificates_verify.clone(), - }, - ))), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }; - - let h = match &mut p.record.content { - Content::Handshake(h) => h, - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(Error::ErrInvalidContentType), - )) - } - }; - h.handshake_header.message_sequence = seq_pred; - - // seqPred++ // this is the last use of seqPred - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - if let Err(err) = h.marshal(&mut writer) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )); - } - } - merged.extend_from_slice(&raw); - - pkts.push(p); - } - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::ChangeCipherSpec(ChangeCipherSpec {}), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }); - - if state.local_verify_data.is_empty() { - let mut plain_text = cache - .pull_and_merge(&[ - HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateVerify, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Finished, - epoch: cfg.initial_epoch + 1, - is_client: true, - optional: false, - }, - ]) - .await; - - plain_text.extend_from_slice(&merged); - - let cipher_suite = state.cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - state.local_verify_data = match prf_verify_data_client( - &state.master_secret, - &plain_text, - cipher_suite.hash_func(), - ) { - Ok(data) => data, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - } - } - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 1, - Content::Handshake(Handshake::new(HandshakeMessage::Finished( - HandshakeMessageFinished { - verify_data: state.local_verify_data.clone(), - }, - ))), - ), - should_encrypt: true, - reset_local_sequence_number: true, - }); - - Ok(pkts) - } -} -async fn initialize_cipher_suite( - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - h: &HandshakeMessageServerKeyExchange, - sending_plain_text: &[u8], -) -> Result<(), (Option, Option)> { - let mut cipher_suite = state.cipher_suite.lock().await; - - if let Some(cipher_suite) = &*cipher_suite { - if cipher_suite.is_initialized() { - return Ok(()); - } - } - - let mut client_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(client_random.as_mut()); - let _ = state.local_random.marshal(&mut writer); - } - let mut server_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(server_random.as_mut()); - let _ = state.remote_random.marshal(&mut writer); - } - - if let Some(cipher_suite) = &*cipher_suite { - if state.extended_master_secret { - let session_hash = match cache - .session_hash( - cipher_suite.hash_func(), - cfg.initial_epoch, - sending_plain_text, - ) - .await - { - Ok(s) => s, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - - state.master_secret = match prf_extended_master_secret( - &state.pre_master_secret, - &session_hash, - cipher_suite.hash_func(), - ) { - Ok(m) => m, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::IllegalParameter, - }), - Some(err), - )) - } - }; - } else { - state.master_secret = match prf_master_secret( - &state.pre_master_secret, - &client_random, - &server_random, - cipher_suite.hash_func(), - ) { - Ok(m) => m, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - } - } - - if cfg.local_psk_callback.is_none() { - // Verify that the pair of hash algorithm and signiture is listed. - let mut valid_signature_scheme = false; - for ss in &cfg.local_signature_schemes { - if ss.hash == h.algorithm.hash && ss.signature == h.algorithm.signature { - valid_signature_scheme = true; - break; - } - } - if !valid_signature_scheme { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InsufficientSecurity, - }), - Some(Error::ErrNoAvailableSignatureSchemes), - )); - } - - let expected_msg = - value_key_message(&client_random, &server_random, &h.public_key, h.named_curve); - if let Err(err) = verify_key_signature( - &expected_msg, - &h.algorithm, - &h.signature, - &state.peer_certificates, - cfg.insecure_verification, - ) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(err), - )); - } - - let mut chains = vec![]; - if !cfg.insecure_skip_verify { - chains = match verify_server_cert( - &state.peer_certificates, - &cfg.server_cert_verifier, - &cfg.server_name, - ) { - Ok(chains) => chains, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(err), - )) - } - } - } - if let Some(verify_peer_certificate) = &cfg.verify_peer_certificate { - if let Err(err) = verify_peer_certificate(&state.peer_certificates, &chains) { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::BadCertificate, - }), - Some(err), - )); - } - } - } - - if let Some(cipher_suite) = &mut *cipher_suite { - if let Err(err) = - cipher_suite.init(&state.master_secret, &client_random, &server_random, true) - { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )); - } - } - - Ok(()) -} diff --git a/dtls/src/flight/flight6.rs b/dtls/src/flight/flight6.rs deleted file mode 100644 index 1e9b00362..000000000 --- a/dtls/src/flight/flight6.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::fmt; - -use async_trait::async_trait; - -use super::*; -use crate::change_cipher_spec::*; -use crate::content::*; -use crate::handshake::handshake_message_finished::*; -use crate::handshake::*; -use crate::prf::*; -use crate::record_layer::record_layer_header::*; - -#[derive(Debug, PartialEq)] -pub(crate) struct Flight6; - -impl fmt::Display for Flight6 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Flight 6") - } -} - -#[async_trait] -impl Flight for Flight6 { - fn is_last_send_flight(&self) -> bool { - true - } - - async fn parse( - &self, - _tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let (_, msgs) = match cache - .full_pull_map( - state.handshake_recv_sequence - 1, - &[HandshakeCachePullRule { - typ: HandshakeType::Finished, - epoch: cfg.initial_epoch + 1, - is_client: true, - optional: false, - }], - ) - .await - { - Ok((seq, msgs)) => (seq, msgs), - // No valid message received. Keep reading - Err(_) => return Err((None, None)), - }; - - if let Some(message) = msgs.get(&HandshakeType::Finished) { - match message { - HandshakeMessage::Finished(_) => {} - _ => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - None, - )) - } - }; - } - - // Other party retransmitted the last flight. - Ok(Box::new(Flight6 {})) - } - - async fn generate( - &self, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)> { - let mut pkts = vec![Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 0, - Content::ChangeCipherSpec(ChangeCipherSpec {}), - ), - should_encrypt: false, - reset_local_sequence_number: false, - }]; - - if state.local_verify_data.is_empty() { - let plain_text = cache - .pull_and_merge(&[ - HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch: cfg.initial_epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateVerify, - epoch: cfg.initial_epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Finished, - epoch: cfg.initial_epoch + 1, - is_client: true, - optional: false, - }, - ]) - .await; - - let cipher_suite = state.cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - state.local_verify_data = match prf_verify_data_server( - &state.master_secret, - &plain_text, - cipher_suite.hash_func(), - ) { - Ok(data) => data, - Err(err) => { - return Err(( - Some(Alert { - alert_level: AlertLevel::Fatal, - alert_description: AlertDescription::InternalError, - }), - Some(err), - )) - } - }; - } - } - - pkts.push(Packet { - record: RecordLayer::new( - PROTOCOL_VERSION1_2, - 1, - Content::Handshake(Handshake::new(HandshakeMessage::Finished( - HandshakeMessageFinished { - verify_data: state.local_verify_data.clone(), - }, - ))), - ), - should_encrypt: true, - reset_local_sequence_number: true, - }); - - Ok(pkts) - } -} diff --git a/dtls/src/flight/mod.rs b/dtls/src/flight/mod.rs deleted file mode 100644 index 8e6b41e6f..000000000 --- a/dtls/src/flight/mod.rs +++ /dev/null @@ -1,87 +0,0 @@ -pub(crate) mod flight0; -pub(crate) mod flight1; -pub(crate) mod flight2; -pub(crate) mod flight3; -pub(crate) mod flight4; -pub(crate) mod flight5; -pub(crate) mod flight6; - -use std::fmt; - -use async_trait::async_trait; -use tokio::sync::mpsc; - -use crate::alert::*; -use crate::error::Error; -use crate::handshake::handshake_cache::*; -use crate::handshaker::*; -use crate::record_layer::*; -use crate::state::*; - -/* - DTLS messages are grouped into a series of message flights, according - to the diagrams below. Although each Flight of messages may consist - of a number of messages, they should be viewed as monolithic for the - purpose of timeout and retransmission. - https://tools.ietf.org/html/rfc4347#section-4.2.4 - Client Server - ------ ------ - Waiting Flight 0 - - ClientHello --------> Flight 1 - - <------- HelloVerifyRequest Flight 2 - - ClientHello --------> Flight 3 - - ServerHello \ - Certificate* \ - ServerKeyExchange* Flight 4 - CertificateRequest* / - <-------- ServerHelloDone / - - Certificate* \ - ClientKeyExchange \ - CertificateVerify* Flight 5 - [ChangeCipherSpec] / - Finished --------> / - - [ChangeCipherSpec] \ Flight 6 - <-------- Finished / - -*/ - -#[derive(Clone, Debug)] -pub(crate) struct Packet { - pub(crate) record: RecordLayer, - pub(crate) should_encrypt: bool, - pub(crate) reset_local_sequence_number: bool, -} - -#[async_trait] -pub(crate) trait Flight: fmt::Display + fmt::Debug { - fn is_last_send_flight(&self) -> bool { - false - } - fn is_last_recv_flight(&self) -> bool { - false - } - fn has_retransmit(&self) -> bool { - true - } - - async fn parse( - &self, - tx: &mut mpsc::Sender>, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)>; - - async fn generate( - &self, - state: &mut State, - cache: &HandshakeCache, - cfg: &HandshakeConfig, - ) -> Result, (Option, Option)>; -} diff --git a/dtls/src/fragment_buffer/fragment_buffer_test.rs b/dtls/src/fragment_buffer/fragment_buffer_test.rs deleted file mode 100644 index 0e9090809..000000000 --- a/dtls/src/fragment_buffer/fragment_buffer_test.rs +++ /dev/null @@ -1,174 +0,0 @@ -use super::*; - -#[test] -fn test_fragment_buffer() -> Result<()> { - let tests = vec![ - ( - "Single Fragment", - vec![vec![ - 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, - 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, - ]], - vec![vec![ - 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, - 0x00, - ]], - 0, - ), - ( - "Single Fragment Epoch 3", - vec![vec![ - 0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, - 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, - ]], - vec![vec![ - 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, - 0x00, - ]], - 3, - ), - ( - "Multiple Fragments", - vec![ - vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, - 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, - 0x01, 0x02, 0x03, 0x04, - ], - vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, - 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, - 0x06, 0x07, 0x08, 0x09, - ], - vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, - 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, - ], - ], - vec![vec![ - 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, - 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - ]], - 0, - ), - ( - "Multiple Unordered Fragments", - vec![ - vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, - 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, - 0x01, 0x02, 0x03, 0x04, - ], - vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, - 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, - 0x0B, 0x0C, 0x0D, 0x0E, - ], - vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, - 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, - 0x06, 0x07, 0x08, 0x09, - ], - ], - vec![vec![ - 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, - 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - ]], - 0, - ), - ( - "Multiple Handshakes in Single Fragment", - vec![vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x30, /* record header */ - 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, - 0x01, 0x01, /*handshake msg 1*/ - 0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, - 0x01, 0x01, /*handshake msg 2*/ - 0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, - 0x01, 0x01, /*handshake msg 3*/ - ]], - vec![ - vec![ - 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, - 0xff, 0x01, 0x01, - ], - vec![ - 0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, - 0xff, 0x01, 0x01, - ], - vec![ - 0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, - 0xff, 0x01, 0x01, - ], - ], - 0, - ), - // Ensure zero length fragments don't cause an infinite recursive loop which in turn causes - // a stack overflow. - ( - "Zero length fragment", - vec![vec![ - 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]], - vec![vec![ - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, - ]], - 0, - ), - ]; - - for (name, inputs, expects, expected_epoch) in tests { - let mut fragment_buffer = FragmentBuffer::new(); - for frag in inputs { - let status = fragment_buffer.push(&frag)?; - assert!( - status, - "fragment_buffer didn't accept fragments for '{name}'" - ); - } - - for expected in expects { - let (out, epoch) = fragment_buffer.pop()?; - assert_eq!( - out, expected, - "fragment_buffer '{name}' push/pop: got {out:?}, want {expected:?}" - ); - - assert_eq!( - epoch, expected_epoch, - "fragment_buffer returned wrong epoch: got {epoch}, want {expected_epoch}" - ); - } - - let result = fragment_buffer.pop(); - assert!( - result.is_err(), - "fragment_buffer popped single buffer multiple times for '{name}'" - ); - } - - Ok(()) -} - -#[test] -fn test_fragment_buffer_overflow() -> Result<()> { - let mut fragment_buffer = FragmentBuffer::new(); - - fragment_buffer.push(&[ - 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, - 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, - ])?; - - let big_buffer = vec![0; 2_000_000]; - let result = fragment_buffer.push(&big_buffer); - - assert!( - result.is_err(), - "Pushing a buffer of size 2MB should have caused FragmentBuffer::push to return an error" - ); - - Ok(()) -} diff --git a/dtls/src/fragment_buffer/mod.rs b/dtls/src/fragment_buffer/mod.rs deleted file mode 100644 index d375bc2c2..000000000 --- a/dtls/src/fragment_buffer/mod.rs +++ /dev/null @@ -1,160 +0,0 @@ -#[cfg(test)] -mod fragment_buffer_test; - -use std::collections::HashMap; -use std::io::{BufWriter, Cursor}; - -use crate::content::*; -use crate::error::*; -use crate::handshake::handshake_header::*; -use crate::record_layer::record_layer_header::*; - -// 2 mb max buffer size -const FRAGMENT_BUFFER_MAX_SIZE: usize = 2_000_000; - -pub(crate) struct Fragment { - record_layer_header: RecordLayerHeader, - handshake_header: HandshakeHeader, - data: Vec, -} - -pub(crate) struct FragmentBuffer { - // map of MessageSequenceNumbers that hold slices of fragments - cache: HashMap>, - - current_message_sequence_number: u16, -} - -impl FragmentBuffer { - pub fn new() -> Self { - FragmentBuffer { - cache: HashMap::new(), - current_message_sequence_number: 0, - } - } - - // Attempts to push a DTLS packet to the FragmentBuffer - // when it returns true it means the FragmentBuffer has inserted and the buffer shouldn't be handled - // when an error returns it is fatal, and the DTLS connection should be stopped - pub fn push(&mut self, mut buf: &[u8]) -> Result { - let current_size = self.size(); - if current_size + buf.len() >= FRAGMENT_BUFFER_MAX_SIZE { - return Err(Error::ErrFragmentBufferOverflow { - new_size: current_size + buf.len(), - max_size: FRAGMENT_BUFFER_MAX_SIZE, - }); - } - - let mut reader = Cursor::new(buf); - let record_layer_header = RecordLayerHeader::unmarshal(&mut reader)?; - - // Fragment isn't a handshake, we don't need to handle it - if record_layer_header.content_type != ContentType::Handshake { - return Ok(false); - } - - buf = &buf[RECORD_LAYER_HEADER_SIZE..]; - while !buf.is_empty() { - let mut reader = Cursor::new(buf); - let handshake_header = HandshakeHeader::unmarshal(&mut reader)?; - - self.cache - .entry(handshake_header.message_sequence) - .or_default(); - - // end index should be the length of handshake header but if the handshake - // was fragmented, we should keep them all - let mut end = HANDSHAKE_HEADER_LENGTH + handshake_header.length as usize; - if end > buf.len() { - end = buf.len(); - } - - // Discard all headers, when rebuilding the packet we will re-build - let data = buf[HANDSHAKE_HEADER_LENGTH..end].to_vec(); - - if let Some(x) = self.cache.get_mut(&handshake_header.message_sequence) { - x.push(Fragment { - record_layer_header, - handshake_header, - data, - }); - } - buf = &buf[end..]; - } - - Ok(true) - } - - pub fn pop(&mut self) -> Result<(Vec, u16)> { - let seq_num = self.current_message_sequence_number; - if !self.cache.contains_key(&seq_num) { - return Err(Error::ErrEmptyFragment); - } - - let (content, epoch) = if let Some(frags) = self.cache.get_mut(&seq_num) { - let mut raw_message = vec![]; - // Recursively collect up - if !append_message(0, frags, &mut raw_message) { - return Err(Error::ErrEmptyFragment); - } - - let mut first_header = frags[0].handshake_header; - first_header.fragment_offset = 0; - first_header.fragment_length = first_header.length; - - let mut raw_header = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw_header.as_mut()); - if first_header.marshal(&mut writer).is_err() { - return Err(Error::ErrEmptyFragment); - } - } - - let message_epoch = frags[0].record_layer_header.epoch; - - raw_header.extend_from_slice(&raw_message); - - (raw_header, message_epoch) - } else { - return Err(Error::ErrEmptyFragment); - }; - - self.cache.remove(&seq_num); - self.current_message_sequence_number += 1; - - Ok((content, epoch)) - } - - fn size(&self) -> usize { - self.cache - .values() - .map(|fragment| fragment.iter().map(|f| f.data.len()).sum::()) - .sum() - } -} - -fn append_message(target_offset: u32, frags: &[Fragment], raw_message: &mut Vec) -> bool { - for f in frags { - if f.handshake_header.fragment_offset == target_offset { - let fragment_end = - f.handshake_header.fragment_offset + f.handshake_header.fragment_length; - - // NB: Order here is important, the `f.handshake_header.fragment_length != 0` - // MUST come before the recursive call. - if fragment_end != f.handshake_header.length - && f.handshake_header.fragment_length != 0 - && !append_message(fragment_end, frags, raw_message) - { - return false; - } - - let mut message = vec![]; - message.extend_from_slice(&f.data); - message.extend_from_slice(raw_message); - *raw_message = message; - return true; - } - } - - false -} diff --git a/dtls/src/handshake/handshake_cache.rs b/dtls/src/handshake/handshake_cache.rs deleted file mode 100644 index 38667bb28..000000000 --- a/dtls/src/handshake/handshake_cache.rs +++ /dev/null @@ -1,241 +0,0 @@ -#[cfg(test)] -mod handshake_cache_test; - -use std::collections::HashMap; -use std::io::BufReader; -use std::sync::Arc; - -use sha2::{Digest, Sha256}; -use tokio::sync::Mutex; - -use crate::cipher_suite::*; -use crate::handshake::*; - -#[derive(Clone, Debug)] -pub(crate) struct HandshakeCacheItem { - typ: HandshakeType, - is_client: bool, - epoch: u16, - message_sequence: u16, - data: Vec, -} - -#[derive(Copy, Clone, Debug)] -pub(crate) struct HandshakeCachePullRule { - pub(crate) typ: HandshakeType, - pub(crate) epoch: u16, - pub(crate) is_client: bool, - pub(crate) optional: bool, -} - -#[derive(Clone)] -pub(crate) struct HandshakeCache { - cache: Arc>>, -} - -impl HandshakeCache { - pub(crate) fn new() -> Self { - HandshakeCache { - cache: Arc::new(Mutex::new(vec![])), - } - } - - pub(crate) async fn push( - &mut self, - data: Vec, - epoch: u16, - message_sequence: u16, - typ: HandshakeType, - is_client: bool, - ) -> bool { - let mut cache = self.cache.lock().await; - - for i in &*cache { - if i.message_sequence == message_sequence && i.is_client == is_client { - return false; - } - } - - cache.push(HandshakeCacheItem { - typ, - is_client, - epoch, - message_sequence, - data, - }); - - true - } - - // returns a list handshakes that match the requested rules - // the list will contain null entries for rules that can't be satisfied - // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies) - pub(crate) async fn pull(&self, rules: &[HandshakeCachePullRule]) -> Vec { - let cache = self.cache.lock().await; - - let mut out = vec![]; - for r in rules { - let mut item: Option = None; - for c in &*cache { - if c.typ == r.typ && c.is_client == r.is_client && c.epoch == r.epoch { - if let Some(x) = &item { - if x.message_sequence < c.message_sequence { - item = Some(c.clone()); - } - } else { - item = Some(c.clone()); - } - } - } - - if let Some(c) = item { - out.push(c); - } - } - - out - } - - // full_pull_map pulls all handshakes between rules[0] to rules[len(rules)-1] as map. - pub(crate) async fn full_pull_map( - &self, - start_seq: isize, - rules: &[HandshakeCachePullRule], - ) -> Result<(isize, HashMap)> { - let cache = self.cache.lock().await; - - let mut ci = HashMap::new(); - for r in rules { - let mut item: Option = None; - for c in &*cache { - if c.typ == r.typ && c.is_client == r.is_client && c.epoch == r.epoch { - if let Some(x) = &item { - if x.message_sequence < c.message_sequence { - item = Some(c.clone()); - } - } else { - item = Some(c.clone()); - } - } - } - if !r.optional && item.is_none() { - // Missing mandatory message. - return Err(Error::Other("Missing mandatory message".to_owned())); - } - - if let Some(c) = item { - ci.insert(r.typ, c); - } - } - - let mut out = HashMap::new(); - let mut seq = start_seq; - for r in rules { - let t = r.typ; - if let Some(i) = ci.get(&t) { - let mut reader = BufReader::new(i.data.as_slice()); - let raw_handshake = Handshake::unmarshal(&mut reader)?; - if seq as u16 != raw_handshake.handshake_header.message_sequence { - // There is a gap. Some messages are not arrived. - return Err(Error::Other( - "There is a gap. Some messages are not arrived.".to_owned(), - )); - } - seq += 1; - out.insert(t, raw_handshake.handshake_message); - } - } - - Ok((seq, out)) - } - - // pull_and_merge calls pull and then merges the results, ignoring any null entries - pub(crate) async fn pull_and_merge(&self, rules: &[HandshakeCachePullRule]) -> Vec { - let mut merged = vec![]; - - for p in &self.pull(rules).await { - merged.extend_from_slice(&p.data); - } - - merged - } - - // session_hash returns the session hash for Extended Master Secret support - // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4 - pub(crate) async fn session_hash( - &self, - hf: CipherSuiteHash, - epoch: u16, - additional: &[u8], - ) -> Result> { - let mut merged = vec![]; - - // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3 - let handshake_buffer = self - .pull(&[ - HandshakeCachePullRule { - typ: HandshakeType::ClientHello, - epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHello, - epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerKeyExchange, - epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::CertificateRequest, - epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ServerHelloDone, - epoch, - is_client: false, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::Certificate, - epoch, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: HandshakeType::ClientKeyExchange, - epoch, - is_client: true, - optional: false, - }, - ]) - .await; - - for p in &handshake_buffer { - merged.extend_from_slice(&p.data); - } - - merged.extend_from_slice(additional); - - let mut hasher = match hf { - CipherSuiteHash::Sha256 => Sha256::new(), - }; - hasher.update(&merged); - let result = hasher.finalize(); - - Ok(result.as_slice().to_vec()) - } -} diff --git a/dtls/src/handshake/handshake_cache/handshake_cache_test.rs b/dtls/src/handshake/handshake_cache/handshake_cache_test.rs deleted file mode 100644 index b17391b18..000000000 --- a/dtls/src/handshake/handshake_cache/handshake_cache_test.rs +++ /dev/null @@ -1,658 +0,0 @@ -use super::*; - -#[tokio::test] -async fn test_handshake_cache_single_push() -> Result<()> { - let tests = vec![ - ( - "Single Push", - vec![HandshakeCacheItem { - typ: 0.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }], - vec![HandshakeCachePullRule { - typ: 0.into(), - epoch: 0, - is_client: true, - optional: false, - }], - vec![0x00], - ), - ( - "Multi Push", - vec![ - HandshakeCacheItem { - typ: 0.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: 2.into(), - is_client: true, - epoch: 0, - message_sequence: 2, - data: vec![0x02], - }, - ], - vec![ - HandshakeCachePullRule { - typ: 0.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 2.into(), - epoch: 0, - is_client: true, - optional: false, - }, - ], - vec![0x00, 0x01, 0x02], - ), - ( - "Multi Push, Rules set order", - vec![ - HandshakeCacheItem { - typ: 2.into(), - is_client: true, - epoch: 0, - message_sequence: 2, - data: vec![0x02], - }, - HandshakeCacheItem { - typ: 0.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - ], - vec![ - HandshakeCachePullRule { - typ: 0.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 2.into(), - epoch: 0, - is_client: true, - optional: false, - }, - ], - vec![0x00, 0x01, 0x02], - ), - ( - "Multi Push, Dupe Seqnum", - vec![ - HandshakeCacheItem { - typ: 0.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - ], - vec![ - HandshakeCachePullRule { - typ: 0.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: true, - optional: false, - }, - ], - vec![0x00, 0x01], - ), - ( - "Multi Push, Dupe Seqnum Client/Server", - vec![ - HandshakeCacheItem { - typ: 0.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: 1.into(), - is_client: false, - epoch: 0, - message_sequence: 1, - data: vec![0x02], - }, - ], - vec![ - HandshakeCachePullRule { - typ: 0.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: false, - optional: false, - }, - ], - vec![0x00, 0x01, 0x02], - ), - ( - "Multi Push, Dupe Seqnum with Unique HandshakeType", - vec![ - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: 2.into(), - is_client: true, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: 3.into(), - is_client: false, - epoch: 0, - message_sequence: 0, - data: vec![0x02], - }, - ], - vec![ - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 2.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 3.into(), - epoch: 0, - is_client: false, - optional: false, - }, - ], - vec![0x00, 0x01, 0x02], - ), - ( - "Multi Push, Wrong epoch", - vec![ - HandshakeCacheItem { - typ: 1.into(), - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: 2.into(), - is_client: true, - epoch: 1, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: 2.into(), - is_client: true, - epoch: 0, - message_sequence: 2, - data: vec![0x11], - }, - HandshakeCacheItem { - typ: 3.into(), - is_client: false, - epoch: 0, - message_sequence: 0, - data: vec![0x02], - }, - HandshakeCacheItem { - typ: 3.into(), - is_client: false, - epoch: 1, - message_sequence: 0, - data: vec![0x12], - }, - HandshakeCacheItem { - typ: 3.into(), - is_client: false, - epoch: 2, - message_sequence: 0, - data: vec![0x12], - }, - ], - vec![ - HandshakeCachePullRule { - typ: 1.into(), - epoch: 0, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 2.into(), - epoch: 1, - is_client: true, - optional: false, - }, - HandshakeCachePullRule { - typ: 3.into(), - epoch: 0, - is_client: false, - optional: false, - }, - ], - vec![0x00, 0x01, 0x02], - ), - ]; - - for (name, inputs, rules, expected) in tests { - let mut h = HandshakeCache::new(); - for i in inputs { - h.push(i.data, i.epoch, i.message_sequence, i.typ, i.is_client) - .await; - } - let verify_data = h.pull_and_merge(&rules).await; - assert_eq!( - verify_data, expected, - "handshakeCache '{name}' exp:{expected:?} actual {verify_data:?}", - ); - } - - Ok(()) -} - -#[tokio::test] -async fn test_handshake_cache_session_hash() -> Result<()> { - let tests = vec![ - ( - "Standard Handshake", - vec![ - HandshakeCacheItem { - typ: HandshakeType::ClientHello, - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHello, - is_client: false, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: HandshakeType::Certificate, - is_client: false, - epoch: 0, - message_sequence: 2, - data: vec![0x02], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerKeyExchange, - is_client: false, - epoch: 0, - message_sequence: 3, - data: vec![0x03], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHelloDone, - is_client: false, - epoch: 0, - message_sequence: 4, - data: vec![0x04], - }, - HandshakeCacheItem { - typ: HandshakeType::ClientKeyExchange, - is_client: true, - epoch: 0, - message_sequence: 5, - data: vec![0x05], - }, - ], - vec![ - 0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, - 0x27, 0xcd, 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, - 0x2d, 0x51, 0x1f, 0x43, - ], - ), - ( - "Handshake With Client Cert Request", - vec![ - HandshakeCacheItem { - typ: HandshakeType::ClientHello, - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHello, - is_client: false, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: HandshakeType::Certificate, - is_client: false, - epoch: 0, - message_sequence: 2, - data: vec![0x02], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerKeyExchange, - is_client: false, - epoch: 0, - message_sequence: 3, - data: vec![0x03], - }, - HandshakeCacheItem { - typ: HandshakeType::CertificateRequest, - is_client: false, - epoch: 0, - message_sequence: 4, - data: vec![0x04], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHelloDone, - is_client: false, - epoch: 0, - message_sequence: 5, - data: vec![0x05], - }, - HandshakeCacheItem { - typ: HandshakeType::ClientKeyExchange, - is_client: true, - epoch: 0, - message_sequence: 6, - data: vec![0x06], - }, - ], - vec![ - 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, - 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, - 0x25, 0x74, 0x9a, 0x6b, - ], - ), - ( - "Handshake Ignores after ClientKeyExchange", - vec![ - HandshakeCacheItem { - typ: HandshakeType::ClientHello, - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHello, - is_client: false, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: HandshakeType::Certificate, - is_client: false, - epoch: 0, - message_sequence: 2, - data: vec![0x02], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerKeyExchange, - is_client: false, - epoch: 0, - message_sequence: 3, - data: vec![0x03], - }, - HandshakeCacheItem { - typ: HandshakeType::CertificateRequest, - is_client: false, - epoch: 0, - message_sequence: 4, - data: vec![0x04], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHelloDone, - is_client: false, - epoch: 0, - message_sequence: 5, - data: vec![0x05], - }, - HandshakeCacheItem { - typ: HandshakeType::ClientKeyExchange, - is_client: true, - epoch: 0, - message_sequence: 6, - data: vec![0x06], - }, - HandshakeCacheItem { - typ: HandshakeType::CertificateVerify, - is_client: true, - epoch: 0, - message_sequence: 7, - data: vec![0x07], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: true, - epoch: 1, - message_sequence: 7, - data: vec![0x08], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: false, - epoch: 1, - message_sequence: 7, - data: vec![0x09], - }, - ], - vec![ - 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, - 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, - 0x25, 0x74, 0x9a, 0x6b, - ], - ), - ( - "Handshake Ignores wrong epoch", - vec![ - HandshakeCacheItem { - typ: HandshakeType::ClientHello, - is_client: true, - epoch: 0, - message_sequence: 0, - data: vec![0x00], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHello, - is_client: false, - epoch: 0, - message_sequence: 1, - data: vec![0x01], - }, - HandshakeCacheItem { - typ: HandshakeType::Certificate, - is_client: false, - epoch: 0, - message_sequence: 2, - data: vec![0x02], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerKeyExchange, - is_client: false, - epoch: 0, - message_sequence: 3, - data: vec![0x03], - }, - HandshakeCacheItem { - typ: HandshakeType::CertificateRequest, - is_client: false, - epoch: 0, - message_sequence: 4, - data: vec![0x04], - }, - HandshakeCacheItem { - typ: HandshakeType::ServerHelloDone, - is_client: false, - epoch: 0, - message_sequence: 5, - data: vec![0x05], - }, - HandshakeCacheItem { - typ: HandshakeType::ClientKeyExchange, - is_client: true, - epoch: 0, - message_sequence: 6, - data: vec![0x06], - }, - HandshakeCacheItem { - typ: HandshakeType::CertificateVerify, - is_client: true, - epoch: 0, - message_sequence: 7, - data: vec![0x07], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: true, - epoch: 0, - message_sequence: 7, - data: vec![0xf0], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: false, - epoch: 0, - message_sequence: 7, - data: vec![0xf1], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: true, - epoch: 1, - message_sequence: 7, - data: vec![0x08], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: false, - epoch: 1, - message_sequence: 7, - data: vec![0x09], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: true, - epoch: 0, - message_sequence: 7, - data: vec![0xf0], - }, - HandshakeCacheItem { - typ: HandshakeType::Finished, - is_client: false, - epoch: 0, - message_sequence: 7, - data: vec![0xf1], - }, - ], - vec![ - 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, - 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, - 0x25, 0x74, 0x9a, 0x6b, - ], - ), - ]; - - for (name, inputs, expected) in tests { - let mut h = HandshakeCache::new(); - for i in inputs { - h.push(i.data, i.epoch, i.message_sequence, i.typ, i.is_client) - .await; - } - - let verify_data = h.session_hash(CipherSuiteHash::Sha256, 0, &[]).await?; - - assert_eq!( - verify_data, expected, - "handshakeCacheSessionHassh '{name}' exp: {expected:?} actual {verify_data:?}" - ); - } - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_header.rs b/dtls/src/handshake/handshake_header.rs deleted file mode 100644 index 70c7610a6..000000000 --- a/dtls/src/handshake/handshake_header.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::*; - -// msg_len for Handshake messages assumes an extra 12 bytes for -// sequence, Fragment and version information -pub(crate) const HANDSHAKE_HEADER_LENGTH: usize = 12; - -#[derive(Copy, Clone, PartialEq, Eq, Debug, Default)] -pub struct HandshakeHeader { - pub(crate) handshake_type: HandshakeType, - pub(crate) length: u32, // uint24 in spec - pub(crate) message_sequence: u16, - pub(crate) fragment_offset: u32, // uint24 in spec - pub(crate) fragment_length: u32, // uint24 in spec -} - -impl HandshakeHeader { - pub fn size(&self) -> usize { - 1 + 3 + 2 + 3 + 3 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(self.handshake_type as u8)?; - writer.write_u24::(self.length)?; - writer.write_u16::(self.message_sequence)?; - writer.write_u24::(self.fragment_offset)?; - writer.write_u24::(self.fragment_length)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let handshake_type = reader.read_u8()?.into(); - let length = reader.read_u24::()?; - let message_sequence = reader.read_u16::()?; - let fragment_offset = reader.read_u24::()?; - let fragment_length = reader.read_u24::()?; - - Ok(HandshakeHeader { - handshake_type, - length, - message_sequence, - fragment_offset, - fragment_length, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_certificate.rs b/dtls/src/handshake/handshake_message_certificate.rs deleted file mode 100644 index c4e0034e4..000000000 --- a/dtls/src/handshake/handshake_message_certificate.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::*; - -#[cfg(test)] -mod handshake_message_certificate_test; - -const HANDSHAKE_MESSAGE_CERTIFICATE_LENGTH_FIELD_SIZE: usize = 3; - -#[derive(PartialEq, Eq, Debug, Clone)] -pub struct HandshakeMessageCertificate { - pub(crate) certificate: Vec>, -} - -impl HandshakeMessageCertificate { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::Certificate - } - - pub fn size(&self) -> usize { - let mut len = 3; - - for r in &self.certificate { - len += HANDSHAKE_MESSAGE_CERTIFICATE_LENGTH_FIELD_SIZE + r.len(); - } - - len - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - let mut payload_size = 0; - for r in &self.certificate { - payload_size += HANDSHAKE_MESSAGE_CERTIFICATE_LENGTH_FIELD_SIZE + r.len(); - } - - // Total Payload Size - writer.write_u24::(payload_size as u32)?; - - for r in &self.certificate { - // Certificate Length - writer.write_u24::(r.len() as u32)?; - - // Certificate body - writer.write_all(r)?; - } - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let mut certificate: Vec> = vec![]; - - let payload_size = reader.read_u24::()? as usize; - let mut offset = 0; - while offset < payload_size { - let certificate_len = reader.read_u24::()? as usize; - offset += HANDSHAKE_MESSAGE_CERTIFICATE_LENGTH_FIELD_SIZE; - - let mut buf = vec![0; certificate_len]; - reader.read_exact(&mut buf)?; - offset += certificate_len; - - certificate.push(buf); - } - - Ok(HandshakeMessageCertificate { certificate }) - } -} diff --git a/dtls/src/handshake/handshake_message_certificate/handshake_message_certificate_test.rs b/dtls/src/handshake/handshake_message_certificate/handshake_message_certificate_test.rs deleted file mode 100644 index 582d16df8..000000000 --- a/dtls/src/handshake/handshake_message_certificate/handshake_message_certificate_test.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_certificate() -> Result<()> { - let raw_certificate = vec![ - 0x00, 0x01, 0x8c, 0x00, 0x01, 0x89, 0x30, 0x82, 0x01, 0x85, 0x30, 0x82, 0x01, 0x2b, 0x02, - 0x14, 0x7d, 0x00, 0xcf, 0x07, 0xfc, 0xe2, 0xb6, 0xb8, 0x3f, 0x72, 0xeb, 0x11, 0x36, 0x1b, - 0xf6, 0x39, 0xf1, 0x3c, 0x33, 0x41, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, - 0x04, 0x03, 0x02, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, - 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x53, - 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, - 0x55, 0x04, 0x0a, 0x0c, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, - 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, - 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x32, 0x35, 0x30, 0x38, 0x35, 0x31, 0x31, 0x32, - 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x32, 0x35, 0x30, 0x38, 0x35, 0x31, 0x31, 0x32, - 0x5a, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, - 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x53, 0x6f, 0x6d, - 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, - 0x0a, 0x0c, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, - 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x59, 0x30, - 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, - 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0xf9, 0xb1, 0x62, 0xd6, 0x07, 0xae, - 0xc3, 0x36, 0x34, 0xf5, 0xa3, 0x09, 0x39, 0x86, 0xe7, 0x3b, 0x59, 0xf7, 0x4a, 0x1d, 0xf4, - 0x97, 0x4f, 0x91, 0x40, 0x56, 0x1b, 0x3d, 0x6c, 0x5a, 0x38, 0x10, 0x15, 0x58, 0xf5, 0xa4, - 0xcc, 0xdf, 0xd5, 0xf5, 0x4a, 0x35, 0x40, 0x0f, 0x9f, 0x54, 0xb7, 0xe9, 0xe2, 0xae, 0x63, - 0x83, 0x6a, 0x4c, 0xfc, 0xc2, 0x5f, 0x78, 0xa0, 0xbb, 0x46, 0x54, 0xa4, 0xda, 0x30, 0x0a, - 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x48, 0x00, 0x30, 0x45, - 0x02, 0x20, 0x47, 0x1a, 0x5f, 0x58, 0x2a, 0x74, 0x33, 0x6d, 0xed, 0xac, 0x37, 0x21, 0xfa, - 0x76, 0x5a, 0x4d, 0x78, 0x68, 0x1a, 0xdd, 0x80, 0xa4, 0xd4, 0xb7, 0x7f, 0x7d, 0x78, 0xb3, - 0xfb, 0xf3, 0x95, 0xfb, 0x02, 0x21, 0x00, 0xc0, 0x73, 0x30, 0xda, 0x2b, 0xc0, 0x0c, 0x9e, - 0xb2, 0x25, 0x0d, 0x46, 0xb0, 0xbc, 0x66, 0x7f, 0x71, 0x66, 0xbf, 0x16, 0xb3, 0x80, 0x78, - 0xd0, 0x0c, 0xef, 0xcc, 0xf5, 0xc1, 0x15, 0x0f, 0x58, - ]; - - let mut reader = BufReader::new(raw_certificate.as_slice()); - let c = HandshakeMessageCertificate::unmarshal(&mut reader)?; - //TODO: add x509 parse - // certificate, err := x509.ParseCertificate(c.certificate[0]) - // if err != nil { - // t.Error(err) - // } - // copyCertificatePrivateMembers(certificate, parsedCertificate) - // if !reflect.DeepEqual(certificate, parsedCertificate) { - // t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", c, parsedCertificate) - // } - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_certificate, - "handshakeMessageCertificate marshal: got {raw:?}, want {raw_certificate:?}" - ); - - Ok(()) -} - -#[test] -fn test_empty_handshake_message_certificate() -> Result<()> { - let raw_certificate = vec![0x00, 0x00, 0x00]; - - let expected_certificate = HandshakeMessageCertificate { - certificate: vec![], - }; - - let mut reader = BufReader::new(raw_certificate.as_slice()); - let c = HandshakeMessageCertificate::unmarshal(&mut reader)?; - - assert_eq!( - c, expected_certificate, - "handshakeMessageCertificate unmarshal: got {c:?}, want {expected_certificate:?}", - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_certificate_request.rs b/dtls/src/handshake/handshake_message_certificate_request.rs deleted file mode 100644 index f48cecbb6..000000000 --- a/dtls/src/handshake/handshake_message_certificate_request.rs +++ /dev/null @@ -1,77 +0,0 @@ -#[cfg(test)] -mod handshake_message_certificate_request_test; - -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::*; -use crate::client_certificate_type::*; -use crate::signature_hash_algorithm::*; - -/* -A non-anonymous server can optionally request a certificate from -the client, if appropriate for the selected cipher suite. This -message, if sent, will immediately follow the ServerKeyExchange -message (if it is sent; otherwise, this message follows the -server's Certificate message). -*/ -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageCertificateRequest { - pub(crate) certificate_types: Vec, - pub(crate) signature_hash_algorithms: Vec, -} - -const HANDSHAKE_MESSAGE_CERTIFICATE_REQUEST_MIN_LENGTH: usize = 5; - -impl HandshakeMessageCertificateRequest { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::CertificateRequest - } - - pub fn size(&self) -> usize { - 1 + self.certificate_types.len() + 2 + self.signature_hash_algorithms.len() * 2 + 2 - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(self.certificate_types.len() as u8)?; - for v in &self.certificate_types { - writer.write_u8(*v as u8)?; - } - - writer.write_u16::(2 * self.signature_hash_algorithms.len() as u16)?; - for v in &self.signature_hash_algorithms { - writer.write_u8(v.hash as u8)?; - writer.write_u8(v.signature as u8)?; - } - - writer.write_all(&[0x00, 0x00])?; // Distinguished Names Length - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let certificate_types_length = reader.read_u8()?; - - let mut certificate_types = vec![]; - for _ in 0..certificate_types_length { - let cert_type = reader.read_u8()?.into(); - certificate_types.push(cert_type); - } - - let signature_hash_algorithms_length = reader.read_u16::()?; - - let mut signature_hash_algorithms = vec![]; - for _ in (0..signature_hash_algorithms_length).step_by(2) { - let hash = reader.read_u8()?.into(); - let signature = reader.read_u8()?.into(); - - signature_hash_algorithms.push(SignatureHashAlgorithm { hash, signature }); - } - - Ok(HandshakeMessageCertificateRequest { - certificate_types, - signature_hash_algorithms, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_certificate_request/handshake_message_certificate_request_test.rs b/dtls/src/handshake/handshake_message_certificate_request/handshake_message_certificate_request_test.rs deleted file mode 100644 index f261cde5b..000000000 --- a/dtls/src/handshake/handshake_message_certificate_request/handshake_message_certificate_request_test.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; -use crate::signature_hash_algorithm::*; - -#[test] -fn test_handshake_message_certificate_request() -> Result<()> { - let raw_certificate_request = vec![ - 0x02, 0x01, 0x40, 0x00, 0x0C, 0x04, 0x03, 0x04, 0x01, 0x05, 0x03, 0x05, 0x01, 0x06, 0x01, - 0x02, 0x01, 0x00, 0x00, - ]; - - let parsed_certificate_request = HandshakeMessageCertificateRequest { - certificate_types: vec![ - ClientCertificateType::RsaSign, - ClientCertificateType::EcdsaSign, - ], - signature_hash_algorithms: vec![ - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha1, - signature: SignatureAlgorithm::Rsa, - }, - ], - }; - - let mut reader = BufReader::new(raw_certificate_request.as_slice()); - let c = HandshakeMessageCertificateRequest::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_certificate_request, - "parsedCertificateRequest unmarshal: got {c:?}, want {parsed_certificate_request:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_certificate_request, - "parsedCertificateRequest marshal: got {raw:?}, want {raw_certificate_request:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_certificate_verify.rs b/dtls/src/handshake/handshake_message_certificate_verify.rs deleted file mode 100644 index 4b535ce19..000000000 --- a/dtls/src/handshake/handshake_message_certificate_verify.rs +++ /dev/null @@ -1,52 +0,0 @@ -#[cfg(test)] -mod handshake_message_certificate_verify_test; - -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::*; -use crate::signature_hash_algorithm::*; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageCertificateVerify { - pub(crate) algorithm: SignatureHashAlgorithm, - pub(crate) signature: Vec, -} - -const HANDSHAKE_MESSAGE_CERTIFICATE_VERIFY_MIN_LENGTH: usize = 4; - -impl HandshakeMessageCertificateVerify { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::CertificateVerify - } - - pub fn size(&self) -> usize { - 1 + 1 + 2 + self.signature.len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(self.algorithm.hash as u8)?; - writer.write_u8(self.algorithm.signature as u8)?; - writer.write_u16::(self.signature.len() as u16)?; - writer.write_all(&self.signature)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let hash_algorithm = reader.read_u8()?.into(); - let signature_algorithm = reader.read_u8()?.into(); - let signature_length = reader.read_u16::()? as usize; - let mut signature = vec![0; signature_length]; - reader.read_exact(&mut signature)?; - - Ok(HandshakeMessageCertificateVerify { - algorithm: SignatureHashAlgorithm { - hash: hash_algorithm, - signature: signature_algorithm, - }, - signature, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_certificate_verify/handshake_message_certificate_verify_test.rs b/dtls/src/handshake/handshake_message_certificate_verify/handshake_message_certificate_verify_test.rs deleted file mode 100644 index b52cebc5b..000000000 --- a/dtls/src/handshake/handshake_message_certificate_verify/handshake_message_certificate_verify_test.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_certificate_request() -> Result<()> { - let raw_certificate_verify = vec![ - 0x04, 0x03, 0x00, 0x47, 0x30, 0x45, 0x02, 0x20, 0x6b, 0x63, 0x17, 0xad, 0xbe, 0xb7, 0x7b, - 0x0f, 0x86, 0x73, 0x39, 0x1e, 0xba, 0xb3, 0x50, 0x9c, 0xce, 0x9c, 0xe4, 0x8b, 0xe5, 0x13, - 0x07, 0x59, 0x18, 0x1f, 0xe5, 0xa0, 0x2b, 0xca, 0xa6, 0xad, 0x02, 0x21, 0x00, 0xd3, 0xb5, - 0x01, 0xbe, 0x87, 0x6c, 0x04, 0xa1, 0xdc, 0x28, 0xaa, 0x5f, 0xf7, 0x1e, 0x9c, 0xc0, 0x1e, - 0x00, 0x2c, 0xe5, 0x94, 0xbb, 0x03, 0x0e, 0xf1, 0xcb, 0x28, 0x22, 0x33, 0x23, 0x88, 0xad, - ]; - let parsed_certificate_verify = HandshakeMessageCertificateVerify { - algorithm: SignatureHashAlgorithm { - hash: raw_certificate_verify[0].into(), - signature: raw_certificate_verify[1].into(), - }, - signature: raw_certificate_verify[4..].to_vec(), - }; - - let mut reader = BufReader::new(raw_certificate_verify.as_slice()); - let c = HandshakeMessageCertificateVerify::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_certificate_verify, - "handshakeMessageCertificate unmarshal: got {c:?}, want {parsed_certificate_verify:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_certificate_verify, - "handshakeMessageCertificateVerify marshal: got {raw:?}, want {raw_certificate_verify:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_client_hello.rs b/dtls/src/handshake/handshake_message_client_hello.rs deleted file mode 100644 index f98340e03..000000000 --- a/dtls/src/handshake/handshake_message_client_hello.rs +++ /dev/null @@ -1,196 +0,0 @@ -#[cfg(test)] -mod handshake_message_client_hello_test; - -use std::fmt; -use std::io::{BufReader, BufWriter}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::handshake_random::*; -use super::*; -use crate::cipher_suite::*; -use crate::compression_methods::*; -use crate::extension::*; -use crate::record_layer::record_layer_header::*; - -/* -When a client first connects to a server it is required to send -the client hello as its first message. The client can also send a -client hello in response to a hello request or on its own -initiative in order to renegotiate the security parameters in an -existing connection. -*/ -#[derive(Clone)] -pub struct HandshakeMessageClientHello { - pub(crate) version: ProtocolVersion, - pub(crate) random: HandshakeRandom, - pub(crate) cookie: Vec, - - pub(crate) cipher_suites: Vec, - pub(crate) compression_methods: CompressionMethods, - pub(crate) extensions: Vec, -} - -impl PartialEq for HandshakeMessageClientHello { - fn eq(&self, other: &Self) -> bool { - if !(self.version == other.version - && self.random == other.random - && self.cookie == other.cookie - && self.compression_methods == other.compression_methods - && self.extensions == other.extensions - && self.cipher_suites.len() == other.cipher_suites.len()) - { - return false; - } - - for i in 0..self.cipher_suites.len() { - if self.cipher_suites[i] != other.cipher_suites[i] { - return false; - } - } - - true - } -} - -impl fmt::Debug for HandshakeMessageClientHello { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut cipher_suites_str = String::new(); - for cipher_suite in &self.cipher_suites { - cipher_suites_str += &cipher_suite.to_string(); - cipher_suites_str += " "; - } - let s = [ - format!("version: {:?} random: {:?}", self.version, self.random), - format!("cookie: {:?}", self.cookie), - format!("cipher_suites: {cipher_suites_str:?}"), - format!("compression_methods: {:?}", self.compression_methods), - format!("extensions: {:?}", self.extensions), - ]; - write!(f, "{}", s.join(" ")) - } -} - -const HANDSHAKE_MESSAGE_CLIENT_HELLO_VARIABLE_WIDTH_START: usize = 34; - -impl HandshakeMessageClientHello { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::ClientHello - } - - pub fn size(&self) -> usize { - let mut len = 0; - - len += 2; // version.major+minor - len += self.random.size(); - - // SessionID - len += 1; - - len += 1 + self.cookie.len(); - - len += 2 + 2 * self.cipher_suites.len(); - - len += self.compression_methods.size(); - - len += 2; - for extension in &self.extensions { - len += extension.size(); - } - - len - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - if self.cookie.len() > 255 { - return Err(Error::ErrCookieTooLong); - } - - writer.write_u8(self.version.major)?; - writer.write_u8(self.version.minor)?; - self.random.marshal(writer)?; - - // SessionID - writer.write_u8(0x00)?; - - writer.write_u8(self.cookie.len() as u8)?; - writer.write_all(&self.cookie)?; - - writer.write_u16::(2 * self.cipher_suites.len() as u16)?; - for cipher_suite in &self.cipher_suites { - writer.write_u16::(*cipher_suite as u16)?; - } - - self.compression_methods.marshal(writer)?; - - let mut extension_buffer = vec![]; - { - let mut extension_writer = BufWriter::<&mut Vec>::new(extension_buffer.as_mut()); - for extension in &self.extensions { - extension.marshal(&mut extension_writer)?; - } - } - - writer.write_u16::(extension_buffer.len() as u16)?; - writer.write_all(&extension_buffer)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let major = reader.read_u8()?; - let minor = reader.read_u8()?; - let random = HandshakeRandom::unmarshal(reader)?; - - // Session ID - reader.read_u8()?; - - let cookie_len = reader.read_u8()? as usize; - let mut cookie = vec![0; cookie_len]; - reader.read_exact(&mut cookie)?; - - let cipher_suites_len = reader.read_u16::()? as usize / 2; - let mut cipher_suites = vec![]; - for _ in 0..cipher_suites_len { - let id: CipherSuiteId = reader.read_u16::()?.into(); - //let cipher_suite = cipher_suite_for_id(id)?; - cipher_suites.push(id); - } - - let compression_methods = CompressionMethods::unmarshal(reader)?; - let mut extensions = vec![]; - - let extension_buffer_len = reader.read_u16::()? as usize; - let mut extension_buffer = vec![0u8; extension_buffer_len]; - reader.read_exact(&mut extension_buffer)?; - - let mut offset = 0; - while offset < extension_buffer_len { - let mut extension_reader = BufReader::new(&extension_buffer[offset..]); - if let Ok(extension) = Extension::unmarshal(&mut extension_reader) { - extensions.push(extension); - } else { - log::warn!( - "Unsupported Extension Type {} {}", - extension_buffer[offset], - extension_buffer[offset + 1] - ); - } - - let extension_len = - u16::from_be_bytes([extension_buffer[offset + 2], extension_buffer[offset + 3]]) - as usize; - offset += 4 + extension_len; - } - - Ok(HandshakeMessageClientHello { - version: ProtocolVersion { major, minor }, - random, - cookie, - - cipher_suites, - compression_methods, - extensions, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_client_hello/handshake_message_client_hello_test.rs b/dtls/src/handshake/handshake_message_client_hello/handshake_message_client_hello_test.rs deleted file mode 100644 index ba4260315..000000000 --- a/dtls/src/handshake/handshake_message_client_hello/handshake_message_client_hello_test.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::io::{BufReader, BufWriter}; -use std::time::{Duration, SystemTime}; - -use super::*; -use crate::curve::named_curve::*; -use crate::extension::extension_supported_elliptic_curves::*; - -#[test] -fn test_handshake_message_client_hello() -> Result<()> { - let raw_client_hello = vec![ - 0xfe, 0xfd, 0xb6, 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, - 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, - 0xd8, 0x3d, 0xdc, 0x4b, 0x00, 0x14, 0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, - 0xd6, 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8, 0x00, 0x04, 0xc0, 0x2b, - 0xc0, 0x0a, 0x01, 0x00, 0x00, 0x08, 0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x1d, - ]; - - let gmt_unix_time = if let Some(unix_time) = - SystemTime::UNIX_EPOCH.checked_add(Duration::new(3056586332u64, 0)) - { - unix_time - } else { - SystemTime::UNIX_EPOCH - }; - let parsed_client_hello = HandshakeMessageClientHello { - version: ProtocolVersion { - major: 0xFE, - minor: 0xFD, - }, - random: HandshakeRandom { - gmt_unix_time, - random_bytes: [ - 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, - 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, - ], - }, - cookie: vec![ - 0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, 0xd6, 0x6c, 0x57, 0xd0, 0x0e, - 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8, - ], - cipher_suites: vec![ - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_256_Cbc_Sha, - //Box::::default(), - //Box::::default(), - ], - compression_methods: CompressionMethods { - ids: vec![CompressionMethodId::Null], - }, - extensions: vec![Extension::SupportedEllipticCurves( - ExtensionSupportedEllipticCurves { - elliptic_curves: vec![NamedCurve::X25519], - }, - )], - }; - - let mut reader = BufReader::new(raw_client_hello.as_slice()); - let c = HandshakeMessageClientHello::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_client_hello, - "handshakeMessageClientHello unmarshal: got {c:?}, want {parsed_client_hello:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_client_hello, - "handshakeMessageClientHello marshal: got {raw:?}, want {raw_client_hello:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_client_key_exchange.rs b/dtls/src/handshake/handshake_message_client_key_exchange.rs deleted file mode 100644 index 4c027ff8b..000000000 --- a/dtls/src/handshake/handshake_message_client_key_exchange.rs +++ /dev/null @@ -1,70 +0,0 @@ -#[cfg(test)] -mod handshake_message_client_key_exchange_test; - -use std::io::{Read, Write}; - -use byteorder::{BigEndian, WriteBytesExt}; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageClientKeyExchange { - pub(crate) identity_hint: Vec, - pub(crate) public_key: Vec, -} - -impl HandshakeMessageClientKeyExchange { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::ClientKeyExchange - } - - pub fn size(&self) -> usize { - if !self.public_key.is_empty() { - 1 + self.public_key.len() - } else { - 2 + self.identity_hint.len() - } - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - if (!self.identity_hint.is_empty() && !self.public_key.is_empty()) - || (self.identity_hint.is_empty() && self.public_key.is_empty()) - { - return Err(Error::ErrInvalidClientKeyExchange); - } - - if !self.public_key.is_empty() { - writer.write_u8(self.public_key.len() as u8)?; - writer.write_all(&self.public_key)?; - } else { - writer.write_u16::(self.identity_hint.len() as u16)?; - writer.write_all(&self.identity_hint)?; - } - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let mut data = vec![]; - reader.read_to_end(&mut data)?; - - // If parsed as PSK return early and only populate PSK Identity Hint - let psk_length = ((data[0] as u16) << 8) | data[1] as u16; - if data.len() == psk_length as usize + 2 { - return Ok(HandshakeMessageClientKeyExchange { - identity_hint: data[2..].to_vec(), - public_key: vec![], - }); - } - - let public_key_length = data[0] as usize; - if data.len() != public_key_length + 1 { - return Err(Error::ErrBufferTooSmall); - } - - Ok(HandshakeMessageClientKeyExchange { - identity_hint: vec![], - public_key: data[1..].to_vec(), - }) - } -} diff --git a/dtls/src/handshake/handshake_message_client_key_exchange/handshake_message_client_key_exchange_test.rs b/dtls/src/handshake/handshake_message_client_key_exchange/handshake_message_client_key_exchange_test.rs deleted file mode 100644 index 0b9d6c178..000000000 --- a/dtls/src/handshake/handshake_message_client_key_exchange/handshake_message_client_key_exchange_test.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_client_key_exchange() -> Result<()> { - let raw_client_key_exchange = vec![ - 0x20, 0x26, 0x78, 0x4a, 0x78, 0x70, 0xc1, 0xf9, 0x71, 0xea, 0x50, 0x4a, 0xb5, 0xbb, 0x00, - 0x76, 0x02, 0x05, 0xda, 0xf7, 0xd0, 0x3f, 0xe3, 0xf7, 0x4e, 0x8a, 0x14, 0x6f, 0xb7, 0xe0, - 0xc0, 0xff, 0x54, - ]; - let parsed_client_key_exchange = HandshakeMessageClientKeyExchange { - identity_hint: vec![], - public_key: raw_client_key_exchange[1..].to_vec(), - }; - - let mut reader = BufReader::new(raw_client_key_exchange.as_slice()); - let c = HandshakeMessageClientKeyExchange::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_client_key_exchange, - "parsedCertificateRequest unmarshal: got {c:?}, want {parsed_client_key_exchange:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_client_key_exchange, - "handshakeMessageClientKeyExchange marshal: got {raw:?}, want {raw_client_key_exchange:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_finished.rs b/dtls/src/handshake/handshake_message_finished.rs deleted file mode 100644 index d20feb6e9..000000000 --- a/dtls/src/handshake/handshake_message_finished.rs +++ /dev/null @@ -1,34 +0,0 @@ -#[cfg(test)] -mod handshake_message_finished_test; - -use std::io::{Read, Write}; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageFinished { - pub(crate) verify_data: Vec, -} - -impl HandshakeMessageFinished { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::Finished - } - - pub fn size(&self) -> usize { - self.verify_data.len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_all(&self.verify_data)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let mut verify_data: Vec = vec![]; - reader.read_to_end(&mut verify_data)?; - - Ok(HandshakeMessageFinished { verify_data }) - } -} diff --git a/dtls/src/handshake/handshake_message_finished/handshake_message_finished_test.rs b/dtls/src/handshake/handshake_message_finished/handshake_message_finished_test.rs deleted file mode 100644 index 7980ee909..000000000 --- a/dtls/src/handshake/handshake_message_finished/handshake_message_finished_test.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_finished() -> Result<()> { - let raw_finished = vec![ - 0x01, 0x01, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, - ]; - let parsed_finished = HandshakeMessageFinished { - verify_data: raw_finished.clone(), - }; - - let mut reader = BufReader::new(raw_finished.as_slice()); - let c = HandshakeMessageFinished::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_finished, - "handshakeMessageFinished unmarshal: got {c:?}, want {parsed_finished:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_finished, - "handshakeMessageFinished marshal: got {raw:?}, want {raw_finished:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_hello_verify_request.rs b/dtls/src/handshake/handshake_message_hello_verify_request.rs deleted file mode 100644 index 1d738d9cb..000000000 --- a/dtls/src/handshake/handshake_message_hello_verify_request.rs +++ /dev/null @@ -1,72 +0,0 @@ -#[cfg(test)] -mod handshake_message_hello_verify_request_test; - -use std::io::{Read, Write}; - -use byteorder::{ReadBytesExt, WriteBytesExt}; - -use super::*; -use crate::record_layer::record_layer_header::*; - -/* - The definition of HelloVerifyRequest is as follows: - - struct { - ProtocolVersion server_version; - opaque cookie<0..2^8-1>; - } HelloVerifyRequest; - - The HelloVerifyRequest message type is hello_verify_request(3). - - When the client sends its ClientHello message to the server, the server - MAY respond with a HelloVerifyRequest message. This message contains - a stateless cookie generated using the technique of [PHOTURIS]. The - client MUST retransmit the ClientHello with the cookie added. - - https://tools.ietf.org/html/rfc6347#section-4.2.1 -*/ -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageHelloVerifyRequest { - pub(crate) version: ProtocolVersion, - pub(crate) cookie: Vec, -} - -impl HandshakeMessageHelloVerifyRequest { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::HelloVerifyRequest - } - - pub fn size(&self) -> usize { - 1 + 1 + 1 + self.cookie.len() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - if self.cookie.len() > 255 { - return Err(Error::ErrCookieTooLong); - } - - writer.write_u8(self.version.major)?; - writer.write_u8(self.version.minor)?; - writer.write_u8(self.cookie.len() as u8)?; - writer.write_all(&self.cookie)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let major = reader.read_u8()?; - let minor = reader.read_u8()?; - let cookie_length = reader.read_u8()?; - let mut cookie = vec![]; - reader.read_to_end(&mut cookie)?; - - if cookie.len() < cookie_length as usize { - return Err(Error::ErrBufferTooSmall); - } - - Ok(HandshakeMessageHelloVerifyRequest { - version: ProtocolVersion { major, minor }, - cookie, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_hello_verify_request/handshake_message_hello_verify_request_test.rs b/dtls/src/handshake/handshake_message_hello_verify_request/handshake_message_hello_verify_request_test.rs deleted file mode 100644 index 5a25ce008..000000000 --- a/dtls/src/handshake/handshake_message_hello_verify_request/handshake_message_hello_verify_request_test.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_hello_verify_request() -> Result<()> { - let raw_hello_verify_request = vec![ - 0xfe, 0xff, 0x14, 0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, 0xe2, 0xef, - 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd, - ]; - let parsed_hello_verify_request = HandshakeMessageHelloVerifyRequest { - version: ProtocolVersion { - major: 0xFE, - minor: 0xFF, - }, - cookie: vec![ - 0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, 0xe2, 0xef, 0xc7, 0xfd, - 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd, - ], - }; - - let mut reader = BufReader::new(raw_hello_verify_request.as_slice()); - let c = HandshakeMessageHelloVerifyRequest::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_hello_verify_request, - "parsed_hello_verify_request unmarshal: got {c:?}, want {parsed_hello_verify_request:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_hello_verify_request, - "parsed_hello_verify_request marshal: got {raw:?}, want {raw_hello_verify_request:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_server_hello.rs b/dtls/src/handshake/handshake_message_server_hello.rs deleted file mode 100644 index 6be32e458..000000000 --- a/dtls/src/handshake/handshake_message_server_hello.rs +++ /dev/null @@ -1,151 +0,0 @@ -#[cfg(test)] -mod handshake_message_server_hello_test; - -use std::fmt; -use std::io::{BufReader, BufWriter}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use super::handshake_random::*; -use super::*; -use crate::cipher_suite::*; -use crate::compression_methods::*; -use crate::extension::*; -use crate::record_layer::record_layer_header::*; - -/* -The server will send this message in response to a ClientHello -message when it was able to find an acceptable set of algorithms. -If it cannot find such a match, it will respond with a handshake -failure alert. -https://tools.ietf.org/html/rfc5246#section-7.4.1.3 -*/ -#[derive(Clone)] -pub struct HandshakeMessageServerHello { - pub(crate) version: ProtocolVersion, - pub(crate) random: HandshakeRandom, - - pub(crate) cipher_suite: CipherSuiteId, - pub(crate) compression_method: CompressionMethodId, - pub(crate) extensions: Vec, -} - -impl PartialEq for HandshakeMessageServerHello { - fn eq(&self, other: &Self) -> bool { - self.version == other.version - && self.random == other.random - && self.compression_method == other.compression_method - && self.extensions == other.extensions - && self.cipher_suite == other.cipher_suite - } -} - -impl fmt::Debug for HandshakeMessageServerHello { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = [ - format!("version: {:?} random: {:?}", self.version, self.random), - format!("cipher_suites: {:?}", self.cipher_suite), - format!("compression_method: {:?}", self.compression_method), - format!("extensions: {:?}", self.extensions), - ]; - write!(f, "{}", s.join(" ")) - } -} - -impl HandshakeMessageServerHello { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::ServerHello - } - - pub fn size(&self) -> usize { - let mut len = 2 + self.random.size(); - - // SessionID - len += 1; - - len += 2; - - len += 1; - - len += 2; - for extension in &self.extensions { - len += extension.size(); - } - - len - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - writer.write_u8(self.version.major)?; - writer.write_u8(self.version.minor)?; - self.random.marshal(writer)?; - - // SessionID - writer.write_u8(0x00)?; - - writer.write_u16::(self.cipher_suite as u16)?; - - writer.write_u8(self.compression_method as u8)?; - - let mut extension_buffer = vec![]; - { - let mut extension_writer = BufWriter::<&mut Vec>::new(extension_buffer.as_mut()); - for extension in &self.extensions { - extension.marshal(&mut extension_writer)?; - } - } - - writer.write_u16::(extension_buffer.len() as u16)?; - writer.write_all(&extension_buffer)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let major = reader.read_u8()?; - let minor = reader.read_u8()?; - let random = HandshakeRandom::unmarshal(reader)?; - - // Session ID - let session_id_len = reader.read_u8()? as usize; - let mut session_id_buffer = vec![0u8; session_id_len]; - reader.read_exact(&mut session_id_buffer)?; - - let cipher_suite: CipherSuiteId = reader.read_u16::()?.into(); - - let compression_method = reader.read_u8()?.into(); - let mut extensions = vec![]; - - let extension_buffer_len = reader.read_u16::()? as usize; - let mut extension_buffer = vec![0u8; extension_buffer_len]; - reader.read_exact(&mut extension_buffer)?; - - let mut offset = 0; - while offset < extension_buffer_len { - let mut extension_reader = BufReader::new(&extension_buffer[offset..]); - if let Ok(extension) = Extension::unmarshal(&mut extension_reader) { - extensions.push(extension); - } else { - log::warn!( - "Unsupported Extension Type {} {}", - extension_buffer[offset], - extension_buffer[offset + 1] - ); - } - - let extension_len = - u16::from_be_bytes([extension_buffer[offset + 2], extension_buffer[offset + 3]]) - as usize; - offset += 4 + extension_len; - } - - Ok(HandshakeMessageServerHello { - version: ProtocolVersion { major, minor }, - random, - - cipher_suite, - compression_method, - extensions, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_server_hello/handshake_message_server_hello_test.rs b/dtls/src/handshake/handshake_message_server_hello/handshake_message_server_hello_test.rs deleted file mode 100644 index ba906e06f..000000000 --- a/dtls/src/handshake/handshake_message_server_hello/handshake_message_server_hello_test.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::io::{BufReader, BufWriter}; -use std::time::{Duration, SystemTime}; - -use super::*; - -#[test] -fn test_handshake_message_server_hello() -> Result<()> { - let raw_server_hello = vec![ - 0xfe, 0xfd, 0x21, 0x63, 0x32, 0x21, 0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, - 0x5f, 0xd6, 0x5c, 0xcc, 0x20, 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, - 0xcf, 0x63, 0x84, 0x28, 0x00, 0xc0, 0x2b, 0x00, 0x00, 0x00, - ]; - - let gmt_unix_time = if let Some(unix_time) = - SystemTime::UNIX_EPOCH.checked_add(Duration::new(560149025u64, 0)) - { - unix_time - } else { - SystemTime::UNIX_EPOCH - }; - let parsed_server_hello = HandshakeMessageServerHello { - version: ProtocolVersion { - major: 0xFE, - minor: 0xFD, - }, - random: HandshakeRandom { - gmt_unix_time, - random_bytes: [ - 0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, - 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, 0xcf, 0x63, 0x84, 0x28, - ], - }, - cipher_suite: CipherSuiteId::Tls_Ecdhe_Ecdsa_With_Aes_128_Gcm_Sha256, - compression_method: CompressionMethodId::Null, - extensions: vec![], - }; - - let mut reader = BufReader::new(raw_server_hello.as_slice()); - let c = HandshakeMessageServerHello::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_server_hello, - "handshakeMessageServerHello unmarshal: got {c:?}, want {parsed_server_hello:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_server_hello, - "handshakeMessageServerHello marshal: got {raw:?}, want {raw_server_hello:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_server_hello_done.rs b/dtls/src/handshake/handshake_message_server_hello_done.rs deleted file mode 100644 index bae58f218..000000000 --- a/dtls/src/handshake/handshake_message_server_hello_done.rs +++ /dev/null @@ -1,27 +0,0 @@ -#[cfg(test)] -mod handshake_message_server_hello_done_test; - -use std::io::{Read, Write}; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageServerHelloDone; - -impl HandshakeMessageServerHelloDone { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::ServerHelloDone - } - - pub fn size(&self) -> usize { - 0 - } - - pub fn marshal(&self, _writer: &mut W) -> Result<()> { - Ok(()) - } - - pub fn unmarshal(_reader: &mut R) -> Result { - Ok(HandshakeMessageServerHelloDone {}) - } -} diff --git a/dtls/src/handshake/handshake_message_server_hello_done/handshake_message_server_hello_done_test.rs b/dtls/src/handshake/handshake_message_server_hello_done/handshake_message_server_hello_done_test.rs deleted file mode 100644 index b9d3b8359..000000000 --- a/dtls/src/handshake/handshake_message_server_hello_done/handshake_message_server_hello_done_test.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_server_hello_done() -> Result<()> { - let raw_server_hello_done = vec![]; - let parsed_server_hello_done = HandshakeMessageServerHelloDone {}; - - let mut reader = BufReader::new(raw_server_hello_done.as_slice()); - let c = HandshakeMessageServerHelloDone::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_server_hello_done, - "handshakeMessageServerHelloDone unmarshal: got {c:?}, want {parsed_server_hello_done:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_server_hello_done, - "handshakeMessageServerHelloDone marshal: got {raw:?}, want {raw_server_hello_done:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_message_server_key_exchange.rs b/dtls/src/handshake/handshake_message_server_key_exchange.rs deleted file mode 100644 index e84384275..000000000 --- a/dtls/src/handshake/handshake_message_server_key_exchange.rs +++ /dev/null @@ -1,133 +0,0 @@ -#[cfg(test)] -mod handshake_message_server_key_exchange_test; - -use std::io::{Read, Write}; - -use byteorder::{BigEndian, WriteBytesExt}; - -use super::*; -use crate::curve::named_curve::*; -use crate::curve::*; -use crate::signature_hash_algorithm::*; - -// Structure supports ECDH and PSK -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeMessageServerKeyExchange { - pub(crate) identity_hint: Vec, - - pub(crate) elliptic_curve_type: EllipticCurveType, - pub(crate) named_curve: NamedCurve, - pub(crate) public_key: Vec, - pub(crate) algorithm: SignatureHashAlgorithm, - pub(crate) signature: Vec, -} - -impl HandshakeMessageServerKeyExchange { - pub fn handshake_type(&self) -> HandshakeType { - HandshakeType::ServerKeyExchange - } - - pub fn size(&self) -> usize { - if !self.identity_hint.is_empty() { - 2 + self.identity_hint.len() - } else { - 1 + 2 + 1 + self.public_key.len() + 2 + 2 + self.signature.len() - } - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - if !self.identity_hint.is_empty() { - writer.write_u16::(self.identity_hint.len() as u16)?; - writer.write_all(&self.identity_hint)?; - return Ok(writer.flush()?); - } - - writer.write_u8(self.elliptic_curve_type as u8)?; - writer.write_u16::(self.named_curve as u16)?; - - writer.write_u8(self.public_key.len() as u8)?; - writer.write_all(&self.public_key)?; - - writer.write_u8(self.algorithm.hash as u8)?; - writer.write_u8(self.algorithm.signature as u8)?; - - writer.write_u16::(self.signature.len() as u16)?; - writer.write_all(&self.signature)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let mut data = vec![]; - reader.read_to_end(&mut data)?; - - // If parsed as PSK return early and only populate PSK Identity Hint - let psk_length = ((data[0] as u16) << 8) | data[1] as u16; - if data.len() == psk_length as usize + 2 { - return Ok(HandshakeMessageServerKeyExchange { - identity_hint: data[2..].to_vec(), - - elliptic_curve_type: EllipticCurveType::Unsupported, - named_curve: NamedCurve::Unsupported, - public_key: vec![], - algorithm: SignatureHashAlgorithm { - hash: HashAlgorithm::Unsupported, - signature: SignatureAlgorithm::Unsupported, - }, - signature: vec![], - }); - } - - let elliptic_curve_type = data[0].into(); - if data[1..].len() < 2 { - return Err(Error::ErrBufferTooSmall); - } - - let named_curve = (((data[1] as u16) << 8) | data[2] as u16).into(); - if data.len() < 4 { - return Err(Error::ErrBufferTooSmall); - } - - let public_key_length = data[3] as usize; - let mut offset = 4 + public_key_length; - if data.len() < offset { - return Err(Error::ErrBufferTooSmall); - } - let public_key = data[4..offset].to_vec(); - if data.len() <= offset { - return Err(Error::ErrBufferTooSmall); - } - - let hash_algorithm = data[offset].into(); - offset += 1; - if data.len() <= offset { - return Err(Error::ErrBufferTooSmall); - } - - let signature_algorithm = data[offset].into(); - offset += 1; - if data.len() < offset + 2 { - return Err(Error::ErrBufferTooSmall); - } - - let signature_length = (((data[offset] as u16) << 8) | data[offset + 1] as u16) as usize; - offset += 2; - if data.len() < offset + signature_length { - return Err(Error::ErrBufferTooSmall); - } - let signature = data[offset..offset + signature_length].to_vec(); - - Ok(HandshakeMessageServerKeyExchange { - identity_hint: vec![], - - elliptic_curve_type, - named_curve, - public_key, - algorithm: SignatureHashAlgorithm { - hash: hash_algorithm, - signature: signature_algorithm, - }, - signature, - }) - } -} diff --git a/dtls/src/handshake/handshake_message_server_key_exchange/handshake_message_server_key_exchange_test.rs b/dtls/src/handshake/handshake_message_server_key_exchange/handshake_message_server_key_exchange_test.rs deleted file mode 100644 index 4fd7adf53..000000000 --- a/dtls/src/handshake/handshake_message_server_key_exchange/handshake_message_server_key_exchange_test.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; - -#[test] -fn test_handshake_message_server_key_exchange() -> Result<()> { - let raw_server_key_exchange = vec![ - 0x03, 0x00, 0x1d, 0x41, 0x04, 0x0c, 0xb9, 0xa3, 0xb9, 0x90, 0x71, 0x35, 0x4a, 0x08, 0x66, - 0xaf, 0xd6, 0x88, 0x58, 0x29, 0x69, 0x98, 0xf1, 0x87, 0x0f, 0xb5, 0xa8, 0xcd, 0x92, 0xf6, - 0x2b, 0x08, 0x0c, 0xd4, 0x16, 0x5b, 0xcc, 0x81, 0xf2, 0x58, 0x91, 0x8e, 0x62, 0xdf, 0xc1, - 0xec, 0x72, 0xe8, 0x47, 0x24, 0x42, 0x96, 0xb8, 0x7b, 0xee, 0xe7, 0x0d, 0xdc, 0x44, 0xec, - 0xf3, 0x97, 0x6b, 0x1b, 0x45, 0x28, 0xac, 0x3f, 0x35, 0x02, 0x03, 0x00, 0x47, 0x30, 0x45, - 0x02, 0x21, 0x00, 0xb2, 0x0b, 0x22, 0x95, 0x3d, 0x56, 0x57, 0x6a, 0x3f, 0x85, 0x30, 0x6f, - 0x55, 0xc3, 0xf4, 0x24, 0x1b, 0x21, 0x07, 0xe5, 0xdf, 0xba, 0x24, 0x02, 0x68, 0x95, 0x1f, - 0x6e, 0x13, 0xbd, 0x9f, 0xaa, 0x02, 0x20, 0x49, 0x9c, 0x9d, 0xdf, 0x84, 0x60, 0x33, 0x27, - 0x96, 0x9e, 0x58, 0x6d, 0x72, 0x13, 0xe7, 0x3a, 0xe8, 0xdf, 0x43, 0x75, 0xc7, 0xb9, 0x37, - 0x6e, 0x90, 0xe5, 0x3b, 0x81, 0xd4, 0xda, 0x68, 0xcd, - ]; - let parsed_server_key_exchange = HandshakeMessageServerKeyExchange { - identity_hint: vec![], - elliptic_curve_type: EllipticCurveType::NamedCurve, - named_curve: NamedCurve::X25519, - public_key: raw_server_key_exchange[4..69].to_vec(), - algorithm: SignatureHashAlgorithm { - hash: HashAlgorithm::Sha1, - signature: SignatureAlgorithm::Ecdsa, - }, - - signature: raw_server_key_exchange[73..144].to_vec(), - }; - - let mut reader = BufReader::new(raw_server_key_exchange.as_slice()); - let c = HandshakeMessageServerKeyExchange::unmarshal(&mut reader)?; - assert_eq!( - c, parsed_server_key_exchange, - "handshakeMessageServerKeyExchange unmarshal: got {c:?}, want {parsed_server_key_exchange:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - c.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_server_key_exchange, - "handshakeMessageServerKeyExchange marshal: got {raw:?}, want {raw_server_key_exchange:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/handshake_random.rs b/dtls/src/handshake/handshake_random.rs deleted file mode 100644 index 4ff468a45..000000000 --- a/dtls/src/handshake/handshake_random.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::io::{self, Read, Write}; -use std::time::{Duration, SystemTime}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use rand::Rng; - -pub const RANDOM_BYTES_LENGTH: usize = 28; -pub const HANDSHAKE_RANDOM_LENGTH: usize = RANDOM_BYTES_LENGTH + 4; - -// https://tools.ietf.org/html/rfc4346#section-7.4.1.2 -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct HandshakeRandom { - pub gmt_unix_time: SystemTime, - pub random_bytes: [u8; RANDOM_BYTES_LENGTH], -} - -impl Default for HandshakeRandom { - fn default() -> Self { - HandshakeRandom { - gmt_unix_time: SystemTime::UNIX_EPOCH, - random_bytes: [0u8; RANDOM_BYTES_LENGTH], - } - } -} - -impl HandshakeRandom { - pub fn size(&self) -> usize { - 4 + RANDOM_BYTES_LENGTH - } - - pub fn marshal(&self, writer: &mut W) -> io::Result<()> { - let secs = match self.gmt_unix_time.duration_since(SystemTime::UNIX_EPOCH) { - Ok(d) => d.as_secs() as u32, - Err(_) => 0, - }; - writer.write_u32::(secs)?; - writer.write_all(&self.random_bytes)?; - - writer.flush() - } - - pub fn unmarshal(reader: &mut R) -> io::Result { - let secs = reader.read_u32::()?; - let gmt_unix_time = if let Some(unix_time) = - SystemTime::UNIX_EPOCH.checked_add(Duration::new(secs as u64, 0)) - { - unix_time - } else { - SystemTime::UNIX_EPOCH - }; - - let mut random_bytes = [0u8; RANDOM_BYTES_LENGTH]; - reader.read_exact(&mut random_bytes)?; - - Ok(HandshakeRandom { - gmt_unix_time, - random_bytes, - }) - } - - // populate fills the HandshakeRandom with random values - // may be called multiple times - pub fn populate(&mut self) { - self.gmt_unix_time = SystemTime::now(); - rand::thread_rng().fill(&mut self.random_bytes); - } -} diff --git a/dtls/src/handshake/handshake_test.rs b/dtls/src/handshake/handshake_test.rs deleted file mode 100644 index a61da45c6..000000000 --- a/dtls/src/handshake/handshake_test.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::io::{BufReader, BufWriter}; -use std::time::{Duration, SystemTime}; - -use super::*; -use crate::compression_methods::*; -use crate::handshake::handshake_message_client_hello::*; -use crate::handshake::handshake_random::HandshakeRandom; -use crate::record_layer::record_layer_header::ProtocolVersion; - -#[test] -fn test_handshake_message() -> Result<()> { - let raw_handshake_message = vec![ - 0x01, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0xfe, 0xfd, 0xb6, - 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, - 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, - 0x4b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]; - let parsed_handshake = Handshake { - handshake_header: HandshakeHeader { - handshake_type: HandshakeType::ClientHello, - length: 0x29, - message_sequence: 0, - fragment_offset: 0, - fragment_length: 0x29, - }, - handshake_message: HandshakeMessage::ClientHello(HandshakeMessageClientHello { - version: ProtocolVersion { - major: 0xFE, - minor: 0xFD, - }, - random: HandshakeRandom { - gmt_unix_time: if let Some(unix_time) = - SystemTime::UNIX_EPOCH.checked_add(Duration::new(3056586332u64, 0)) - { - unix_time - } else { - SystemTime::UNIX_EPOCH - }, - random_bytes: [ - 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, - 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, - 0xdc, 0x4b, - ], - }, - cookie: vec![], - cipher_suites: vec![], - compression_methods: CompressionMethods { ids: vec![] }, - extensions: vec![], - }), - }; - - let mut reader = BufReader::new(raw_handshake_message.as_slice()); - let h = Handshake::unmarshal(&mut reader)?; - assert_eq!( - h, parsed_handshake, - "handshakeMessageClientHello unmarshal: got {h:?}, want {parsed_handshake:?}" - ); - - let mut raw = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(raw.as_mut()); - h.marshal(&mut writer)?; - } - assert_eq!( - raw, raw_handshake_message, - "handshakeMessageClientHello marshal: got {raw:?}, want {raw_handshake_message:?}" - ); - - Ok(()) -} diff --git a/dtls/src/handshake/mod.rs b/dtls/src/handshake/mod.rs deleted file mode 100644 index a71f44f34..000000000 --- a/dtls/src/handshake/mod.rs +++ /dev/null @@ -1,238 +0,0 @@ -pub mod handshake_cache; -pub mod handshake_header; -pub mod handshake_message_certificate; -pub mod handshake_message_certificate_request; -pub mod handshake_message_certificate_verify; -pub mod handshake_message_client_hello; -pub mod handshake_message_client_key_exchange; -pub mod handshake_message_finished; -pub mod handshake_message_hello_verify_request; -pub mod handshake_message_server_hello; -pub mod handshake_message_server_hello_done; -pub mod handshake_message_server_key_exchange; -pub mod handshake_random; - -#[cfg(test)] -mod handshake_test; - -use std::fmt; -use std::io::{Read, Write}; - -use handshake_header::*; -use handshake_message_certificate::*; -use handshake_message_certificate_request::*; -use handshake_message_certificate_verify::*; -use handshake_message_client_hello::*; -use handshake_message_client_key_exchange::*; -use handshake_message_finished::*; -use handshake_message_hello_verify_request::*; -use handshake_message_server_hello::*; -use handshake_message_server_hello_done::*; -use handshake_message_server_key_exchange::*; - -use super::content::*; -use super::error::*; - -// https://tools.ietf.org/html/rfc5246#section-7.4 -#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub enum HandshakeType { - HelloRequest = 0, - ClientHello = 1, - ServerHello = 2, - HelloVerifyRequest = 3, - Certificate = 11, - ServerKeyExchange = 12, - CertificateRequest = 13, - ServerHelloDone = 14, - CertificateVerify = 15, - ClientKeyExchange = 16, - Finished = 20, - #[default] - Invalid, -} - -impl fmt::Display for HandshakeType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - HandshakeType::HelloRequest => write!(f, "HelloRequest"), - HandshakeType::ClientHello => write!(f, "ClientHello"), - HandshakeType::ServerHello => write!(f, "ServerHello"), - HandshakeType::HelloVerifyRequest => write!(f, "HelloVerifyRequest"), - HandshakeType::Certificate => write!(f, "Certificate"), - HandshakeType::ServerKeyExchange => write!(f, "ServerKeyExchange"), - HandshakeType::CertificateRequest => write!(f, "CertificateRequest"), - HandshakeType::ServerHelloDone => write!(f, "ServerHelloDone"), - HandshakeType::CertificateVerify => write!(f, "CertificateVerify"), - HandshakeType::ClientKeyExchange => write!(f, "ClientKeyExchange"), - HandshakeType::Finished => write!(f, "Finished"), - HandshakeType::Invalid => write!(f, "Invalid"), - } - } -} - -impl From for HandshakeType { - fn from(val: u8) -> Self { - match val { - 0 => HandshakeType::HelloRequest, - 1 => HandshakeType::ClientHello, - 2 => HandshakeType::ServerHello, - 3 => HandshakeType::HelloVerifyRequest, - 11 => HandshakeType::Certificate, - 12 => HandshakeType::ServerKeyExchange, - 13 => HandshakeType::CertificateRequest, - 14 => HandshakeType::ServerHelloDone, - 15 => HandshakeType::CertificateVerify, - 16 => HandshakeType::ClientKeyExchange, - 20 => HandshakeType::Finished, - _ => HandshakeType::Invalid, - } - } -} - -#[derive(PartialEq, Debug, Clone)] -pub enum HandshakeMessage { - //HelloRequest(errNotImplemented), - ClientHello(HandshakeMessageClientHello), - ServerHello(HandshakeMessageServerHello), - HelloVerifyRequest(HandshakeMessageHelloVerifyRequest), - Certificate(HandshakeMessageCertificate), - ServerKeyExchange(HandshakeMessageServerKeyExchange), - CertificateRequest(HandshakeMessageCertificateRequest), - ServerHelloDone(HandshakeMessageServerHelloDone), - CertificateVerify(HandshakeMessageCertificateVerify), - ClientKeyExchange(HandshakeMessageClientKeyExchange), - Finished(HandshakeMessageFinished), -} - -impl HandshakeMessage { - pub fn handshake_type(&self) -> HandshakeType { - match self { - HandshakeMessage::ClientHello(msg) => msg.handshake_type(), - HandshakeMessage::ServerHello(msg) => msg.handshake_type(), - HandshakeMessage::HelloVerifyRequest(msg) => msg.handshake_type(), - HandshakeMessage::Certificate(msg) => msg.handshake_type(), - HandshakeMessage::ServerKeyExchange(msg) => msg.handshake_type(), - HandshakeMessage::CertificateRequest(msg) => msg.handshake_type(), - HandshakeMessage::ServerHelloDone(msg) => msg.handshake_type(), - HandshakeMessage::CertificateVerify(msg) => msg.handshake_type(), - HandshakeMessage::ClientKeyExchange(msg) => msg.handshake_type(), - HandshakeMessage::Finished(msg) => msg.handshake_type(), - } - } - - pub fn size(&self) -> usize { - match self { - HandshakeMessage::ClientHello(msg) => msg.size(), - HandshakeMessage::ServerHello(msg) => msg.size(), - HandshakeMessage::HelloVerifyRequest(msg) => msg.size(), - HandshakeMessage::Certificate(msg) => msg.size(), - HandshakeMessage::ServerKeyExchange(msg) => msg.size(), - HandshakeMessage::CertificateRequest(msg) => msg.size(), - HandshakeMessage::ServerHelloDone(msg) => msg.size(), - HandshakeMessage::CertificateVerify(msg) => msg.size(), - HandshakeMessage::ClientKeyExchange(msg) => msg.size(), - HandshakeMessage::Finished(msg) => msg.size(), - } - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - match self { - HandshakeMessage::ClientHello(msg) => msg.marshal(writer)?, - HandshakeMessage::ServerHello(msg) => msg.marshal(writer)?, - HandshakeMessage::HelloVerifyRequest(msg) => msg.marshal(writer)?, - HandshakeMessage::Certificate(msg) => msg.marshal(writer)?, - HandshakeMessage::ServerKeyExchange(msg) => msg.marshal(writer)?, - HandshakeMessage::CertificateRequest(msg) => msg.marshal(writer)?, - HandshakeMessage::ServerHelloDone(msg) => msg.marshal(writer)?, - HandshakeMessage::CertificateVerify(msg) => msg.marshal(writer)?, - HandshakeMessage::ClientKeyExchange(msg) => msg.marshal(writer)?, - HandshakeMessage::Finished(msg) => msg.marshal(writer)?, - } - - Ok(()) - } -} - -// The handshake protocol is responsible for selecting a cipher spec and -// generating a master secret, which together comprise the primary -// cryptographic parameters associated with a secure session. The -// handshake protocol can also optionally authenticate parties who have -// certificates signed by a trusted certificate authority. -// https://tools.ietf.org/html/rfc5246#section-7.3 -#[derive(PartialEq, Debug, Clone)] -pub struct Handshake { - pub(crate) handshake_header: HandshakeHeader, - pub(crate) handshake_message: HandshakeMessage, -} - -impl Handshake { - pub fn new(handshake_message: HandshakeMessage) -> Self { - Handshake { - handshake_header: HandshakeHeader { - handshake_type: handshake_message.handshake_type(), - length: handshake_message.size() as u32, - message_sequence: 0, - fragment_offset: 0, - fragment_length: handshake_message.size() as u32, - }, - handshake_message, - } - } - - pub fn content_type(&self) -> ContentType { - ContentType::Handshake - } - - pub fn size(&self) -> usize { - self.handshake_header.size() + self.handshake_message.size() - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - self.handshake_header.marshal(writer)?; - self.handshake_message.marshal(writer)?; - Ok(()) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let handshake_header = HandshakeHeader::unmarshal(reader)?; - - let handshake_message = match handshake_header.handshake_type { - HandshakeType::ClientHello => { - HandshakeMessage::ClientHello(HandshakeMessageClientHello::unmarshal(reader)?) - } - HandshakeType::ServerHello => { - HandshakeMessage::ServerHello(HandshakeMessageServerHello::unmarshal(reader)?) - } - HandshakeType::HelloVerifyRequest => HandshakeMessage::HelloVerifyRequest( - HandshakeMessageHelloVerifyRequest::unmarshal(reader)?, - ), - HandshakeType::Certificate => { - HandshakeMessage::Certificate(HandshakeMessageCertificate::unmarshal(reader)?) - } - HandshakeType::ServerKeyExchange => HandshakeMessage::ServerKeyExchange( - HandshakeMessageServerKeyExchange::unmarshal(reader)?, - ), - HandshakeType::CertificateRequest => HandshakeMessage::CertificateRequest( - HandshakeMessageCertificateRequest::unmarshal(reader)?, - ), - HandshakeType::ServerHelloDone => HandshakeMessage::ServerHelloDone( - HandshakeMessageServerHelloDone::unmarshal(reader)?, - ), - HandshakeType::CertificateVerify => HandshakeMessage::CertificateVerify( - HandshakeMessageCertificateVerify::unmarshal(reader)?, - ), - HandshakeType::ClientKeyExchange => HandshakeMessage::ClientKeyExchange( - HandshakeMessageClientKeyExchange::unmarshal(reader)?, - ), - HandshakeType::Finished => { - HandshakeMessage::Finished(HandshakeMessageFinished::unmarshal(reader)?) - } - _ => return Err(Error::ErrNotImplemented), - }; - - Ok(Handshake { - handshake_header, - handshake_message, - }) - } -} diff --git a/dtls/src/handshaker.rs b/dtls/src/handshaker.rs deleted file mode 100644 index b6e1a9e2b..000000000 --- a/dtls/src/handshaker.rs +++ /dev/null @@ -1,429 +0,0 @@ -use std::collections::HashMap; -use std::fmt; -use std::sync::Arc; - -use log::*; - -use crate::cipher_suite::*; -use crate::config::*; -use crate::conn::*; -use crate::content::*; -use crate::crypto::*; -use crate::error::*; -use crate::extension::extension_use_srtp::*; -use crate::signature_hash_algorithm::*; - -use rustls::client::danger::ServerCertVerifier; -use rustls::pki_types::CertificateDer; -use rustls::server::danger::ClientCertVerifier; - -//use std::io::BufWriter; - -// [RFC6347 Section-4.2.4] -// +-----------+ -// +---> | PREPARING | <--------------------+ -// | +-----------+ | -// | | | -// | | Buffer next flight | -// | | | -// | \|/ | -// | +-----------+ | -// | | SENDING |<------------------+ | Send -// | +-----------+ | | HelloRequest -// Receive | | | | -// next | | Send flight | | or -// flight | +--------+ | | -// | | | Set retransmit timer | | Receive -// | | \|/ | | HelloRequest -// | | +-----------+ | | Send -// +--)--| WAITING |-------------------+ | ClientHello -// | | +-----------+ Timer expires | | -// | | | | | -// | | +------------------------+ | -// Receive | | Send Read retransmit | -// last | | last | -// flight | | flight | -// | | | -// \|/\|/ | -// +-----------+ | -// | FINISHED | -------------------------------+ -// +-----------+ -// | /|\ -// | | -// +---+ -// Read retransmit -// Retransmit last flight - -#[derive(Copy, Clone, PartialEq)] -pub(crate) enum HandshakeState { - Errored, - Preparing, - Sending, - Waiting, - Finished, -} - -impl fmt::Display for HandshakeState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - HandshakeState::Errored => write!(f, "Errored"), - HandshakeState::Preparing => write!(f, "Preparing"), - HandshakeState::Sending => write!(f, "Sending"), - HandshakeState::Waiting => write!(f, "Waiting"), - HandshakeState::Finished => write!(f, "Finished"), - } - } -} - -pub(crate) type VerifyPeerCertificateFn = - Arc], &[CertificateDer<'static>]) -> Result<()>) + Send + Sync>; - -pub(crate) struct HandshakeConfig { - pub(crate) local_psk_callback: Option, - pub(crate) local_psk_identity_hint: Option>, - pub(crate) local_cipher_suites: Vec, // Available CipherSuites - pub(crate) local_signature_schemes: Vec, // Available signature schemes - pub(crate) extended_master_secret: ExtendedMasterSecretType, // Policy for the Extended Master Support extension - pub(crate) local_srtp_protection_profiles: Vec, // Available SRTPProtectionProfiles, if empty no SRTP support - pub(crate) server_name: String, - pub(crate) client_auth: ClientAuthType, // If we are a client should we request a client certificate - pub(crate) local_certificates: Vec, - pub(crate) name_to_certificate: HashMap, - pub(crate) insecure_skip_verify: bool, - pub(crate) insecure_verification: bool, - pub(crate) verify_peer_certificate: Option, - pub(crate) server_cert_verifier: Arc, - pub(crate) client_cert_verifier: Option>, - pub(crate) retransmit_interval: tokio::time::Duration, - pub(crate) initial_epoch: u16, - //log logging.LeveledLogger - //mu sync.Mutex -} - -pub fn gen_self_signed_root_cert() -> rustls::RootCertStore { - let mut certs = rustls::RootCertStore::empty(); - certs - .add( - rcgen::generate_simple_self_signed(vec![]) - .unwrap() - .cert - .der() - .to_owned(), - ) - .unwrap(); - certs -} - -impl Default for HandshakeConfig { - fn default() -> Self { - HandshakeConfig { - local_psk_callback: None, - local_psk_identity_hint: None, - local_cipher_suites: vec![], - local_signature_schemes: vec![], - extended_master_secret: ExtendedMasterSecretType::Disable, - local_srtp_protection_profiles: vec![], - server_name: String::new(), - client_auth: ClientAuthType::NoClientCert, - local_certificates: vec![], - name_to_certificate: HashMap::new(), - insecure_skip_verify: false, - insecure_verification: false, - verify_peer_certificate: None, - server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new( - gen_self_signed_root_cert(), - )) - .build() - .unwrap(), - client_cert_verifier: None, - retransmit_interval: tokio::time::Duration::from_secs(0), - initial_epoch: 0, - } - } -} - -impl HandshakeConfig { - pub(crate) fn get_certificate(&self, server_name: &str) -> Result { - //TODO - /*if self.name_to_certificate.is_empty() { - let mut name_to_certificate = HashMap::new(); - for cert in &self.local_certificates { - if let Ok((_rem, x509_cert)) = x509_parser::parse_x509_der(&cert.certificate) { - if let Some(a) = x509_cert.tbs_certificate.subject.iter_common_name().next() { - let common_name = match a.attr_value.as_str() { - Ok(cn) => cn.to_lowercase(), - Err(err) => return Err(Error::new(err.to_string())), - }; - name_to_certificate.insert(common_name, cert.clone()); - } - if let Some((_, sans)) = x509_cert.tbs_certificate.subject_alternative_name() { - for gn in &sans.general_names { - match gn { - x509_parser::extensions::GeneralName::DNSName(san) => { - let san = san.to_lowercase(); - name_to_certificate.insert(san, cert.clone()); - } - _ => {} - } - } - } - } else { - continue; - } - } - self.name_to_certificate = name_to_certificate; - }*/ - - if self.local_certificates.is_empty() { - return Err(Error::ErrNoCertificates); - } - - if self.local_certificates.len() == 1 { - // There's only one choice, so no point doing any work. - return Ok(self.local_certificates[0].clone()); - } - - if server_name.is_empty() { - return Ok(self.local_certificates[0].clone()); - } - - let lower = server_name.to_lowercase(); - let name = lower.trim_end_matches('.'); - - if let Some(cert) = self.name_to_certificate.get(name) { - return Ok(cert.clone()); - } - - // try replacing labels in the name with wildcards until we get a - // match. - let mut labels: Vec<&str> = name.split_terminator('.').collect(); - for i in 0..labels.len() { - labels[i] = "*"; - let candidate = labels.join("."); - if let Some(cert) = self.name_to_certificate.get(&candidate) { - return Ok(cert.clone()); - } - } - - // If nothing matches, return the first certificate. - Ok(self.local_certificates[0].clone()) - } -} - -pub(crate) fn srv_cli_str(is_client: bool) -> String { - if is_client { - return "client".to_owned(); - } - "server".to_owned() -} - -impl DTLSConn { - pub(crate) async fn handshake(&mut self, mut state: HandshakeState) -> Result<()> { - loop { - trace!( - "[handshake:{}] {}: {}", - srv_cli_str(self.state.is_client), - self.current_flight.to_string(), - state.to_string() - ); - - if state == HandshakeState::Finished && !self.is_handshake_completed_successfully() { - self.set_handshake_completed_successfully(); - self.handshake_done_tx.take(); // drop it by take - return Ok(()); - } - - state = match state { - HandshakeState::Preparing => self.prepare().await?, - HandshakeState::Sending => self.send().await?, - HandshakeState::Waiting => self.wait().await?, - HandshakeState::Finished => self.finish().await?, - _ => return Err(Error::ErrInvalidFsmTransition), - }; - } - } - - async fn prepare(&mut self) -> Result { - self.flights = None; - - // Prepare flights - self.retransmit = self.current_flight.has_retransmit(); - - let result = self - .current_flight - .generate(&mut self.state, &self.cache, &self.cfg) - .await; - - match result { - Err((a, mut err)) => { - if let Some(a) = a { - let alert_err = self.notify(a.alert_level, a.alert_description).await; - - if let Err(alert_err) = alert_err { - if err.is_some() { - err = Some(alert_err); - } - } - } - if let Some(err) = err { - return Err(err); - } - } - Ok(pkts) => { - /*if !pkts.is_empty() { - let mut s = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(s.as_mut()); - pkts[0].record.content.marshal(&mut writer)?; - } - trace!( - "[handshake:{}] {}: {:?}", - srv_cli_str(self.state.is_client), - self.current_flight.to_string(), - s, - ); - }*/ - self.flights = Some(pkts) - } - }; - - let epoch = self.cfg.initial_epoch; - let mut next_epoch = epoch; - if let Some(pkts) = &mut self.flights { - for p in pkts { - p.record.record_layer_header.epoch += epoch; - if p.record.record_layer_header.epoch > next_epoch { - next_epoch = p.record.record_layer_header.epoch; - } - if let Content::Handshake(h) = &mut p.record.content { - h.handshake_header.message_sequence = self.state.handshake_send_sequence as u16; - self.state.handshake_send_sequence += 1; - } - } - } - if epoch != next_epoch { - trace!( - "[handshake:{}] -> changeCipherSpec (epoch: {})", - srv_cli_str(self.state.is_client), - next_epoch - ); - self.set_local_epoch(next_epoch); - } - - Ok(HandshakeState::Sending) - } - async fn send(&mut self) -> Result { - // Send flights - if let Some(pkts) = self.flights.clone() { - self.write_packets(pkts).await?; - } - - if self.current_flight.is_last_send_flight() { - Ok(HandshakeState::Finished) - } else { - Ok(HandshakeState::Waiting) - } - } - async fn wait(&mut self) -> Result { - let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval); - tokio::pin!(retransmit_timer); - - loop { - tokio::select! { - done = self.handshake_rx.recv() =>{ - if done.is_none() { - trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string()); - return Err(Error::ErrAlertFatalOrClose); - } - - //trace!("[handshake:{}] {} received handshake_rx", srv_cli_str(self.state.is_client), self.current_flight.to_string()); - let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await; - drop(done); - match result { - Err((alert, mut err)) => { - trace!("[handshake:{}] {} result alert:{:?}, err:{:?}", - srv_cli_str(self.state.is_client), - self.current_flight.to_string(), - alert, - err); - - if let Some(alert) = alert { - let alert_err = self.notify(alert.alert_level, alert.alert_description).await; - - if let Err(alert_err) = alert_err { - if err.is_some() { - err = Some(alert_err); - } - } - } - if let Some(err) = err { - return Err(err); - } - } - Ok(next_flight) => { - trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight.to_string(), next_flight.to_string()); - if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() { - return Ok(HandshakeState::Finished); - } - self.current_flight = next_flight; - return Ok(HandshakeState::Preparing); - } - }; - } - - _ = retransmit_timer.as_mut() =>{ - trace!("[handshake:{}] {} retransmit_timer", srv_cli_str(self.state.is_client), self.current_flight.to_string()); - - if !self.retransmit { - return Ok(HandshakeState::Waiting); - } - return Ok(HandshakeState::Sending); - } - - /*_ = self.done_rx.recv() => { - return Err(Error::new("done_rx recv".to_owned())); - }*/ - } - } - } - async fn finish(&mut self) -> Result { - let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval); - - tokio::select! { - done = self.handshake_rx.recv() =>{ - if done.is_none() { - trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string()); - return Err(Error::ErrAlertFatalOrClose); - } - let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await; - drop(done); - match result { - Err((alert, mut err)) => { - if let Some(alert) = alert { - let alert_err = self.notify(alert.alert_level, alert.alert_description).await; - if let Err(alert_err) = alert_err { - if err.is_some() { - err = Some(alert_err); - } - } - } - if let Some(err) = err { - return Err(err); - } - } - Ok(_) => { - retransmit_timer.await; - // Retransmit last flight - return Ok(HandshakeState::Sending); - } - }; - } - - /*_ = self.done_rx.recv() => { - return Err(Error::new("done_rx recv".to_owned())); - }*/ - } - - Ok(HandshakeState::Finished) - } -} diff --git a/dtls/src/lib.rs b/dtls/src/lib.rs deleted file mode 100644 index fd75b6bc8..000000000 --- a/dtls/src/lib.rs +++ /dev/null @@ -1,57 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod alert; -pub mod application_data; -pub mod change_cipher_spec; -pub mod cipher_suite; -pub mod client_certificate_type; -pub mod compression_methods; -pub mod config; -pub mod conn; -pub mod content; -pub mod crypto; -pub mod curve; -mod error; -pub mod extension; -pub mod flight; -pub mod fragment_buffer; -pub mod handshake; -pub mod handshaker; -pub mod listener; -pub mod prf; -pub mod record_layer; -pub mod signature_hash_algorithm; -pub mod state; - -use cipher_suite::*; -pub use error::Error; -use extension::extension_use_srtp::SrtpProtectionProfile; - -pub(crate) fn find_matching_srtp_profile( - a: &[SrtpProtectionProfile], - b: &[SrtpProtectionProfile], -) -> Result { - for a_profile in a { - for b_profile in b { - if a_profile == b_profile { - return Ok(*a_profile); - } - } - } - Err(()) -} - -pub(crate) fn find_matching_cipher_suite( - a: &[CipherSuiteId], - b: &[CipherSuiteId], -) -> Result { - for a_suite in a { - for b_suite in b { - if a_suite == b_suite { - return Ok(*a_suite); - } - } - } - Err(()) -} diff --git a/dtls/src/listener.rs b/dtls/src/listener.rs deleted file mode 100644 index 0dc6257ce..000000000 --- a/dtls/src/listener.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::future::Future; -use std::io::BufReader; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::net::ToSocketAddrs; -use util::conn::conn_udp_listener::*; -use util::conn::*; - -use crate::config::*; -use crate::conn::DTLSConn; -use crate::content::ContentType; -use crate::error::Result; -use crate::record_layer::record_layer_header::RecordLayerHeader; -use crate::record_layer::unpack_datagram; - -/// Listen creates a DTLS listener -pub async fn listen(laddr: A, config: Config) -> Result { - validate_config(false, &config)?; - - let mut lc = ListenConfig { - accept_filter: Some(Box::new( - |packet: &[u8]| -> Pin + Send + 'static>> { - let pkts = match unpack_datagram(packet) { - Ok(pkts) => { - if pkts.is_empty() { - return Box::pin(async { false }); - } - pkts - } - Err(_) => return Box::pin(async { false }), - }; - - let mut reader = BufReader::new(pkts[0].as_slice()); - match RecordLayerHeader::unmarshal(&mut reader) { - Ok(h) => { - let content_type = h.content_type; - Box::pin(async move { content_type == ContentType::Handshake }) - } - Err(_) => Box::pin(async { false }), - } - }, - )), - ..Default::default() - }; - - let parent = Arc::new(lc.listen(laddr).await?); - Ok(DTLSListener { parent, config }) -} - -/// DTLSListener represents a DTLS listener -pub struct DTLSListener { - parent: Arc, - config: Config, -} - -impl DTLSListener { - /// creates a DTLS listener which accepts connections from an inner Listener. - pub fn new(parent: Arc, config: Config) -> Result { - validate_config(false, &config)?; - - Ok(DTLSListener { parent, config }) - } -} - -type UtilResult = std::result::Result; - -#[async_trait] -impl Listener for DTLSListener { - /// Accept waits for and returns the next connection to the listener. - /// You have to either close or read on all connection that are created. - /// Connection handshake will timeout using ConnectContextMaker in the Config. - /// If you want to specify the timeout duration, set ConnectContextMaker. - async fn accept(&self) -> UtilResult<(Arc, SocketAddr)> { - let (conn, raddr) = self.parent.accept().await?; - let dtls_conn = DTLSConn::new(conn, self.config.clone(), false, None) - .await - .map_err(util::Error::from_std)?; - Ok((Arc::new(dtls_conn), raddr)) - } - - /// Close closes the listener. - /// Any blocked Accept operations will be unblocked and return errors. - /// Already Accepted connections are not closed. - async fn close(&self) -> UtilResult<()> { - self.parent.close().await - } - - /// Addr returns the listener's network address. - async fn addr(&self) -> UtilResult { - self.parent.addr().await - } -} diff --git a/dtls/src/prf/mod.rs b/dtls/src/prf/mod.rs deleted file mode 100644 index e38cd454e..000000000 --- a/dtls/src/prf/mod.rs +++ /dev/null @@ -1,315 +0,0 @@ -#[cfg(test)] -mod prf_test; - -use std::convert::TryInto; -use std::fmt; - -use hmac::{Hmac, Mac}; -use sha1::Sha1; -use sha2::{Digest, Sha256}; - -type HmacSha256 = Hmac; -type HmacSha1 = Hmac; - -use crate::cipher_suite::CipherSuiteHash; -use crate::content::ContentType; -use crate::curve::named_curve::*; -use crate::error::*; -use crate::record_layer::record_layer_header::ProtocolVersion; - -pub(crate) const PRF_MASTER_SECRET_LABEL: &str = "master secret"; -pub(crate) const PRF_EXTENDED_MASTER_SECRET_LABEL: &str = "extended master secret"; -pub(crate) const PRF_KEY_EXPANSION_LABEL: &str = "key expansion"; -pub(crate) const PRF_VERIFY_DATA_CLIENT_LABEL: &str = "client finished"; -pub(crate) const PRF_VERIFY_DATA_SERVER_LABEL: &str = "server finished"; - -#[derive(PartialEq, Debug, Clone)] -pub(crate) struct EncryptionKeys { - pub(crate) master_secret: Vec, - pub(crate) client_mac_key: Vec, - pub(crate) server_mac_key: Vec, - pub(crate) client_write_key: Vec, - pub(crate) server_write_key: Vec, - pub(crate) client_write_iv: Vec, - pub(crate) server_write_iv: Vec, -} - -impl fmt::Display for EncryptionKeys { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = "EncryptionKeys:\n".to_string(); - - out += format!("- master_secret: {:?}\n", self.master_secret).as_str(); - out += format!("- client_mackey: {:?}\n", self.client_mac_key).as_str(); - out += format!("- server_mackey: {:?}\n", self.server_mac_key).as_str(); - out += format!("- client_write_key: {:?}\n", self.client_write_key).as_str(); - out += format!("- server_write_key: {:?}\n", self.server_write_key).as_str(); - out += format!("- client_write_iv: {:?}\n", self.client_write_iv).as_str(); - out += format!("- server_write_iv: {:?}\n", self.server_write_iv).as_str(); - - write!(f, "{out}") - } -} - -// The premaster secret is formed as follows: if the PSK is N octets -// long, concatenate a uint16 with the value N, N zero octets, a second -// uint16 with the value N, and the PSK itself. -// -// https://tools.ietf.org/html/rfc4279#section-2 -pub(crate) fn prf_psk_pre_master_secret(psk: &[u8]) -> Vec { - let psk_len = psk.len(); - - let mut out = vec![0u8; 2 + psk_len + 2]; - - out.extend_from_slice(psk); - let be = (psk_len as u16).to_be_bytes(); - out[..2].copy_from_slice(&be); - out[2 + psk_len..2 + psk_len + 2].copy_from_slice(&be); - - out -} - -pub(crate) fn prf_pre_master_secret( - public_key: &[u8], - private_key: &NamedCurvePrivateKey, - curve: NamedCurve, -) -> Result> { - match curve { - NamedCurve::P256 => elliptic_curve_pre_master_secret(public_key, private_key, curve), - NamedCurve::P384 => elliptic_curve_pre_master_secret(public_key, private_key, curve), - NamedCurve::X25519 => elliptic_curve_pre_master_secret(public_key, private_key, curve), - _ => Err(Error::ErrInvalidNamedCurve), - } -} - -fn elliptic_curve_pre_master_secret( - public_key: &[u8], - private_key: &NamedCurvePrivateKey, - curve: NamedCurve, -) -> Result> { - match curve { - NamedCurve::P256 => { - let pub_key = p256::EncodedPoint::from_bytes(public_key)?; - let public = p256::PublicKey::from_sec1_bytes(pub_key.as_ref())?; - if let NamedCurvePrivateKey::EphemeralSecretP256(secret) = private_key { - return Ok(secret.diffie_hellman(&public).raw_secret_bytes().to_vec()); - } - } - NamedCurve::P384 => { - let pub_key = p384::EncodedPoint::from_bytes(public_key)?; - let public = p384::PublicKey::from_sec1_bytes(pub_key.as_ref())?; - if let NamedCurvePrivateKey::EphemeralSecretP384(secret) = private_key { - return Ok(secret.diffie_hellman(&public).raw_secret_bytes().to_vec()); - } - } - NamedCurve::X25519 => { - if public_key.len() != 32 { - return Err(Error::Other("Public key is not 32 len".into())); - } - let pub_key: [u8; 32] = public_key.try_into().unwrap(); - let public = x25519_dalek::PublicKey::from(pub_key); - if let NamedCurvePrivateKey::StaticSecretX25519(secret) = private_key { - return Ok(secret.diffie_hellman(&public).as_bytes().to_vec()); - } - } - _ => return Err(Error::ErrInvalidNamedCurve), - } - Err(Error::ErrNamedCurveAndPrivateKeyMismatch) -} - -// This PRF with the SHA-256 hash function is used for all cipher suites -// defined in this document and in TLS documents published prior to this -// document when TLS 1.2 is negotiated. New cipher suites MUST explicitly -// specify a PRF and, in general, SHOULD use the TLS PRF with SHA-256 or a -// stronger standard hash function. -// -// P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) + -// HMAC_hash(secret, A(2) + seed) + -// HMAC_hash(secret, A(3) + seed) + ... -// -// A() is defined as: -// -// A(0) = seed -// A(i) = HMAC_hash(secret, A(i-1)) -// -// P_hash can be iterated as many times as necessary to produce the -// required quantity of data. For example, if P_SHA256 is being used to -// create 80 bytes of data, it will have to be iterated three times -// (through A(3)), creating 96 bytes of output data; the last 16 bytes -// of the final iteration will then be discarded, leaving 80 bytes of -// output data. -// -// https://tools.ietf.org/html/rfc4346w -fn hmac_sha(h: CipherSuiteHash, key: &[u8], data: &[u8]) -> Result> { - let mut mac = match h { - CipherSuiteHash::Sha256 => { - HmacSha256::new_from_slice(key).map_err(|e| Error::Other(e.to_string()))? - } - }; - mac.update(data); - let result = mac.finalize(); - let code_bytes = result.into_bytes(); - Ok(code_bytes.to_vec()) -} - -pub(crate) fn prf_p_hash( - secret: &[u8], - seed: &[u8], - requested_length: usize, - h: CipherSuiteHash, -) -> Result> { - let mut last_round = seed.to_vec(); - let mut out = vec![]; - - let iterations = ((requested_length as f64) / (h.size() as f64)).ceil() as usize; - for _ in 0..iterations { - last_round = hmac_sha(h, secret, &last_round)?; - - let mut last_round_seed = last_round.clone(); - last_round_seed.extend_from_slice(seed); - let with_secret = hmac_sha(h, secret, &last_round_seed)?; - - out.extend_from_slice(&with_secret); - } - - Ok(out[..requested_length].to_vec()) -} - -pub(crate) fn prf_extended_master_secret( - pre_master_secret: &[u8], - session_hash: &[u8], - h: CipherSuiteHash, -) -> Result> { - let mut seed = PRF_EXTENDED_MASTER_SECRET_LABEL.as_bytes().to_vec(); - seed.extend_from_slice(session_hash); - prf_p_hash(pre_master_secret, &seed, 48, h) -} - -pub(crate) fn prf_master_secret( - pre_master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - h: CipherSuiteHash, -) -> Result> { - let mut seed = PRF_MASTER_SECRET_LABEL.as_bytes().to_vec(); - seed.extend_from_slice(client_random); - seed.extend_from_slice(server_random); - prf_p_hash(pre_master_secret, &seed, 48, h) -} - -pub(crate) fn prf_encryption_keys( - master_secret: &[u8], - client_random: &[u8], - server_random: &[u8], - prf_mac_len: usize, - prf_key_len: usize, - prf_iv_len: usize, - h: CipherSuiteHash, -) -> Result { - let mut seed = PRF_KEY_EXPANSION_LABEL.as_bytes().to_vec(); - seed.extend_from_slice(server_random); - seed.extend_from_slice(client_random); - - let material = prf_p_hash( - master_secret, - &seed, - (2 * prf_mac_len) + (2 * prf_key_len) + (2 * prf_iv_len), - h, - )?; - let mut key_material = &material[..]; - - let client_mac_key = key_material[..prf_mac_len].to_vec(); - key_material = &key_material[prf_mac_len..]; - - let server_mac_key = key_material[..prf_mac_len].to_vec(); - key_material = &key_material[prf_mac_len..]; - - let client_write_key = key_material[..prf_key_len].to_vec(); - key_material = &key_material[prf_key_len..]; - - let server_write_key = key_material[..prf_key_len].to_vec(); - key_material = &key_material[prf_key_len..]; - - let client_write_iv = key_material[..prf_iv_len].to_vec(); - key_material = &key_material[prf_iv_len..]; - - let server_write_iv = key_material[..prf_iv_len].to_vec(); - - Ok(EncryptionKeys { - master_secret: master_secret.to_vec(), - client_mac_key, - server_mac_key, - client_write_key, - server_write_key, - client_write_iv, - server_write_iv, - }) -} - -pub(crate) fn prf_verify_data( - master_secret: &[u8], - handshake_bodies: &[u8], - label: &str, - h: CipherSuiteHash, -) -> Result> { - let mut hasher = match h { - CipherSuiteHash::Sha256 => Sha256::new(), - }; - hasher.update(handshake_bodies); - let result = hasher.finalize(); - let mut seed = label.as_bytes().to_vec(); - seed.extend_from_slice(&result); - - prf_p_hash(master_secret, &seed, 12, h) -} - -pub(crate) fn prf_verify_data_client( - master_secret: &[u8], - handshake_bodies: &[u8], - h: CipherSuiteHash, -) -> Result> { - prf_verify_data( - master_secret, - handshake_bodies, - PRF_VERIFY_DATA_CLIENT_LABEL, - h, - ) -} - -pub(crate) fn prf_verify_data_server( - master_secret: &[u8], - handshake_bodies: &[u8], - h: CipherSuiteHash, -) -> Result> { - prf_verify_data( - master_secret, - handshake_bodies, - PRF_VERIFY_DATA_SERVER_LABEL, - h, - ) -} - -// compute the MAC using HMAC-SHA1 -pub(crate) fn prf_mac( - epoch: u16, - sequence_number: u64, - content_type: ContentType, - protocol_version: ProtocolVersion, - payload: &[u8], - key: &[u8], -) -> Result> { - let mut hmac = HmacSha1::new_from_slice(key).map_err(|e| Error::Other(e.to_string()))?; - - let mut msg = vec![0u8; 13]; - msg[..2].copy_from_slice(&epoch.to_be_bytes()); - msg[2..8].copy_from_slice(&sequence_number.to_be_bytes()[2..]); - msg[8] = content_type as u8; - msg[9] = protocol_version.major; - msg[10] = protocol_version.minor; - msg[11..].copy_from_slice(&(payload.len() as u16).to_be_bytes()); - - hmac.update(&msg); - hmac.update(payload); - let result = hmac.finalize(); - - Ok(result.into_bytes().to_vec()) -} diff --git a/dtls/src/prf/prf_test.rs b/dtls/src/prf/prf_test.rs deleted file mode 100644 index 33d5241f2..000000000 --- a/dtls/src/prf/prf_test.rs +++ /dev/null @@ -1,261 +0,0 @@ -use super::*; -use crate::cipher_suite::CipherSuiteHash; - -#[test] -fn test_pre_master_secret() -> Result<()> { - let private_key: [u8; 32] = [ - 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, - 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, - 0x3e, 0x3f, - ]; - let private_key = - NamedCurvePrivateKey::StaticSecretX25519(x25519_dalek::StaticSecret::from(private_key)); - let public_key = [ - 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, - 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, - 0xb6, 0x15, - ]; - - let expected_pre_master_secret = vec![ - 0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, - 0xad, 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, - 0x76, 0x24, - ]; - - let pre_master_secret = prf_pre_master_secret(&public_key, &private_key, NamedCurve::X25519)?; - - assert_eq!( - expected_pre_master_secret, pre_master_secret, - "PremasterSecret exp: {expected_pre_master_secret:?} actual: {pre_master_secret:?}" - ); - - Ok(()) -} - -#[test] -fn test_master_secret() -> Result<()> { - let pre_master_secret = vec![ - 0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, - 0xad, 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, - 0x76, 0x24, - ]; - let client_random = vec![ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, - 0x1e, 0x1f, - ]; - let server_random = vec![ - 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, - 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, - 0x8e, 0x8f, - ]; - let expected_master_secret = vec![ - 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, - 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, - 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, - 0x37, 0x34, 0x4c, - ]; - - let master_secret = prf_master_secret( - &pre_master_secret, - &client_random, - &server_random, - CipherSuiteHash::Sha256, - )?; - - assert_eq!( - expected_master_secret, master_secret, - "master_secret exp: {expected_master_secret:?} actual: {master_secret:?}" - ); - - Ok(()) -} - -#[test] -fn test_encryption_keys() -> Result<()> { - let client_random = vec![ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, - 0x1e, 0x1f, - ]; - let server_random = vec![ - 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, - 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, - 0x8e, 0x8f, - ]; - let master_secret = vec![ - 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, - 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, - 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, - 0x37, 0x34, 0x4c, - ]; - - let expected_encryption_keys = EncryptionKeys { - master_secret: master_secret.clone(), - client_mac_key: vec![], - server_mac_key: vec![], - client_write_key: vec![ - 0x1b, 0x7d, 0x11, 0x7c, 0x7d, 0x5f, 0x69, 0x0b, 0xc2, 0x63, 0xca, 0xe8, 0xef, 0x60, - 0xaf, 0x0f, - ], - server_write_key: vec![ - 0x18, 0x78, 0xac, 0xc2, 0x2a, 0xd8, 0xbd, 0xd8, 0xc6, 0x01, 0xa6, 0x17, 0x12, 0x6f, - 0x63, 0x54, - ], - client_write_iv: vec![0x0e, 0xb2, 0x09, 0x06], - server_write_iv: vec![0xf7, 0x81, 0xfa, 0xd2], - }; - - let keys = prf_encryption_keys( - &master_secret, - &client_random, - &server_random, - 0, - 16, - 4, - CipherSuiteHash::Sha256, - )?; - - assert_eq!( - expected_encryption_keys, keys, - "master_secret exp: {expected_encryption_keys:?} actual: {keys:?}", - ); - - Ok(()) -} - -#[test] -fn test_verify_data() -> Result<()> { - let client_hello = vec![ - 0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8, 0xcc, 0xa9, - 0xc0, 0x2f, 0xc0, 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09, 0xc0, 0x14, 0xc0, - 0x0a, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, - 0x00, 0x58, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, - 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, - 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, - 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, - 0x0d, 0x00, 0x12, 0x00, 0x10, 0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, - 0x06, 0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x12, 0x00, 0x00, - ]; - let server_hello = vec![ - 0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, - 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, - 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00, 0x05, 0xff, - 0x01, 0x00, 0x01, 0x00, - ]; - let server_certificate = vec![ - 0x0b, 0x00, 0x03, 0x2b, 0x00, 0x03, 0x28, 0x00, 0x03, 0x25, 0x30, 0x82, 0x03, 0x21, 0x30, - 0x82, 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, - 0x04, 0x8f, 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, - 0x0b, 0x05, 0x00, 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, - 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, - 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, - 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, - 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, - 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, - 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, - 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x30, 0x82, 0x01, 0x22, - 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, - 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xc4, - 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, 0x91, 0x04, - 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, 0x7b, - 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, - 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, - 0xf3, 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, - 0xe1, 0xf3, 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, - 0xb7, 0x3b, 0x31, 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, - 0xe2, 0x0b, 0xf7, 0xf8, 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, - 0x39, 0xf1, 0xea, 0x69, 0xec, 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, - 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, - 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, - 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, - 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, - 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, 0xf4, 0xf0, 0xd5, 0xce, 0x79, - 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, 0x4b, 0xb7, 0x11, 0x01, - 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, 0x4d, 0x02, 0x9d, - 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, 0x94, 0x21, - 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, - 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, - 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, - 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, - 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, - 0xe2, 0x52, 0xcf, 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, - 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, - 0x82, 0x01, 0x01, 0x00, 0x59, 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, - 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, - 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, - 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, - 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, - 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, 0x61, 0x8c, 0xe5, 0x98, 0xfa, - 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, 0x63, 0xeb, 0xbf, 0xab, - 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, 0x78, 0x19, 0x66, - 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, 0x99, 0xc7, - 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, 0x93, - 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, - 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, - 0xcb, 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, - 0xf3, 0x4a, 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, - 0x23, 0x7f, 0x11, 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, - 0x95, 0x9e, 0xd6, 0x85, 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, - 0x29, 0x26, 0xdc, 0xb4, 0x0e, 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, - 0xb6, 0x08, 0x46, 0x36, 0xa0, - ]; - let server_key_exchange = vec![ - 0x0c, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, - 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, - 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, 0x04, 0x01, 0x01, 0x00, 0x04, - 0x02, 0xb6, 0x61, 0xf7, 0xc1, 0x91, 0xee, 0x59, 0xbe, 0x45, 0x37, 0x66, 0x39, 0xbd, 0xc3, - 0xd4, 0xbb, 0x81, 0xe1, 0x15, 0xca, 0x73, 0xc8, 0x34, 0x8b, 0x52, 0x5b, 0x0d, 0x23, 0x38, - 0xaa, 0x14, 0x46, 0x67, 0xed, 0x94, 0x31, 0x02, 0x14, 0x12, 0xcd, 0x9b, 0x84, 0x4c, 0xba, - 0x29, 0x93, 0x4a, 0xaa, 0xcc, 0xe8, 0x73, 0x41, 0x4e, 0xc1, 0x1c, 0xb0, 0x2e, 0x27, 0x2d, - 0x0a, 0xd8, 0x1f, 0x76, 0x7d, 0x33, 0x07, 0x67, 0x21, 0xf1, 0x3b, 0xf3, 0x60, 0x20, 0xcf, - 0x0b, 0x1f, 0xd0, 0xec, 0xb0, 0x78, 0xde, 0x11, 0x28, 0xbe, 0xba, 0x09, 0x49, 0xeb, 0xec, - 0xe1, 0xa1, 0xf9, 0x6e, 0x20, 0x9d, 0xc3, 0x6e, 0x4f, 0xff, 0xd3, 0x6b, 0x67, 0x3a, 0x7d, - 0xdc, 0x15, 0x97, 0xad, 0x44, 0x08, 0xe4, 0x85, 0xc4, 0xad, 0xb2, 0xc8, 0x73, 0x84, 0x12, - 0x49, 0x37, 0x25, 0x23, 0x80, 0x9e, 0x43, 0x12, 0xd0, 0xc7, 0xb3, 0x52, 0x2e, 0xf9, 0x83, - 0xca, 0xc1, 0xe0, 0x39, 0x35, 0xff, 0x13, 0xa8, 0xe9, 0x6b, 0xa6, 0x81, 0xa6, 0x2e, 0x40, - 0xd3, 0xe7, 0x0a, 0x7f, 0xf3, 0x58, 0x66, 0xd3, 0xd9, 0x99, 0x3f, 0x9e, 0x26, 0xa6, 0x34, - 0xc8, 0x1b, 0x4e, 0x71, 0x38, 0x0f, 0xcd, 0xd6, 0xf4, 0xe8, 0x35, 0xf7, 0x5a, 0x64, 0x09, - 0xc7, 0xdc, 0x2c, 0x07, 0x41, 0x0e, 0x6f, 0x87, 0x85, 0x8c, 0x7b, 0x94, 0xc0, 0x1c, 0x2e, - 0x32, 0xf2, 0x91, 0x76, 0x9e, 0xac, 0xca, 0x71, 0x64, 0x3b, 0x8b, 0x98, 0xa9, 0x63, 0xdf, - 0x0a, 0x32, 0x9b, 0xea, 0x4e, 0xd6, 0x39, 0x7e, 0x8c, 0xd0, 0x1a, 0x11, 0x0a, 0xb3, 0x61, - 0xac, 0x5b, 0xad, 0x1c, 0xcd, 0x84, 0x0a, 0x6c, 0x8a, 0x6e, 0xaa, 0x00, 0x1a, 0x9d, 0x7d, - 0x87, 0xdc, 0x33, 0x18, 0x64, 0x35, 0x71, 0x22, 0x6c, 0x4d, 0xd2, 0xc2, 0xac, 0x41, 0xfb, - ]; - let server_hello_done = vec![0x0e, 0x00, 0x00, 0x00]; - let client_key_exchange = vec![ - 0x10, 0x00, 0x00, 0x21, 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, - 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, - 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54, - ]; - - let mut final_msg = vec![]; - final_msg.extend_from_slice(&client_hello); - final_msg.extend_from_slice(&server_hello); - final_msg.extend_from_slice(&server_certificate); - final_msg.extend_from_slice(&server_key_exchange); - final_msg.extend_from_slice(&server_hello_done); - final_msg.extend_from_slice(&client_key_exchange); - - let master_secret = vec![ - 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, - 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, - 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, - 0x37, 0x34, 0x4c, - ]; - - let expected_verify_data = vec![ - 0xcf, 0x91, 0x96, 0x26, 0xf1, 0x36, 0x0c, 0x53, 0x6a, 0xaa, 0xd7, 0x3a, - ]; - - let verify_data = prf_verify_data_client(&master_secret, &final_msg, CipherSuiteHash::Sha256)?; - - assert_eq!( - expected_verify_data, verify_data, - "verify_data exp: {expected_verify_data:?} actual: {verify_data:?}" - ); - - Ok(()) -} diff --git a/dtls/src/record_layer/mod.rs b/dtls/src/record_layer/mod.rs deleted file mode 100644 index 8be337db7..000000000 --- a/dtls/src/record_layer/mod.rs +++ /dev/null @@ -1,107 +0,0 @@ -pub mod record_layer_header; - -#[cfg(test)] -mod record_layer_test; - -use std::io::{Read, Write}; - -use record_layer_header::*; - -use super::content::*; -use super::error::*; -use crate::alert::Alert; -use crate::application_data::ApplicationData; -use crate::change_cipher_spec::ChangeCipherSpec; -use crate::handshake::Handshake; - -/* - The TLS Record Layer which handles all data transport. - The record layer is assumed to sit directly on top of some - reliable transport such as TCP. The record layer can carry four types of content: - - 1. Handshake messagesโ€”used for algorithm negotiation and key establishment. - 2. ChangeCipherSpec messagesโ€”really part of the handshake but technically a separate kind of message. - 3. Alert messagesโ€”used to signal that errors have occurred - 4. Application layer data - - The DTLS record layer is extremely similar to that of TLS 1.1. The - only change is the inclusion of an explicit sequence number in the - record. This sequence number allows the recipient to correctly - verify the TLS MAC. - https://tools.ietf.org/html/rfc4347#section-4.1 -*/ -#[derive(Debug, Clone, PartialEq)] -pub struct RecordLayer { - pub record_layer_header: RecordLayerHeader, - pub content: Content, -} - -impl RecordLayer { - pub fn new(protocol_version: ProtocolVersion, epoch: u16, content: Content) -> Self { - RecordLayer { - record_layer_header: RecordLayerHeader { - content_type: content.content_type(), - protocol_version, - epoch, - sequence_number: 0, - content_len: content.size() as u16, - }, - content, - } - } - - pub fn marshal(&self, writer: &mut W) -> Result<()> { - self.record_layer_header.marshal(writer)?; - self.content.marshal(writer)?; - Ok(()) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let record_layer_header = RecordLayerHeader::unmarshal(reader)?; - let content = match record_layer_header.content_type { - ContentType::Alert => Content::Alert(Alert::unmarshal(reader)?), - ContentType::ApplicationData => { - Content::ApplicationData(ApplicationData::unmarshal(reader)?) - } - ContentType::ChangeCipherSpec => { - Content::ChangeCipherSpec(ChangeCipherSpec::unmarshal(reader)?) - } - ContentType::Handshake => Content::Handshake(Handshake::unmarshal(reader)?), - _ => return Err(Error::Other("Invalid Content Type".to_owned())), - }; - - Ok(RecordLayer { - record_layer_header, - content, - }) - } -} - -// Note that as with TLS, multiple handshake messages may be placed in -// the same DTLS record, provided that there is room and that they are -// part of the same flight. Thus, there are two acceptable ways to pack -// two DTLS messages into the same datagram: in the same record or in -// separate records. -// https://tools.ietf.org/html/rfc6347#section-4.2.3 -pub(crate) fn unpack_datagram(buf: &[u8]) -> Result>> { - let mut out = vec![]; - - let mut offset = 0; - while buf.len() != offset { - if buf.len() - offset <= RECORD_LAYER_HEADER_SIZE { - return Err(Error::ErrInvalidPacketLength); - } - - let pkt_len = RECORD_LAYER_HEADER_SIZE - + (((buf[offset + RECORD_LAYER_HEADER_SIZE - 2] as usize) << 8) - | buf[offset + RECORD_LAYER_HEADER_SIZE - 1] as usize); - if offset + pkt_len > buf.len() { - return Err(Error::ErrInvalidPacketLength); - } - - out.push(buf[offset..offset + pkt_len].to_vec()); - offset += pkt_len - } - - Ok(out) -} diff --git a/dtls/src/record_layer/record_layer_header.rs b/dtls/src/record_layer/record_layer_header.rs deleted file mode 100644 index 571168248..000000000 --- a/dtls/src/record_layer/record_layer_header.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::io::{Read, Write}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; - -use crate::content::*; -use crate::error::*; - -pub const RECORD_LAYER_HEADER_SIZE: usize = 13; -pub const MAX_SEQUENCE_NUMBER: u64 = 0x0000FFFFFFFFFFFF; - -pub const DTLS1_2MAJOR: u8 = 0xfe; -pub const DTLS1_2MINOR: u8 = 0xfd; - -pub const DTLS1_0MAJOR: u8 = 0xfe; -pub const DTLS1_0MINOR: u8 = 0xff; - -// VERSION_DTLS12 is the DTLS version in the same style as -// VersionTLSXX from crypto/tls -pub const VERSION_DTLS12: u16 = 0xfefd; - -pub const PROTOCOL_VERSION1_0: ProtocolVersion = ProtocolVersion { - major: DTLS1_0MAJOR, - minor: DTLS1_0MINOR, -}; -pub const PROTOCOL_VERSION1_2: ProtocolVersion = ProtocolVersion { - major: DTLS1_2MAJOR, - minor: DTLS1_2MINOR, -}; - -// https://tools.ietf.org/html/rfc4346#section-6.2.1 -#[derive(Copy, Clone, PartialEq, Eq, Debug, Default)] -pub struct ProtocolVersion { - pub major: u8, - pub minor: u8, -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug, Default)] -pub struct RecordLayerHeader { - pub content_type: ContentType, - pub protocol_version: ProtocolVersion, - pub epoch: u16, - pub sequence_number: u64, // uint48 in spec - pub content_len: u16, -} - -impl RecordLayerHeader { - pub fn marshal(&self, writer: &mut W) -> Result<()> { - if self.sequence_number > MAX_SEQUENCE_NUMBER { - return Err(Error::ErrSequenceNumberOverflow); - } - - writer.write_u8(self.content_type as u8)?; - writer.write_u8(self.protocol_version.major)?; - writer.write_u8(self.protocol_version.minor)?; - writer.write_u16::(self.epoch)?; - - let be: [u8; 8] = self.sequence_number.to_be_bytes(); - writer.write_all(&be[2..])?; // uint48 in spec - - writer.write_u16::(self.content_len)?; - - Ok(writer.flush()?) - } - - pub fn unmarshal(reader: &mut R) -> Result { - let content_type = reader.read_u8()?.into(); - let major = reader.read_u8()?; - let minor = reader.read_u8()?; - let epoch = reader.read_u16::()?; - - // SequenceNumber is stored as uint48, make into uint64 - let mut be: [u8; 8] = [0u8; 8]; - reader.read_exact(&mut be[2..])?; - let sequence_number = u64::from_be_bytes(be); - - let protocol_version = ProtocolVersion { major, minor }; - if protocol_version != PROTOCOL_VERSION1_0 && protocol_version != PROTOCOL_VERSION1_2 { - return Err(Error::ErrUnsupportedProtocolVersion); - } - let content_len = reader.read_u16::()?; - - Ok(RecordLayerHeader { - content_type, - protocol_version, - epoch, - sequence_number, - content_len, - }) - } -} diff --git a/dtls/src/record_layer/record_layer_test.rs b/dtls/src/record_layer/record_layer_test.rs deleted file mode 100644 index 3e644e503..000000000 --- a/dtls/src/record_layer/record_layer_test.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::record_layer_header::*; -use super::*; -use crate::change_cipher_spec::ChangeCipherSpec; - -#[test] -fn test_udp_decode() -> Result<()> { - let tests = vec![ - ( - "Change Cipher Spec, single packet", - vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, - ], - vec![vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, - ]], - None, - ), - ( - "Change Cipher Spec, multi packet", - vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01, - ], - vec![ - vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, - 0x01, - ], - vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, - 0x01, - ], - ], - None, - ), - ( - "Invalid packet length", - vec![0x14, 0xfe], - vec![], - Some(Error::ErrInvalidPacketLength), - ), - ( - "Packet declared invalid length", - vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01, - ], - vec![], - Some(Error::ErrInvalidPacketLength), - ), - ]; - - for (name, data, wanted, wanted_err) in tests { - let dtls_pkts = unpack_datagram(&data); - if let Some(err) = wanted_err { - if let Err(dtls) = dtls_pkts { - assert_eq!(err.to_string(), dtls.to_string()); - } else { - panic!("something wrong for {name} when wanted_err is Some"); - } - } else if let Ok(pkts) = dtls_pkts { - assert_eq!( - wanted, pkts, - "{name} UDP decode: got {pkts:?}, want {wanted:?}", - ); - } else { - panic!("something wrong for {name} when wanted_err is None"); - } - } - - Ok(()) -} - -#[test] -fn test_record_layer_round_trip() -> Result<()> { - let tests = vec![( - "Change Cipher Spec, single packet", - vec![ - 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, - ], - RecordLayer { - record_layer_header: RecordLayerHeader { - content_type: ContentType::ChangeCipherSpec, - protocol_version: ProtocolVersion { - major: 0xfe, - minor: 0xff, - }, - epoch: 0, - sequence_number: 18, - content_len: 1, - }, - content: Content::ChangeCipherSpec(ChangeCipherSpec {}), - }, - )]; - - for (name, data, want) in tests { - let mut reader = BufReader::new(data.as_slice()); - let r = RecordLayer::unmarshal(&mut reader)?; - - assert_eq!( - want, r, - "{name} recordLayer.unmarshal: got {r:?}, want {want:?}" - ); - - let mut data2 = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(data2.as_mut()); - r.marshal(&mut writer)?; - } - assert_eq!( - data, data2, - "{name} recordLayer.marshal: got {data2:?}, want {data:?}" - ); - } - - Ok(()) -} diff --git a/dtls/src/signature_hash_algorithm/mod.rs b/dtls/src/signature_hash_algorithm/mod.rs deleted file mode 100644 index 41e5dab34..000000000 --- a/dtls/src/signature_hash_algorithm/mod.rs +++ /dev/null @@ -1,215 +0,0 @@ -#[cfg(test)] -mod signature_hash_algorithm_test; - -use std::fmt; - -use crate::crypto::*; -use crate::error::*; - -// HashAlgorithm is used to indicate the hash algorithm used -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18 -// Supported hash hash algorithms -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum HashAlgorithm { - Md2 = 0, // Blacklisted - Md5 = 1, // Blacklisted - Sha1 = 2, // Blacklisted - Sha224 = 3, - Sha256 = 4, - Sha384 = 5, - Sha512 = 6, - Ed25519 = 8, - Unsupported, -} - -impl From for HashAlgorithm { - fn from(val: u8) -> Self { - match val { - 0 => HashAlgorithm::Md2, - 1 => HashAlgorithm::Md5, - 2 => HashAlgorithm::Sha1, - 3 => HashAlgorithm::Sha224, - 4 => HashAlgorithm::Sha256, - 5 => HashAlgorithm::Sha384, - 6 => HashAlgorithm::Sha512, - 8 => HashAlgorithm::Ed25519, - _ => HashAlgorithm::Unsupported, - } - } -} - -impl fmt::Display for HashAlgorithm { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - HashAlgorithm::Md2 => write!(f, "md2"), - HashAlgorithm::Md5 => write!(f, "md5"), // [RFC3279] - HashAlgorithm::Sha1 => write!(f, "sha-1"), // [RFC3279] - HashAlgorithm::Sha224 => write!(f, "sha-224"), // [RFC4055] - HashAlgorithm::Sha256 => write!(f, "sha-256"), // [RFC4055] - HashAlgorithm::Sha384 => write!(f, "sha-384"), // [RFC4055] - HashAlgorithm::Sha512 => write!(f, "sha-512"), // [RFC4055] - HashAlgorithm::Ed25519 => write!(f, "null"), // [RFC4055] - _ => write!(f, "unknown or unsupported hash algorithm"), - } - } -} - -impl HashAlgorithm { - pub(crate) fn insecure(&self) -> bool { - matches!( - *self, - HashAlgorithm::Md2 | HashAlgorithm::Md5 | HashAlgorithm::Sha1 - ) - } - - pub(crate) fn invalid(&self) -> bool { - matches!(*self, HashAlgorithm::Md2) - } -} - -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16 -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum SignatureAlgorithm { - Rsa = 1, - Ecdsa = 3, - Ed25519 = 7, - Unsupported, -} - -impl From for SignatureAlgorithm { - fn from(val: u8) -> Self { - match val { - 1 => SignatureAlgorithm::Rsa, - 3 => SignatureAlgorithm::Ecdsa, - 7 => SignatureAlgorithm::Ed25519, - _ => SignatureAlgorithm::Unsupported, - } - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct SignatureHashAlgorithm { - pub hash: HashAlgorithm, - pub signature: SignatureAlgorithm, -} - -impl SignatureHashAlgorithm { - // is_compatible checks that given private key is compatible with the signature scheme. - pub(crate) fn is_compatible(&self, private_key: &CryptoPrivateKey) -> bool { - match &private_key.kind { - CryptoPrivateKeyKind::Ed25519(_) => self.signature == SignatureAlgorithm::Ed25519, - CryptoPrivateKeyKind::Ecdsa256(_) => self.signature == SignatureAlgorithm::Ecdsa, - CryptoPrivateKeyKind::Rsa256(_) => self.signature == SignatureAlgorithm::Rsa, - } - } -} - -pub(crate) fn default_signature_schemes() -> Vec { - vec![ - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Ed25519, - signature: SignatureAlgorithm::Ed25519, - }, - ] -} - -// select Signature Scheme returns most preferred and compatible scheme. -pub(crate) fn select_signature_scheme( - sigs: &[SignatureHashAlgorithm], - private_key: &CryptoPrivateKey, -) -> Result { - for ss in sigs { - if ss.is_compatible(private_key) { - return Ok(*ss); - } - } - - Err(Error::ErrNoAvailableSignatureSchemes) -} - -// SignatureScheme identifies a signature algorithm supported by TLS. See -// RFC 8446, Section 4.2.3. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum SignatureScheme { - // RSASSA-PKCS1-v1_5 algorithms. - Pkcs1WithSha256 = 0x0401, - Pkcs1WithSha384 = 0x0501, - Pkcs1WithSha512 = 0x0601, - - // RSASSA-PSS algorithms with public key OID rsaEncryption. - PssWithSha256 = 0x0804, - PssWithSha384 = 0x0805, - PssWithSha512 = 0x0806, - - // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. - EcdsaWithP256AndSha256 = 0x0403, - EcdsaWithP384AndSha384 = 0x0503, - EcdsaWithP521AndSha512 = 0x0603, - - // EdDSA algorithms. - Ed25519 = 0x0807, - - // Legacy signature and hash algorithms for TLS 1.2. - Pkcs1WithSha1 = 0x0201, - EcdsaWithSha1 = 0x0203, -} - -// parse_signature_schemes translates []tls.SignatureScheme to []signatureHashAlgorithm. -// It returns default signature scheme list if no SignatureScheme is passed. -pub(crate) fn parse_signature_schemes( - sigs: &[u16], - insecure_hashes: bool, -) -> Result> { - if sigs.is_empty() { - return Ok(default_signature_schemes()); - } - - let mut out = vec![]; - for ss in sigs { - let sig: SignatureAlgorithm = ((*ss & 0xFF) as u8).into(); - if sig == SignatureAlgorithm::Unsupported { - return Err(Error::ErrInvalidSignatureAlgorithm); - } - let h: HashAlgorithm = (((*ss >> 8) & 0xFF) as u8).into(); - if h == HashAlgorithm::Unsupported || h.invalid() { - return Err(Error::ErrInvalidHashAlgorithm); - } - if h.insecure() && !insecure_hashes { - continue; - } - out.push(SignatureHashAlgorithm { - hash: h, - signature: sig, - }) - } - - if out.is_empty() { - Err(Error::ErrNoAvailableSignatureSchemes) - } else { - Ok(out) - } -} diff --git a/dtls/src/signature_hash_algorithm/signature_hash_algorithm_test.rs b/dtls/src/signature_hash_algorithm/signature_hash_algorithm_test.rs deleted file mode 100644 index 775ef3277..000000000 --- a/dtls/src/signature_hash_algorithm/signature_hash_algorithm_test.rs +++ /dev/null @@ -1,141 +0,0 @@ -use super::*; - -#[test] -fn test_parse_signature_schemes() -> Result<()> { - let tests = vec![ - ( - "Translate", - vec![ - SignatureScheme::EcdsaWithP256AndSha256 as u16, - SignatureScheme::EcdsaWithP384AndSha384 as u16, - SignatureScheme::EcdsaWithP521AndSha512 as u16, - SignatureScheme::Pkcs1WithSha256 as u16, - SignatureScheme::Pkcs1WithSha384 as u16, - SignatureScheme::Pkcs1WithSha512 as u16, - ], - vec![ - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha384, - signature: SignatureAlgorithm::Rsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha512, - signature: SignatureAlgorithm::Rsa, - }, - ], - false, - None, - ), - ( - "InvalidSignatureAlgorithm", - vec![ - SignatureScheme::EcdsaWithP256AndSha256 as u16, // Valid - 0x04FF, // Invalid: unknown signature with SHA-256 - ], - vec![], - false, - Some(Error::ErrInvalidSignatureAlgorithm), - ), - ( - "InvalidHashAlgorithm", - vec![ - SignatureScheme::EcdsaWithP256AndSha256 as u16, // Valid - 0x0003, // Invalid: ECDSA with MD2 - ], - vec![], - false, - Some(Error::ErrInvalidHashAlgorithm), - ), - ( - "InsecureHashAlgorithmDenied", - vec![ - SignatureScheme::EcdsaWithP256AndSha256 as u16, // Valid - SignatureScheme::EcdsaWithSha1 as u16, // Insecure - ], - vec![SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }], - false, - None, - ), - ( - "InsecureHashAlgorithmAllowed", - vec![ - SignatureScheme::EcdsaWithP256AndSha256 as u16, // Valid - SignatureScheme::EcdsaWithSha1 as u16, // Insecure - ], - vec![ - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha256, - signature: SignatureAlgorithm::Ecdsa, - }, - SignatureHashAlgorithm { - hash: HashAlgorithm::Sha1, - signature: SignatureAlgorithm::Ecdsa, - }, - ], - true, - None, - ), - ( - "OnlyInsecureHashAlgorithm", - vec![ - SignatureScheme::EcdsaWithSha1 as u16, // Insecure - ], - vec![], - false, - Some(Error::ErrNoAvailableSignatureSchemes), - ), - ( - "Translate", - vec![SignatureScheme::Ed25519 as u16], - vec![SignatureHashAlgorithm { - hash: HashAlgorithm::Ed25519, - signature: SignatureAlgorithm::Ed25519, - }], - false, - None, - ), - ]; - - for (name, inputs, expected, insecure_hashes, want_err) in tests { - let output = parse_signature_schemes(&inputs, insecure_hashes); - if let Some(err) = want_err { - if let Err(output_err) = output { - assert_eq!( - err.to_string(), - output_err.to_string(), - "Expected error: {err:?}, got: {output_err:?}" - ); - } else { - panic!("expect err, but got non-err for {name}"); - } - } else if let Ok(output_val) = output { - assert_eq!( - expected, output_val, - "Expected signatureHashAlgorithm:\n{expected:?}\ngot:\n{output_val:?}", - ); - } else { - panic!("expect non-err, but got err for {name}"); - } - } - - Ok(()) -} diff --git a/dtls/src/state.rs b/dtls/src/state.rs deleted file mode 100644 index 2d8fbc864..000000000 --- a/dtls/src/state.rs +++ /dev/null @@ -1,306 +0,0 @@ -use std::io::{BufWriter, Cursor}; -use std::marker::{Send, Sync}; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use portable_atomic::AtomicU16; -use serde::{Deserialize, Serialize}; -use tokio::sync::Mutex; -use util::{KeyingMaterialExporter, KeyingMaterialExporterError}; - -use super::cipher_suite::*; -use super::conn::*; -use super::curve::named_curve::*; -use super::extension::extension_use_srtp::SrtpProtectionProfile; -use super::handshake::handshake_random::*; -use super::prf::*; -use crate::error::*; - -// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler -pub struct State { - pub(crate) local_epoch: Arc, - pub(crate) remote_epoch: Arc, - pub(crate) local_sequence_number: Arc>>, // uint48 - pub(crate) local_random: HandshakeRandom, - pub(crate) remote_random: HandshakeRandom, - pub(crate) master_secret: Vec, - pub(crate) cipher_suite: Arc>>>, // nil if a cipher_suite hasn't been chosen - - pub(crate) srtp_protection_profile: SrtpProtectionProfile, // Negotiated srtp_protection_profile - pub peer_certificates: Vec>, - pub identity_hint: Vec, - - pub(crate) is_client: bool, - - pub(crate) pre_master_secret: Vec, - pub(crate) extended_master_secret: bool, - - pub(crate) named_curve: NamedCurve, - pub(crate) local_keypair: Option, - pub(crate) cookie: Vec, - pub(crate) handshake_send_sequence: isize, - pub(crate) handshake_recv_sequence: isize, - pub(crate) server_name: String, - pub(crate) remote_requested_certificate: bool, // Did we get a CertificateRequest - pub(crate) local_certificates_verify: Vec, // cache CertificateVerify - pub(crate) local_verify_data: Vec, // cached VerifyData - pub(crate) local_key_signature: Vec, // cached keySignature - pub(crate) peer_certificates_verified: bool, - //pub(crate) replay_detector: Vec>, -} - -#[derive(Serialize, Deserialize, PartialEq, Debug)] -struct SerializedState { - local_epoch: u16, - remote_epoch: u16, - local_random: [u8; HANDSHAKE_RANDOM_LENGTH], - remote_random: [u8; HANDSHAKE_RANDOM_LENGTH], - cipher_suite_id: u16, - master_secret: Vec, - sequence_number: u64, - srtp_protection_profile: u16, - peer_certificates: Vec>, - identity_hint: Vec, - is_client: bool, -} - -impl Default for State { - fn default() -> Self { - State { - local_epoch: Arc::new(AtomicU16::new(0)), - remote_epoch: Arc::new(AtomicU16::new(0)), - local_sequence_number: Arc::new(Mutex::new(vec![])), - local_random: HandshakeRandom::default(), - remote_random: HandshakeRandom::default(), - master_secret: vec![], - cipher_suite: Arc::new(Mutex::new(None)), // nil if a cipher_suite hasn't been chosen - - srtp_protection_profile: SrtpProtectionProfile::Unsupported, // Negotiated srtp_protection_profile - peer_certificates: vec![], - identity_hint: vec![], - - is_client: false, - - pre_master_secret: vec![], - extended_master_secret: false, - - named_curve: NamedCurve::Unsupported, - local_keypair: None, - cookie: vec![], - handshake_send_sequence: 0, - handshake_recv_sequence: 0, - server_name: "".to_string(), - remote_requested_certificate: false, // Did we get a CertificateRequest - local_certificates_verify: vec![], // cache CertificateVerify - local_verify_data: vec![], // cached VerifyData - local_key_signature: vec![], // cached keySignature - peer_certificates_verified: false, - //replay_detector: vec![], - } - } -} - -impl State { - pub(crate) async fn clone(&self) -> Self { - let mut state = State::default(); - - if let Ok(serialized) = self.serialize().await { - let _ = state.deserialize(&serialized).await; - } - - state - } - - async fn serialize(&self) -> Result { - let mut local_rand = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(local_rand.as_mut()); - self.local_random.marshal(&mut writer)?; - } - let mut remote_rand = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(remote_rand.as_mut()); - self.remote_random.marshal(&mut writer)?; - } - - let mut local_random = [0u8; HANDSHAKE_RANDOM_LENGTH]; - let mut remote_random = [0u8; HANDSHAKE_RANDOM_LENGTH]; - - local_random.copy_from_slice(&local_rand); - remote_random.copy_from_slice(&remote_rand); - - let local_epoch = self.local_epoch.load(Ordering::SeqCst); - let remote_epoch = self.remote_epoch.load(Ordering::SeqCst); - let sequence_number = { - let lsn = self.local_sequence_number.lock().await; - lsn[local_epoch as usize] - }; - let cipher_suite_id = { - let cipher_suite = self.cipher_suite.lock().await; - match &*cipher_suite { - Some(cipher_suite) => cipher_suite.id() as u16, - None => return Err(Error::ErrCipherSuiteUnset), - } - }; - - Ok(SerializedState { - local_epoch, - remote_epoch, - local_random, - remote_random, - cipher_suite_id, - master_secret: self.master_secret.clone(), - sequence_number, - srtp_protection_profile: self.srtp_protection_profile as u16, - peer_certificates: self.peer_certificates.clone(), - identity_hint: self.identity_hint.clone(), - is_client: self.is_client, - }) - } - - async fn deserialize(&mut self, serialized: &SerializedState) -> Result<()> { - // Set epoch values - self.local_epoch - .store(serialized.local_epoch, Ordering::SeqCst); - self.remote_epoch - .store(serialized.remote_epoch, Ordering::SeqCst); - { - let mut lsn = self.local_sequence_number.lock().await; - while lsn.len() <= serialized.local_epoch as usize { - lsn.push(0); - } - lsn[serialized.local_epoch as usize] = serialized.sequence_number; - } - - // Set random values - let mut reader = Cursor::new(&serialized.local_random); - self.local_random = HandshakeRandom::unmarshal(&mut reader)?; - - let mut reader = Cursor::new(&serialized.remote_random); - self.remote_random = HandshakeRandom::unmarshal(&mut reader)?; - - self.is_client = serialized.is_client; - - // Set master secret - self.master_secret.clone_from(&serialized.master_secret); - - // Set cipher suite - self.cipher_suite = Arc::new(Mutex::new(Some(cipher_suite_for_id( - serialized.cipher_suite_id.into(), - )?))); - - self.srtp_protection_profile = serialized.srtp_protection_profile.into(); - - // Set remote certificate - self.peer_certificates - .clone_from(&serialized.peer_certificates); - self.identity_hint.clone_from(&serialized.identity_hint); - - Ok(()) - } - - pub async fn init_cipher_suite(&mut self) -> Result<()> { - let mut cipher_suite = self.cipher_suite.lock().await; - if let Some(cipher_suite) = &mut *cipher_suite { - if cipher_suite.is_initialized() { - return Ok(()); - } - - let mut local_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(local_random.as_mut()); - self.local_random.marshal(&mut writer)?; - } - let mut remote_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(remote_random.as_mut()); - self.remote_random.marshal(&mut writer)?; - } - - if self.is_client { - cipher_suite.init(&self.master_secret, &local_random, &remote_random, true) - } else { - cipher_suite.init(&self.master_secret, &remote_random, &local_random, false) - } - } else { - Err(Error::ErrCipherSuiteUnset) - } - } - - // marshal_binary is a binary.BinaryMarshaler.marshal_binary implementation - pub async fn marshal_binary(&self) -> Result> { - let serialized = self.serialize().await?; - - match bincode::serialize(&serialized) { - Ok(enc) => Ok(enc), - Err(err) => Err(Error::Other(err.to_string())), - } - } - - // unmarshal_binary is a binary.BinaryUnmarshaler.unmarshal_binary implementation - pub async fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> { - let serialized: SerializedState = match bincode::deserialize(data) { - Ok(dec) => dec, - Err(err) => return Err(Error::Other(err.to_string())), - }; - self.deserialize(&serialized).await?; - self.init_cipher_suite().await?; - - Ok(()) - } -} - -#[async_trait] -impl KeyingMaterialExporter for State { - /// export_keying_material returns length bytes of exported key material in a new - /// slice as defined in RFC 5705. - /// This allows protocols to use DTLS for key establishment, but - /// then use some of the keying material for their own purposes - async fn export_keying_material( - &self, - label: &str, - context: &[u8], - length: usize, - ) -> std::result::Result, KeyingMaterialExporterError> { - use KeyingMaterialExporterError::*; - - if self.local_epoch.load(Ordering::SeqCst) == 0 { - return Err(HandshakeInProgress); - } else if !context.is_empty() { - return Err(ContextUnsupported); - } else if INVALID_KEYING_LABELS.contains(&label) { - return Err(ReservedExportKeyingMaterial); - } - - let mut local_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(local_random.as_mut()); - self.local_random.marshal(&mut writer)?; - } - let mut remote_random = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(remote_random.as_mut()); - self.remote_random.marshal(&mut writer)?; - } - - let mut seed = label.as_bytes().to_vec(); - if self.is_client { - seed.extend_from_slice(&local_random); - seed.extend_from_slice(&remote_random); - } else { - seed.extend_from_slice(&remote_random); - seed.extend_from_slice(&local_random); - } - - let cipher_suite = self.cipher_suite.lock().await; - if let Some(cipher_suite) = &*cipher_suite { - match prf_p_hash(&self.master_secret, &seed, length, cipher_suite.hash_func()) { - Ok(v) => Ok(v), - Err(err) => Err(Hash(err.to_string())), - } - } else { - Err(CipherSuiteUnset) - } - } -} diff --git a/examples/.gitignore b/examples/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/examples/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/examples/Cargo.toml b/examples/Cargo.toml deleted file mode 100644 index 04733ac77..000000000 --- a/examples/Cargo.toml +++ /dev/null @@ -1,153 +0,0 @@ -[package] -name = "examples" -version = "0.5.0" -authors = ["Rain Liu "] -edition = "2021" -description = "Examples of WebRTC.rs stack" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/examples" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/examples" - -[dependencies] - - -[dev-dependencies] -webrtc = { path = "../webrtc" } - -tokio = { version = "1.32.0", features = ["full"] } -env_logger = "0.10" -clap = "3" -hyper = { version = "0.14.27", features = ["full"] } -signal = { path = "examples/signal" } -tokio-util = { version = "0.7", features = ["codec"] } -anyhow = "1" -chrono = "0.4.28" -log = "0.4" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -bytes = "1" -lazy_static = "1" -rand = "0.8" - -memchr = "2.1.1" - -[[example]] -name = "rc-cycle" -path = "examples/rc-cycle/rc-cycle.rs" -bench = false - -[[example]] -name = "broadcast" -path = "examples/broadcast/broadcast.rs" -bench = false - -[[example]] -name = "data-channels" -path = "examples/data-channels/data-channels.rs" -bench = false - -[[example]] -name = "data-channels-close" -path = "examples/data-channels-close/data-channels-close.rs" -bench = false - -[[example]] -name = "data-channels-create" -path = "examples/data-channels-create/data-channels-create.rs" -bench = false - -[[example]] -name = "data-channels-detach" -path = "examples/data-channels-detach/data-channels-detach.rs" -bench = false - -[[example]] -name = "data-channels-detach-create" -path = "examples/data-channels-detach-create/data-channels-detach-create.rs" -bench = false - -[[example]] -name = "data-channels-flow-control" -path = "examples/data-channels-flow-control/data-channels-flow-control.rs" -bench = false - -[[example]] -name = "insertable-streams" -path = "examples/insertable-streams/insertable-streams.rs" -bench = false - -[[example]] -name = "play-from-disk-vpx" -path = "examples/play-from-disk-vpx/play-from-disk-vpx.rs" -bench = false - -[[example]] -name = "play-from-disk-h264" -path = "examples/play-from-disk-h264/play-from-disk-h264.rs" -bench = false - -[[example]] -name = "play-from-disk-hevc" -path = "examples/play-from-disk-hevc/play-from-disk-hevc.rs" -bench = false - -[[example]] -name = "play-from-disk-renegotiation" -path = "examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs" -bench = false - -[[example]] -name = "reflect" -path = "examples/reflect/reflect.rs" -bench = false - -[[example]] -name = "rtp-forwarder" -path = "examples/rtp-forwarder/rtp-forwarder.rs" -bench = false - -[[example]] -name = "rtp-to-webrtc" -path = "examples/rtp-to-webrtc/rtp-to-webrtc.rs" -bench = false - -[[example]] -name = "save-to-disk-vpx" -path = "examples/save-to-disk-vpx/save-to-disk-vpx.rs" -bench = false - -[[example]] -name = "save-to-disk-h264" -path = "examples/save-to-disk-h264/save-to-disk-h264.rs" -bench = false - -[[example]] -name = "simulcast" -path = "examples/simulcast/simulcast.rs" -bench = false - -[[example]] -name = "swap-tracks" -path = "examples/swap-tracks/swap-tracks.rs" -bench = false - -[[example]] -name = "ortc" -path = "examples/ortc/ortc.rs" -bench = false - -[[example]] -name = "offer" -path = "examples/offer-answer/offer.rs" -bench = false - -[[example]] -name = "answer" -path = "examples/offer-answer/answer.rs" -bench = false - -[[example]] -name = "ice-restart" -path = "examples/ice-restart/ice-restart.rs" -bench = false diff --git a/examples/LICENSE-APACHE b/examples/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/examples/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/examples/LICENSE-MIT b/examples/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/examples/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index 173a7a724..000000000 --- a/examples/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- Examples of WebRTC.rs stack. Rewrite Pion Examples in Rust -

diff --git a/examples/codecov.yml b/examples/codecov.yml deleted file mode 100644 index bb738da7e..000000000 --- a/examples/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: ec7b9766-689c-46bf-99fe-6c8e9971d20b - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/examples/doc/webrtc.rs.png b/examples/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/examples/doc/webrtc.rs.png and /dev/null differ diff --git a/examples/examples/README.md b/examples/examples/README.md deleted file mode 100644 index 4382f6635..000000000 --- a/examples/examples/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- Examples -

- -All examples are ported from [Pion](https://github.com/pion/webrtc/tree/master/examples#readme). Please check [Pion Examples](https://github.com/pion/webrtc/tree/master/examples#readme) for more details: - -#### Media API -- [x] [Reflect](reflect): The reflect example demonstrates how to have webrtc-rs send back to the user exactly what it receives using the same PeerConnection. -- [x] [Play from Disk VPx](play-from-disk-vpx): The play-from-disk-vp8 example demonstrates how to send VP8/VP9 video to your browser from a file saved to disk. -- [x] [Play from Disk H264](play-from-disk-h264): The play-from-disk-h264 example demonstrates how to send H264 video to your browser from a file saved to disk. -- [x] [Play from Disk Renegotiation](play-from-disk-renegotiation): The play-from-disk-renegotiation example is an extension of the play-from-disk example, but demonstrates how you can add/remove video tracks from an already negotiated PeerConnection. -- [x] [Insertable Streams](insertable-streams): The insertable-streams example demonstrates how webrtc-rs can be used to send E2E encrypted video and decrypt via insertable streams in the browser. -- [x] [Save to Disk VPx](save-to-disk-vpx): The save-to-disk example shows how to record your webcam and save the footage (VP8/VP9 for video, Opus for audio) to disk on the server side. -- [x] [Save to Disk H264](save-to-disk-h264): The save-to-disk example shows how to record your webcam and save the footage (H264 for video, Opus for audio) to disk on the server side. -- [x] [Broadcast](broadcast): The broadcast example demonstrates how to broadcast a video to multiple peers. A broadcaster uploads the video once and the server forwards it to all other peers. -- [x] [RTP Forwarder](rtp-forwarder): The rtp-forwarder example demonstrates how to forward your audio/video streams using RTP. -- [x] [RTP to WebRTC](rtp-to-webrtc): The rtp-to-webrtc example demonstrates how to take RTP packets sent to a webrtc-rs process into your browser. -- [x] [Simulcast](simulcast): The simulcast example demonstrates how to accept and demux 1 Track that contains 3 Simulcast streams. It then returns the media as 3 independent Tracks back to the sender. -- [x] [Swap Tracks](swap-tracks): The swap-tracks demonstrates how to swap multiple incoming tracks on a single outgoing track. - -#### Data Channel API -- [x] [Data Channels](data-channels): The data-channels example shows how you can send/recv DataChannel messages from a web browser. -- [x] [Data Channels Create](data-channels-create): Example data-channels-create shows how you can send/recv DataChannel messages from a web browser. The difference with the data-channels example is that the data channel is initialized from the server side in this example. -- [x] [Data Channels Close](data-channels-close): Example data-channels-close is a variant of data-channels that allow playing with the life cycle of data channels. -- [x] [Data Channels Detach](data-channels-detach): The data-channels-detach example shows how you can send/recv DataChannel messages using the underlying DataChannel implementation directly. This provides a more idiomatic way of interacting with Data Channels. -- [x] [Data Channels Detach Create](data-channels-detach-create): Example data-channels-detach-create shows how you can send/recv DataChannel messages using the underlying DataChannel implementation directly. This provides a more idiomatic way of interacting with Data Channels. The difference with the data-channels-detach example is that the data channel is initialized in this example. -- [x] [Data Channels Flow Control](data-channels-flow-control): Example data-channels-flow-control shows how to use flow control. -- [x] [ORTC](ortc): Example ortc shows how to use the ORTC API for DataChannel communication. -- [x] [Offer Answer](offer-answer): Example offer-answer is an example of two webrtc-rs or pion instances communicating directly! -- [x] [ICE Restart](ice-restart): The ice-restart demonstrates webrtc-rs ICE Restart abilities. diff --git a/examples/examples/broadcast/README.md b/examples/examples/broadcast/README.md deleted file mode 100644 index 21a003dc2..000000000 --- a/examples/examples/broadcast/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# broadcast - -broadcast is a WebRTC.rs application that demonstrates how to broadcast a video to many peers, while only requiring the broadcaster to upload once. - -This could serve as the building block to building conferencing software, and other applications where publishers are bandwidth constrained. - -## Instructions - -### Build broadcast - -```shell -cargo build --example broadcast -``` - -### Open broadcast example page - -[jsfiddle.net](https://jsfiddle.net/1jc4go7v/) You should see two buttons 'Publish a Broadcast' and 'Join a Broadcast' - -### Run Broadcast - -#### Linux/macOS - -Run `broadcast` - -### Start a publisher - -* Click `Publish a Broadcast` -* Copy the string in the first input labelled `Browser base64 Session Description` -* Run `curl localhost:8080/sdp -d "$BROWSER_OFFER"`. `$BROWSER_OFFER` is the value you copied in the last step. -* The `broadcast` terminal application will respond with an answer, paste this into the second input field in your browser. -* Press `Start Session` -* The connection state will be printed in the terminal and under `logs` in the browser. - -### Join the broadcast - -* Click `Join a Broadcast` -* Copy the string in the first input labelled `Browser base64 Session Description` -* Run `curl localhost:8080/sdp -d "$BROWSER_OFFER"`. `$BROWSER_OFFER` is the value you copied in the last step. -* The `broadcast` terminal application will respond with an answer, paste this into the second input field in your browser. -* Press `Start Session` -* The connection state will be printed in the terminal and under `logs` in the browser. - -You can change the listening port using `-port 8011` - -You can `Join the broadcast` as many times as you want. The `broadcast` application is relaying all traffic, so your browser only has to upload once. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/broadcast/broadcast.rs b/examples/examples/broadcast/broadcast.rs deleted file mode 100644 index 2626863d4..000000000 --- a/examples/examples/broadcast/broadcast.rs +++ /dev/null @@ -1,306 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::RTPCodecType; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; -use webrtc::Error; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("broadcast") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of broadcast.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("port") - .takes_value(true) - .default_value("8080") - .long("port") - .help("http server port."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let port = matches.value_of("port").unwrap().parse::()?; - let mut sdp_chan_rx = signal::http_sdp_server(port).await; - - // Wait for the offer - println!("wait for the offer from http_sdp_server\n"); - let line = sdp_chan_rx.recv().await.unwrap(); - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - //println!("Receive offer from http_sdp_server:\n{:?}", offer); - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Allow us to receive 1 video track - peer_connection - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let (local_track_chan_tx, mut local_track_chan_rx) = - tokio::sync::mpsc::channel::>(1); - - let local_track_chan_tx = Arc::new(local_track_chan_tx); - // Set a handler for when a new remote track starts, this handler copies inbound RTP packets, - // replaces the SSRC and sends them back - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } - }; - } - }); - - let local_track_chan_tx2 = Arc::clone(&local_track_chan_tx); - tokio::spawn(async move { - // Create Track that we send video back to browser on - let local_track = Arc::new(TrackLocalStaticRTP::new( - track.codec().capability, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let _ = local_track_chan_tx2.send(Arc::clone(&local_track)).await; - - // Read RTP packets being sent to webrtc-rs - while let Ok((rtp, _)) = track.read_rtp().await { - if let Err(err) = local_track.write_rtp(&rtp).await { - if Error::ErrClosedPipe != err { - print!("output track write_rtp got error: {err} and break"); - break; - } else { - print!("output track write_rtp got error: {err}"); - } - } - } - }); - - Box::pin(async {}) - })); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - Box::pin(async {}) - })); - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - if let Some(local_track) = local_track_chan_rx.recv().await { - loop { - println!("\nCurl an base64 SDP to start sendonly peer connection"); - - let line = sdp_chan_rx.recv().await.unwrap(); - let desc_data = signal::decode(line.as_str())?; - let recv_only_offer = serde_json::from_str::(&desc_data)?; - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let rtp_sender = peer_connection - .add_track(Arc::clone(&local_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new( - move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - Box::pin(async {}) - }, - )); - - // Set the remote SessionDescription - peer_connection - .set_remote_description(recv_only_offer) - .await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - } - } - - Ok(()) -} diff --git a/examples/examples/data-channels-close/README.md b/examples/examples/data-channels-close/README.md deleted file mode 100644 index 7deb096ea..000000000 --- a/examples/examples/data-channels-close/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# data-channels-close -data-channels-close is a variant of the data-channels example that allow playing with the life cycle of data channels. diff --git a/examples/examples/data-channels-close/data-channels-close.rs b/examples/examples/data-channels-close/data-channels-close.rs deleted file mode 100644 index ce76b605f..000000000 --- a/examples/examples/data-channels-close/data-channels-close.rs +++ /dev/null @@ -1,240 +0,0 @@ -use std::io::Write; -use std::sync::atomic::{AtomicI32, Ordering}; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::sync::Mutex; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::math_rand_alpha; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("data-channels-close") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of Data-Channels-Close.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("close-after") - .takes_value(true) - .default_value("5") - .long("close-after") - .help("Close data channel after sending X times."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let close_after = Arc::new(AtomicI32::new( - matches - .value_of("close-after") - .unwrap() - .to_owned() - .parse::()?, - )); - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Register default codecs - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection - .on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - let close_after2 = Arc::clone(&close_after); - - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open. Random messages will now be sent to any connected DataChannels every 5 seconds"); - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - Box::pin(async move { - d2.on_close(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' closed."); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move{ - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = done_rx.recv() => { - break; - } - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - - let cnt = close_after2.fetch_sub(1, Ordering::SeqCst); - if cnt <= 0 { - println!("Sent times out. Closing data channel '{}'-'{}'.", d2.label(), d2.id()); - let _ = d2.close().await; - break; - } - } - }; - } - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - }) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/data-channels-create/README.md b/examples/examples/data-channels-create/README.md deleted file mode 100644 index 786670b37..000000000 --- a/examples/examples/data-channels-create/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# data-channels-create - -data-channels-create is a WebRTC.rs application that shows how you can send/recv DataChannel messages from a web browser. The difference with the data-channels example is that the datachannel is initialized from the WebRTC.rs side in this example. - -## Instructions - -### Build data-channels-create - -```shell -cargo build --example data-channels-create -``` - -### Open data-channels-create example page - -[jsfiddle.net](https://jsfiddle.net/swgxrp94/20/) - -### Run data-channels-create - -Just run `data-channels-create`. - -### Input data-channels-create's SessionDescription into your browser - -Copy the text that `data-channels-create` just emitted and copy into first text area of the jsfiddle. - -### Hit 'Start Session' in jsfiddle - -Hit the 'Start Session' button in the browser. You should see `have-remote-offer` below the `Send Message` button. - -### Input browser's SessionDescription into data-channels-create - -Meanwhile text has appeared in the second text area of the jsfiddle. Copy the text and paste it into `data-channels-create` and hit ENTER. -In the browser you'll now see `connected` as the connection is created. If everything worked you should see `New DataChannel data`. - -Now you can put whatever you want in the `Message` textarea, and when you hit `Send Message` it should appear in your terminal! - -WebRTC.rs will send random messages every 5 seconds that will appear in your browser. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/data-channels-create/data-channels-create.rs b/examples/examples/data-channels-create/data-channels-create.rs deleted file mode 100644 index 87d3c0789..000000000 --- a/examples/examples/data-channels-create/data-channels-create.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::math_rand_alpha; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("data-channels-create") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of Data-Channels-Create.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Register default codecs - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Create a datachannel with label 'data' - let data_channel = peer_connection.create_data_channel("data", None).await?; - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register channel opening handling - let d1 = Arc::clone(&data_channel); - data_channel.on_open(Box::new(move || { - println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", d1.label(), d1.id()); - - let d2 = Arc::clone(&d1); - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; - } - }) - })); - - // Register text message handling - let d_label = data_channel.label().to_owned(); - data_channel.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - - // Create an offer to send to the browser - let offer = peer_connection.create_offer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(offer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - // Wait for the answer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let answer = serde_json::from_str::(&desc_data)?; - - // Apply the answer as the remote description - peer_connection.set_remote_description(answer).await?; - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/data-channels-detach-create/README.md b/examples/examples/data-channels-detach-create/README.md deleted file mode 100644 index 928b1024a..000000000 --- a/examples/examples/data-channels-detach-create/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# data-channels-detach-create - -data-channels-detach-create is an example that shows how you can detach a data channel. -This allows direct access the the underlying [webrtc-rs/data](https://github.com/webrtc-rs/data). - -The example mirrors the data-channels-create example. - -## Install - -```shell -cargo build --example data-channels-detach-create -``` - -## Usage - -The example can be used in the same way as the [Data Channels Create](data-channels-create) example. diff --git a/examples/examples/data-channels-detach-create/data-channels-detach-create.rs b/examples/examples/data-channels-detach-create/data-channels-detach-create.rs deleted file mode 100644 index 2c6301790..000000000 --- a/examples/examples/data-channels-detach-create/data-channels-detach-create.rs +++ /dev/null @@ -1,241 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use bytes::Bytes; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::setting_engine::SettingEngine; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::math_rand_alpha; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -const MESSAGE_SIZE: usize = 1500; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("data-channels-detach-create") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of Data-Channels-Detach-Create.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Register default codecs - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Since this behavior diverges from the WebRTC API it has to be - // enabled using a settings engine. Mixing both detached and the - // OnMessage DataChannel API is not supported. - - // Create a SettingEngine and enable Detach - let mut s = SettingEngine::default(); - s.detach_data_channels(); - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .with_setting_engine(s) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Create a datachannel with label 'data' - let data_channel = peer_connection.create_data_channel("data", None).await?; - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register channel opening handling - let d = Arc::clone(&data_channel); - data_channel.on_open(Box::new(move || { - println!("Data channel '{}'-'{}' open.", d.label(), d.id()); - - let d2 = Arc::clone(&d); - Box::pin(async move { - let raw = match d2.detach().await { - Ok(raw) => raw, - Err(err) => { - println!("data channel detach got err: {err}"); - return; - } - }; - - // Handle reading from the data channel - let r = Arc::clone(&raw); - tokio::spawn(async move { - let _ = read_loop(r).await; - }); - - // Handle writing to the data channel - tokio::spawn(async move { - let _ = write_loop(raw).await; - }); - }) - })); - - // Create an offer to send to the browser - let offer = peer_connection.create_offer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(offer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the offer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - // Wait for the answer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let answer = serde_json::from_str::(&desc_data)?; - - // Apply the answer as the remote description - peer_connection.set_remote_description(answer).await?; - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} - -// read_loop shows how to read from the datachannel directly -async fn read_loop(d: Arc) -> Result<()> { - let mut buffer = vec![0u8; MESSAGE_SIZE]; - loop { - let n = match d.read(&mut buffer).await { - Ok(n) => n, - Err(err) => { - println!("Datachannel closed; Exit the read_loop: {err}"); - return Ok(()); - } - }; - - println!( - "Message from DataChannel: {}", - String::from_utf8(buffer[..n].to_vec())? - ); - } -} - -// write_loop shows how to write to the datachannel directly -async fn write_loop(d: Arc) -> Result<()> { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d.write(&Bytes::from(message)).await.map_err(Into::into); - } - }; - } - - Ok(()) -} diff --git a/examples/examples/data-channels-detach/README.md b/examples/examples/data-channels-detach/README.md deleted file mode 100644 index a181a9737..000000000 --- a/examples/examples/data-channels-detach/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# data-channels - -data-channels is a WebRTC.rs application that shows how you can send/recv DataChannel messages from a web browser - -## Instructions - -### Build data-channels-detach - -```shell -cargo build --example data-channels-detach -``` - -### Open data-channels-detach example page - -[jsfiddle.net](https://jsfiddle.net/9tsx15mg/90/) - -### Run data-channels-detach, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser's session description, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/data-channels-detach` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/data-channels-detach < my_file` - -### Input data-channels-detach's SessionDescription into your browser - -Copy the text that `data-channels` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle - -Under Start Session you should see 'Checking' as it starts connecting. If everything worked you should see `New DataChannel foo 1` - -Now you can put whatever you want in the `Message` textarea, and when you hit `Send Message` it should appear in your terminal! - -WebRTC.rs will send random messages every 5 seconds that will appear in your browser. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/data-channels-detach/data-channels-detach.rs b/examples/examples/data-channels-detach/data-channels-detach.rs deleted file mode 100644 index 7c90a294a..000000000 --- a/examples/examples/data-channels-detach/data-channels-detach.rs +++ /dev/null @@ -1,249 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use bytes::Bytes; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::setting_engine::SettingEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::math_rand_alpha; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -const MESSAGE_SIZE: usize = 1500; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("data-channels-detach") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of Data-Channels-Detach.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Register default codecs - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Since this behavior diverges from the WebRTC API it has to be - // enabled using a settings engine. Mixing both detached and the - // OnMessage DataChannel API is not supported. - - // Create a SettingEngine and enable Detach - let mut s = SettingEngine::default(); - s.detach_data_channels(); - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .with_setting_engine(s) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open."); - - Box::pin(async move { - let raw = match d2.detach().await { - Ok(raw) => raw, - Err(err) => { - println!("data channel detach got err: {err}"); - return; - } - }; - - // Handle reading from the data channel - let r = Arc::clone(&raw); - tokio::spawn(async move { - let _ = read_loop(r).await; - }); - - // Handle writing to the data channel - tokio::spawn(async move { - let _ = write_loop(raw).await; - }); - }) - })); - }) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} - -// read_loop shows how to read from the datachannel directly -async fn read_loop(d: Arc) -> Result<()> { - let mut buffer = vec![0u8; MESSAGE_SIZE]; - loop { - let n = match d.read(&mut buffer).await { - Ok(n) => n, - Err(err) => { - println!("Datachannel closed; Exit the read_loop: {err}"); - return Ok(()); - } - }; - - println!( - "Message from DataChannel: {}", - String::from_utf8(buffer[..n].to_vec())? - ); - } -} - -// write_loop shows how to write to the datachannel directly -async fn write_loop(d: Arc) -> Result<()> { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d.write(&Bytes::from(message)).await.map_err(Into::into); - } - }; - } - - Ok(()) -} diff --git a/examples/examples/data-channels-flow-control/README.md b/examples/examples/data-channels-flow-control/README.md deleted file mode 100644 index d9111eee6..000000000 --- a/examples/examples/data-channels-flow-control/README.md +++ /dev/null @@ -1,64 +0,0 @@ -# data-channels-flow-control - -This example demonstrates how to use the following property / methods. - -* pub async fn buffered_amount(&self) -> usize -* pub async fn set_buffered_amount_low_threshold(&self, th: usize) -* pub async fn buffered_amount_low_threshold(&self) -> usize -* pub async fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) - -These methods are equivalent to that of JavaScript WebRTC API. -See for more details. - -## When do we need it? - -Send or SendText methods are called on DataChannel to send data to the connected peer. -The methods return immediately, but it does not mean the data was actually sent onto -the wire. Instead, it is queued in a buffer until it actually gets sent out to the wire. - -When you have a large amount of data to send, it is an application's responsibility to -control the buffered amount in order not to indefinitely grow the buffer size to eventually -exhaust the memory. - -The rate you wish to send data might be much higher than the rate the data channel can -actually send to the peer over the Internet. The above properties/methods help your -application to pace the amount of data to be pushed into the data channel. - -## How to run the example code - -The demo code implements two endpoints (requester and responder) in it. - -```plain - signaling messages - +----------------------------------------+ - | | - v v - +---------------+ +---------------+ - | | data | | - | requester |----------------------->| responder | - |:PeerConnection| |:PeerConnection| - +---------------+ +---------------+ -``` - -First requester and responder will exchange signaling message to establish a peer-to-peer -connection, and data channel (label: "data"). - -Once the data channel is successfully opened, requester will start sending a series of -1024-byte packets to responder, until you kill the process by Ctrl+ะก. - -Here's how to run the code: - -```shell -$ cargo run --release --example data-channels-flow-control - Finished release [optimized] target(s) in 0.36s - Running `target\release\examples\data-channels-flow-control.exe` - -Throughput is about 127.060 Mbps -Throughput is about 122.091 Mbps -Throughput is about 120.630 Mbps -Throughput is about 120.105 Mbps -Throughput is about 119.873 Mbps -Throughput is about 118.890 Mbps -Throughput is about 118.525 Mbps -Throughput is about 118.614 Mbps -``` diff --git a/examples/examples/data-channels-flow-control/data-channels-flow-control.rs b/examples/examples/data-channels-flow-control/data-channels-flow-control.rs deleted file mode 100644 index ba1999d79..000000000 --- a/examples/examples/data-channels-flow-control/data-channels-flow-control.rs +++ /dev/null @@ -1,245 +0,0 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; - -use bytes::Bytes; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_init::RTCDataChannelInit; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::RTCPeerConnection; - -const BUFFERED_AMOUNT_LOW_THRESHOLD: usize = 512 * 1024; // 512 KB -const MAX_BUFFERED_AMOUNT: usize = 1024 * 1024; // 1 MB - -async fn create_peer_connection() -> anyhow::Result { - // Create unique MediaEngine, - // as MediaEngine must not be shared between PeerConnections - let mut media_engine = MediaEngine::default(); - - media_engine.register_default_codecs()?; - - let mut interceptor_registry = Registry::new(); - - interceptor_registry = register_default_interceptors(interceptor_registry, &mut media_engine)?; - - // Create API that bundles the global functions of the WebRTC API - let api = APIBuilder::new() - .with_media_engine(media_engine) - .with_interceptor_registry(interceptor_registry) - .build(); - - let ice_servers = vec![RTCIceServer { - ..Default::default() - }]; - - let config = RTCConfiguration { - ice_servers, - ..Default::default() - }; - - Ok(api.new_peer_connection(config).await?) -} - -async fn create_requester() -> anyhow::Result { - // Create a peer connection first - let pc = create_peer_connection().await?; - - // Data transmission requires a data channel, so prepare to create one - let options = Some(RTCDataChannelInit { - ordered: Some(false), - max_retransmits: Some(0u16), - ..Default::default() - }); - - // Create a data channel to send data over a peer connection - let dc = pc.create_data_channel("data", options).await?; - - // Use mpsc channel to send and receive a signal when more data can be sent - let (more_can_be_sent, mut maybe_more_can_be_sent) = tokio::sync::mpsc::channel(1); - - // Get a shared pointer to the data channel - let shared_dc = dc.clone(); - dc.on_open(Box::new(|| { - Box::pin(async move { - // This callback shouldn't be blocked for a long time, so we spawn our handler - tokio::spawn(async move { - let buf = Bytes::from_static(&[0u8; 1024]); - - loop { - if shared_dc.send(&buf).await.is_err() { - break; - } - - let buffered_amount = shared_dc.buffered_amount().await; - - if buffered_amount + buf.len() > MAX_BUFFERED_AMOUNT { - // Wait for the signal that more can be sent - let _ = maybe_more_can_be_sent.recv().await; - } - } - }); - }) - })); - - dc.set_buffered_amount_low_threshold(BUFFERED_AMOUNT_LOW_THRESHOLD) - .await; - - dc.on_buffered_amount_low(Box::new(move || { - let more_can_be_sent = more_can_be_sent.clone(); - - Box::pin(async move { - // Send a signal that more can be sent - more_can_be_sent.send(()).await.unwrap(); - }) - })) - .await; - - Ok(pc) -} - -async fn create_responder() -> anyhow::Result { - // Create a peer connection first - let pc = create_peer_connection().await?; - - // Set a data channel handler so that we can receive data - pc.on_data_channel(Box::new(move |dc| { - Box::pin(async move { - let total_bytes_received = Arc::new(AtomicUsize::new(0)); - - let shared_total_bytes_received = total_bytes_received.clone(); - dc.on_open(Box::new(move || { - Box::pin(async { - // This callback shouldn't be blocked for a long time, so we spawn our handler - tokio::spawn(async move { - let start = SystemTime::now(); - - tokio::time::sleep(Duration::from_secs(1)).await; - println!(); - - loop { - let total_bytes_received = - shared_total_bytes_received.load(Ordering::Relaxed); - - let elapsed = SystemTime::now().duration_since(start); - let bps = - (total_bytes_received * 8) as f64 / elapsed.unwrap().as_secs_f64(); - - println!( - "Throughput is about {:.03} Mbps", - bps / (1024 * 1024) as f64 - ); - tokio::time::sleep(Duration::from_secs(1)).await; - } - }); - }) - })); - - dc.on_message(Box::new(move |msg| { - let total_bytes_received = total_bytes_received.clone(); - - Box::pin(async move { - total_bytes_received.fetch_add(msg.data.len(), Ordering::Relaxed); - }) - })); - }) - })); - - Ok(pc) -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - env_logger::init(); - - let requester = Arc::new(create_requester().await?); - let responder = Arc::new(create_responder().await?); - - let maybe_requester = Arc::downgrade(&requester); - responder.on_ice_candidate(Box::new(move |candidate: Option| { - let maybe_requester = maybe_requester.clone(); - - Box::pin(async move { - if let Some(candidate) = candidate { - if let Ok(candidate) = candidate.to_json() { - if let Some(requester) = maybe_requester.upgrade() { - if let Err(err) = requester.add_ice_candidate(candidate).await { - log::warn!("{}", err); - } - } - } - } - }) - })); - - let maybe_responder = Arc::downgrade(&responder); - requester.on_ice_candidate(Box::new(move |candidate: Option| { - let maybe_responder = maybe_responder.clone(); - - Box::pin(async move { - if let Some(candidate) = candidate { - if let Ok(candidate) = candidate.to_json() { - if let Some(responder) = maybe_responder.upgrade() { - if let Err(err) = responder.add_ice_candidate(candidate).await { - log::warn!("{}", err); - } - } - } - } - }) - })); - - let (fault, mut reqs_fault) = tokio::sync::mpsc::channel(1); - requester.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - let fault = fault.clone(); - - Box::pin(async move { - if s == RTCPeerConnectionState::Failed { - fault.send(()).await.unwrap(); - } - }) - })); - - let (fault, mut resp_fault) = tokio::sync::mpsc::channel(1); - responder.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - let fault = fault.clone(); - - Box::pin(async move { - if s == RTCPeerConnectionState::Failed { - fault.send(()).await.unwrap(); - } - }) - })); - - let reqs = requester.create_offer(None).await?; - - requester.set_local_description(reqs.clone()).await?; - responder.set_remote_description(reqs).await?; - - let resp = responder.create_answer(None).await?; - - responder.set_local_description(resp.clone()).await?; - requester.set_remote_description(resp).await?; - - tokio::select! { - _ = tokio::signal::ctrl_c() => {} - _ = reqs_fault.recv() => { - log::error!("Requester's peer connection failed...") - } - _ = resp_fault.recv() => { - log::error!("Responder's peer connection failed..."); - } - } - - requester.close().await?; - responder.close().await?; - - println!(); - - Ok(()) -} diff --git a/examples/examples/data-channels/README.md b/examples/examples/data-channels/README.md deleted file mode 100644 index 6b4019cb9..000000000 --- a/examples/examples/data-channels/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# data-channels - -data-channels is a WebRTC.rs application that shows how you can send/recv DataChannel messages from a web browser - -## Instructions - -### Build data-channels - -```shell -cargo build --example data-channels -``` - -### Open data-channels example page - -[jsfiddle.net](https://jsfiddle.net/9tsx15mg/90/) - -### Run data-channels, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser's session description, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/data-channels` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/data-channels < my_file` - -### Input data-channels's SessionDescription into your browser - -Copy the text that `data-channels` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle - -Under Start Session you should see 'Checking' as it starts connecting. If everything worked you should see `New DataChannel foo 1` - -Now you can put whatever you want in the `Message` textarea, and when you hit `Send Message` it should appear in your terminal! - -WebRTC.rs will send random messages every 5 seconds that will appear in your browser. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/data-channels/data-channels.rs b/examples/examples/data-channels/data-channels.rs deleted file mode 100644 index 708f20342..000000000 --- a/examples/examples/data-channels/data-channels.rs +++ /dev/null @@ -1,207 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::math_rand_alpha; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("data-channels") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of Data-Channels.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Register default codecs - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection - .on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_close(Box::new(move || { - println!("Data channel closed"); - Box::pin(async {}) - })); - - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open. Random messages will now be sent to any connected DataChannels every 5 seconds"); - - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; - } - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - }) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/ice-restart/README.md b/examples/examples/ice-restart/README.md deleted file mode 100644 index c3d707b7f..000000000 --- a/examples/examples/ice-restart/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# ice-restart -ice-restart demonstrates WebRTC.rs ICE Restart abilities. - -## Instructions - -### Build ice-restart -```shell -cargo build --example ice-restart -``` - -### Run ice-restart -```shell -cargo run --example ice-restart -``` - -### Open the Web UI -Open [http://localhost:8080](http://localhost:8080). This will automatically start a PeerConnection. This page will now prints stats about the PeerConnection -and allow you to do an ICE Restart at anytime. - -* `ICE Restart` is the button that causes a new offer to be made with `iceRestart: true`. -* `ICE Connection States` will contain all the connection states the PeerConnection moves through. -* `ICE Selected Pairs` will print the selected pair every 3 seconds. Note how the uFrag/uPwd/Port change everytime you start the Restart process. -* `Inbound DataChannel Messages` containing the current time sent by the Pion process every 3 seconds. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/ice-restart/ice-restart.rs b/examples/examples/ice-restart/ice-restart.rs deleted file mode 100644 index 11fe51d28..000000000 --- a/examples/examples/ice-restart/ice-restart.rs +++ /dev/null @@ -1,258 +0,0 @@ -use std::io::Write; -use std::net::SocketAddr; -use std::str::FromStr; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use tokio::sync::Mutex; -use tokio::time::Duration; -use tokio_util::codec::{BytesCodec, FramedRead}; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::RTCPeerConnection; - -#[macro_use] -extern crate lazy_static; - -lazy_static! { - static ref PEER_CONNECTION_MUTEX: Arc>>> = - Arc::new(Mutex::new(None)); -} - -static INDEX: &str = "examples/examples/ice-restart/index.html"; -static NOTFOUND: &[u8] = b"Not Found"; - -/// HTTP status code 404 -fn not_found() -> Response { - Response::builder() - .status(StatusCode::NOT_FOUND) - .body(NOTFOUND.into()) - .unwrap() -} - -async fn simple_file_send(filename: &str) -> Result, hyper::Error> { - // Serve a file by asynchronously reading it by chunks using tokio-util crate. - - if let Ok(file) = tokio::fs::File::open(filename).await { - let stream = FramedRead::new(file, BytesCodec::new()); - let body = Body::wrap_stream(stream); - return Ok(Response::new(body)); - } - - Ok(not_found()) -} - -// HTTP Listener to get ICE Credentials/Candidate from remote Peer -async fn remote_handler(req: Request) -> Result, hyper::Error> { - match (req.method(), req.uri().path()) { - (&Method::GET, "/") | (&Method::GET, "/index.html") => simple_file_send(INDEX).await, - - (&Method::POST, "/doSignaling") => do_signaling(req).await, - - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} - -// do_signaling exchanges all state of the local PeerConnection and is called -// every time a video is added or removed -async fn do_signaling(req: Request) -> Result, hyper::Error> { - let pc = { - let mut peer_connection = PEER_CONNECTION_MUTEX.lock().await; - if let Some(pc) = &*peer_connection { - Arc::clone(pc) - } else { - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - match m.register_default_codecs() { - Ok(_) => {} - Err(err) => panic!("{}", err), - }; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = match register_default_interceptors(registry, &mut m) { - Ok(r) => r, - Err(err) => panic!("{}", err), - }; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Create a new RTCPeerConnection - let pc = match api.new_peer_connection(RTCConfiguration::default()).await { - Ok(p) => p, - Err(err) => panic!("{}", err), - }; - let pc = Arc::new(pc); - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - pc.on_ice_connection_state_change(Box::new( - |connection_state: RTCIceConnectionState| { - println!("ICE Connection State has changed: {connection_state}"); - Box::pin(async {}) - }, - )); - - // Send the current time via a DataChannel to the remote peer every 3 seconds - pc.on_data_channel(Box::new(|d: Arc| { - Box::pin(async move { - let d2 = Arc::clone(&d); - d.on_open(Box::new(move || { - Box::pin(async move { - while d2 - .send_text(format!("{:?}", tokio::time::Instant::now())) - .await - .is_ok() - { - tokio::time::sleep(Duration::from_secs(3)).await; - } - }) - })); - }) - })); - - *peer_connection = Some(Arc::clone(&pc)); - pc - } - }; - - let sdp_str = match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - let offer = match serde_json::from_str::(&sdp_str) { - Ok(s) => s, - Err(err) => panic!("{}", err), - }; - - if let Err(err) = pc.set_remote_description(offer).await { - panic!("{}", err); - } - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = pc.gathering_complete_promise().await; - - // Create an answer - let answer = match pc.create_answer(None).await { - Ok(answer) => answer, - Err(err) => panic!("{}", err), - }; - - // Sets the LocalDescription, and starts our UDP listeners - if let Err(err) = pc.set_local_description(answer).await { - panic!("{}", err); - } - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - let payload = if let Some(local_desc) = pc.local_description().await { - match serde_json::to_string(&local_desc) { - Ok(p) => p, - Err(err) => panic!("{}", err), - } - } else { - panic!("generate local_description failed!"); - }; - - let mut response = match Response::builder() - .header("content-type", "application/json") - .body(Body::from(payload)) - { - Ok(res) => res, - Err(err) => panic!("{}", err), - }; - - *response.status_mut() = StatusCode::OK; - Ok(response) -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("ice-restart") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of ice-restart.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - tokio::spawn(async move { - println!("Open http://localhost:8080 to access this demo"); - - let addr = SocketAddr::from_str("0.0.0.0:8080").unwrap(); - let service = - make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); - let server = Server::bind(&addr).serve(service); - // Run this server for... forever! - if let Err(e) = server.await { - eprintln!("server error: {e}"); - } - }); - - println!("Press ctrl-c to stop"); - tokio::signal::ctrl_c().await.unwrap(); - - Ok(()) -} diff --git a/examples/examples/ice-restart/index.html b/examples/examples/ice-restart/index.html deleted file mode 100644 index 99eccffe0..000000000 --- a/examples/examples/ice-restart/index.html +++ /dev/null @@ -1,82 +0,0 @@ - - - ice-restart - - - -
- - -

ICE Connection States

-

- -

ICE Selected Pairs

-

- -

Inbound DataChannel Messages

-
- - - - diff --git a/examples/examples/insertable-streams/README.md b/examples/examples/insertable-streams/README.md deleted file mode 100644 index 55c0f9efb..000000000 --- a/examples/examples/insertable-streams/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# insertable-streams - -insertable-streams demonstrates how to use insertable streams with WebRTC.rs. -This example modifies the video with a single-byte XOR cipher before sending, and then -decrypts in Javascript. - -insertable-streams allows the browser to process encoded video. You could implement -E2E encryption, add metadata or insert a completely different video feed! - -## Instructions - -### Create IVF named `output.ivf` that contains a VP8 track - -```shell -ffmpeg -i $INPUT_FILE -g 30 output.ivf -``` - -### Build insertable-streams - -```shell -cargo build --example insertable-streams -``` - -### Open insertable-streams example page - -[jsfiddle.net](https://jsfiddle.net/uqr80Lak/) you should see two text-areas and a 'Start Session' button. You will also have a 'Decrypt' checkbox. -When unchecked the browser will not decrypt the incoming video stream, so it will stop playing or display certificates. - -### Run insertable-streams with your browsers SessionDescription as stdin - -The `output.ivf` you created should be in the same directory as `insertable-streams`. In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/insertable-streams` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/insertable-streams < my_file` - -### Input insertable-streams's SessionDescription into your browser - -Copy the text that `insertable-streams` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -A video should start playing in your browser above the input boxes. `insertable-streams` will exit when the file reaches the end. - -To stop decrypting the stream uncheck the box and the video will not be viewable. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/insertable-streams/insertable-streams.rs b/examples/examples/insertable-streams/insertable-streams.rs deleted file mode 100644 index 4e45499fa..000000000 --- a/examples/examples/insertable-streams/insertable-streams.rs +++ /dev/null @@ -1,269 +0,0 @@ -use std::fs::File; -use std::io::{BufReader, Write}; -use std::path::Path; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::sync::Notify; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::ivf_reader::IVFReader; -use webrtc::media::Sample; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use webrtc::track::track_local::TrackLocal; -use webrtc::Error; - -const CIPHER_KEY: u8 = 0xAA; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("insertable-streams") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of insertable-streams.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let video_file = matches.value_of("video").unwrap(); - if !Path::new(video_file).exists() { - return Err(Error::new(format!("video file: '{video_file}' not exist")).into()); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let video_done_tx = done_tx.clone(); - - // Create a video track - let video_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&video_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let notify_tx = Arc::new(Notify::new()); - let notify_video = notify_tx.clone(); - - let video_file_name = video_file.to_owned(); - tokio::spawn(async move { - // Open a IVF file and start reading using our IVFReader - let file = File::open(video_file_name)?; - let reader = BufReader::new(file); - let (mut ivf, header) = IVFReader::new(reader)?; - - // Wait for connection established - notify_video.notified().await; - - println!("play video from disk file output.ivf"); - - // Send our video file frame at a time. Pace our sending so we send it at the same speed it should be played back as. - // This isn't required since the video is timestamped, but we will such much higher loss if we send all at once. - let sleep_time = Duration::from_millis( - ((1000 * header.timebase_numerator) / header.timebase_denominator) as u64, - ); - loop { - let mut frame = match ivf.parse_next_frame() { - Ok((frame, _)) => frame, - Err(err) => { - println!("All video frames parsed and sent: {err}"); - break; - } - }; - - // Encrypt video using XOR Cipher - for b in &mut frame[..] { - *b ^= CIPHER_KEY; - } - - tokio::time::sleep(sleep_time).await; - - video_track - .write_sample(&Sample { - data: frame.freeze(), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await?; - } - - let _ = video_done_tx.try_send(()); - - Result::<()>::Ok(()) - }); - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - notify_tx.notify_waiters(); - } - Box::pin(async {}) - }, - )); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/offer-answer/README.md b/examples/examples/offer-answer/README.md deleted file mode 100644 index 5a95fc002..000000000 --- a/examples/examples/offer-answer/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# offer-answer - -offer-answer is an example of two webrtc-rs or pion instances communicating directly! - -The SDP offer and answer are exchanged automatically over HTTP. -The `answer` side acts like a HTTP server and should therefore be ran first. - -## Instructions - -First run `answer`: - -```shell -cargo build --example answer -./target/debug/examples/answer -``` - -Next, run `offer`: - -```shell -cargo build --example offer -./target/debug/examples/offer -``` - -You should see them connect and start to exchange messages. diff --git a/examples/examples/offer-answer/answer.rs b/examples/examples/offer-answer/answer.rs deleted file mode 100644 index 1896011ae..000000000 --- a/examples/examples/offer-answer/answer.rs +++ /dev/null @@ -1,391 +0,0 @@ -use std::io::Write; -use std::net::SocketAddr; -use std::str::FromStr; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Method, Request, Response, Server, StatusCode}; -use tokio::sync::Mutex; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::{math_rand_alpha, RTCPeerConnection}; - -#[macro_use] -extern crate lazy_static; - -lazy_static! { - static ref PEER_CONNECTION_MUTEX: Arc>>> = - Arc::new(Mutex::new(None)); - static ref PENDING_CANDIDATES: Arc>> = Arc::new(Mutex::new(vec![])); - static ref ADDRESS: Arc> = Arc::new(Mutex::new(String::new())); -} - -async fn signal_candidate(addr: &str, c: &RTCIceCandidate) -> Result<()> { - /*println!( - "signal_candidate Post candidate to {}", - format!("http://{}/candidate", addr) - );*/ - let payload = c.to_json()?.candidate; - let req = match Request::builder() - .method(Method::POST) - .uri(format!("http://{addr}/candidate")) - .header("content-type", "application/json; charset=utf-8") - .body(Body::from(payload)) - { - Ok(req) => req, - Err(err) => { - println!("{err}"); - return Err(err.into()); - } - }; - - let _resp = match Client::new().request(req).await { - Ok(resp) => resp, - Err(err) => { - println!("{err}"); - return Err(err.into()); - } - }; - //println!("signal_candidate Response: {}", resp.status()); - - Ok(()) -} - -// HTTP Listener to get ICE Credentials/Candidate from remote Peer -async fn remote_handler(req: Request) -> Result, hyper::Error> { - let pc = { - let pcm = PEER_CONNECTION_MUTEX.lock().await; - pcm.clone().unwrap() - }; - let addr = { - let addr = ADDRESS.lock().await; - addr.clone() - }; - - match (req.method(), req.uri().path()) { - // A HTTP handler that allows the other WebRTC-rs or Pion instance to send us ICE candidates - // This allows us to add ICE candidates faster, we don't have to wait for STUN or TURN - // candidates which may be slower - (&Method::POST, "/candidate") => { - //println!("remote_handler receive from /candidate"); - let candidate = - match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - - if let Err(err) = pc - .add_ice_candidate(RTCIceCandidateInit { - candidate, - ..Default::default() - }) - .await - { - panic!("{}", err); - } - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - - // A HTTP handler that processes a SessionDescription given to us from the other WebRTC-rs or Pion process - (&Method::POST, "/sdp") => { - //println!("remote_handler receive from /sdp"); - let sdp_str = match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) - { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - let sdp = match serde_json::from_str::(&sdp_str) { - Ok(s) => s, - Err(err) => panic!("{}", err), - }; - - if let Err(err) = pc.set_remote_description(sdp).await { - panic!("{}", err); - } - - // Create an answer to send to the other process - let answer = match pc.create_answer(None).await { - Ok(a) => a, - Err(err) => panic!("{}", err), - }; - - /*println!( - "remote_handler Post answer to {}", - format!("http://{}/sdp", addr) - );*/ - - // Send our answer to the HTTP server listening in the other process - let payload = match serde_json::to_string(&answer) { - Ok(p) => p, - Err(err) => panic!("{}", err), - }; - - let req = match Request::builder() - .method(Method::POST) - .uri(format!("http://{addr}/sdp")) - .header("content-type", "application/json; charset=utf-8") - .body(Body::from(payload)) - { - Ok(req) => req, - Err(err) => panic!("{}", err), - }; - - let _resp = match Client::new().request(req).await { - Ok(resp) => resp, - Err(err) => { - println!("{err}"); - return Err(err); - } - }; - //println!("remote_handler Response: {}", resp.status()); - - // Sets the LocalDescription, and starts our UDP listeners - if let Err(err) = pc.set_local_description(answer).await { - panic!("{}", err); - } - - { - let cs = PENDING_CANDIDATES.lock().await; - for c in &*cs { - if let Err(err) = signal_candidate(&addr, c).await { - panic!("{}", err); - } - } - } - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("Answer") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of WebRTC-rs Answer.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("offer-address") - .takes_value(true) - .default_value("localhost:50000") - .long("offer-address") - .help("Address that the Offer HTTP server is hosted on."), - ) - .arg( - Arg::new("answer-address") - .takes_value(true) - .default_value("0.0.0.0:60000") - .long("answer-address") - .help("Address that the Answer HTTP server is hosted on."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let offer_addr = matches.value_of("offer-address").unwrap().to_owned(); - let answer_addr = matches.value_of("answer-address").unwrap().to_owned(); - - { - let mut oa = ADDRESS.lock().await; - oa.clone_from(&offer_addr); - } - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // When an ICE candidate is available send to the other Pion instance - // the other Pion instance will add this candidate by calling AddICECandidate - let pc = Arc::downgrade(&peer_connection); - let pending_candidates2 = Arc::clone(&PENDING_CANDIDATES); - let addr2 = offer_addr.clone(); - peer_connection.on_ice_candidate(Box::new(move |c: Option| { - //println!("on_ice_candidate {:?}", c); - - let pc2 = pc.clone(); - let pending_candidates3 = Arc::clone(&pending_candidates2); - let addr3 = addr2.clone(); - Box::pin(async move { - if let Some(c) = c { - if let Some(pc) = pc2.upgrade() { - let desc = pc.remote_description().await; - if desc.is_none() { - let mut cs = pending_candidates3.lock().await; - cs.push(c); - } else if let Err(err) = signal_candidate(&addr3, &c).await { - panic!("{}", err); - } - } - } - }) - })); - - println!("Listening on http://{answer_addr}"); - { - let mut pcm = PEER_CONNECTION_MUTEX.lock().await; - *pcm = Some(Arc::clone(&peer_connection)); - } - - tokio::spawn(async move { - let addr = SocketAddr::from_str(&answer_addr).unwrap(); - let service = - make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); - let server = Server::bind(&addr).serve(service); - // Run this server for... forever! - if let Err(e) = server.await { - eprintln!("server error: {e}"); - } - }); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - Box::pin(async move{ - // Register channel opening handling - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open. Random messages will now be sent to any connected DataChannels every 5 seconds"); - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; - } - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async{}) - })); - }) - })); - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/offer-answer/offer.rs b/examples/examples/offer-answer/offer.rs deleted file mode 100644 index 7755bca57..000000000 --- a/examples/examples/offer-answer/offer.rs +++ /dev/null @@ -1,377 +0,0 @@ -use std::io::Write; -use std::net::SocketAddr; -use std::str::FromStr; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Method, Request, Response, Server, StatusCode}; -use tokio::sync::Mutex; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::{math_rand_alpha, RTCPeerConnection}; - -#[macro_use] -extern crate lazy_static; - -lazy_static! { - static ref PEER_CONNECTION_MUTEX: Arc>>> = - Arc::new(Mutex::new(None)); - static ref PENDING_CANDIDATES: Arc>> = Arc::new(Mutex::new(vec![])); - static ref ADDRESS: Arc> = Arc::new(Mutex::new(String::new())); -} - -async fn signal_candidate(addr: &str, c: &RTCIceCandidate) -> Result<()> { - /*println!( - "signal_candidate Post candidate to {}", - format!("http://{}/candidate", addr) - );*/ - let payload = c.to_json()?.candidate; - let req = match Request::builder() - .method(Method::POST) - .uri(format!("http://{addr}/candidate")) - .header("content-type", "application/json; charset=utf-8") - .body(Body::from(payload)) - { - Ok(req) => req, - Err(err) => { - println!("{err}"); - return Err(err.into()); - } - }; - - let _resp = match Client::new().request(req).await { - Ok(resp) => resp, - Err(err) => { - println!("{err}"); - return Err(err.into()); - } - }; - //println!("signal_candidate Response: {}", resp.status()); - - Ok(()) -} - -// HTTP Listener to get ICE Credentials/Candidate from remote Peer -async fn remote_handler(req: Request) -> Result, hyper::Error> { - let pc = { - let pcm = PEER_CONNECTION_MUTEX.lock().await; - pcm.clone().unwrap() - }; - let addr = { - let addr = ADDRESS.lock().await; - addr.clone() - }; - - match (req.method(), req.uri().path()) { - // A HTTP handler that allows the other WebRTC-rs or Pion instance to send us ICE candidates - // This allows us to add ICE candidates faster, we don't have to wait for STUN or TURN - // candidates which may be slower - (&Method::POST, "/candidate") => { - //println!("remote_handler receive from /candidate"); - let candidate = - match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - - if let Err(err) = pc - .add_ice_candidate(RTCIceCandidateInit { - candidate, - ..Default::default() - }) - .await - { - panic!("{}", err); - } - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - - // A HTTP handler that processes a SessionDescription given to us from the other WebRTC-rs or Pion process - (&Method::POST, "/sdp") => { - //println!("remote_handler receive from /sdp"); - let sdp_str = match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) - { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - let sdp = match serde_json::from_str::(&sdp_str) { - Ok(s) => s, - Err(err) => panic!("{}", err), - }; - - if let Err(err) = pc.set_remote_description(sdp).await { - panic!("{}", err); - } - - { - let cs = PENDING_CANDIDATES.lock().await; - for c in &*cs { - if let Err(err) = signal_candidate(&addr, c).await { - panic!("{}", err); - } - } - } - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("Offer") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of WebRTC-rs Offer.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("offer-address") - .takes_value(true) - .default_value("0.0.0.0:50000") - .long("offer-address") - .help("Address that the Offer HTTP server is hosted on."), - ) - .arg( - Arg::new("answer-address") - .takes_value(true) - .default_value("localhost:60000") - .long("answer-address") - .help("Address that the Answer HTTP server is hosted on."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let offer_addr = matches.value_of("offer-address").unwrap().to_owned(); - let answer_addr = matches.value_of("answer-address").unwrap().to_owned(); - - { - let mut oa = ADDRESS.lock().await; - oa.clone_from(&answer_addr); - } - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // When an ICE candidate is available send to the other Pion instance - // the other Pion instance will add this candidate by calling AddICECandidate - let pc = Arc::downgrade(&peer_connection); - let pending_candidates2 = Arc::clone(&PENDING_CANDIDATES); - let addr2 = answer_addr.clone(); - peer_connection.on_ice_candidate(Box::new(move |c: Option| { - //println!("on_ice_candidate {:?}", c); - - let pc2 = pc.clone(); - let pending_candidates3 = Arc::clone(&pending_candidates2); - let addr3 = addr2.clone(); - Box::pin(async move { - if let Some(c) = c { - if let Some(pc) = pc2.upgrade() { - let desc = pc.remote_description().await; - if desc.is_none() { - let mut cs = pending_candidates3.lock().await; - cs.push(c); - } else if let Err(err) = signal_candidate(&addr3, &c).await { - panic!("{}", err); - } - } - } - }) - })); - - println!("Listening on http://{offer_addr}"); - { - let mut pcm = PEER_CONNECTION_MUTEX.lock().await; - *pcm = Some(Arc::clone(&peer_connection)); - } - - tokio::spawn(async move { - let addr = SocketAddr::from_str(&offer_addr).unwrap(); - let service = - make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); - let server = Server::bind(&addr).serve(service); - // Run this server for... forever! - if let Err(e) = server.await { - eprintln!("server error: {e}"); - } - }); - - // Create a datachannel with label 'data' - let data_channel = peer_connection.create_data_channel("data", None).await?; - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register channel opening handling - let d1 = Arc::clone(&data_channel); - data_channel.on_open(Box::new(move || { - println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", d1.label(), d1.id()); - - let d2 = Arc::clone(&d1); - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; - } - }) - })); - - // Register text message handling - let d_label = data_channel.label().to_owned(); - data_channel.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - - // Create an offer to send to the other process - let offer = peer_connection.create_offer(None).await?; - - // Send our offer to the HTTP server listening in the other process - let payload = match serde_json::to_string(&offer) { - Ok(p) => p, - Err(err) => panic!("{}", err), - }; - - // Sets the LocalDescription, and starts our UDP listeners - // Note: this will start the gathering of ICE candidates - peer_connection.set_local_description(offer).await?; - - //println!("Post: {}", format!("http://{}/sdp", answer_addr)); - let req = match Request::builder() - .method(Method::POST) - .uri(format!("http://{answer_addr}/sdp")) - .header("content-type", "application/json; charset=utf-8") - .body(Body::from(payload)) - { - Ok(req) => req, - Err(err) => panic!("{}", err), - }; - - let _resp = match Client::new().request(req).await { - Ok(resp) => resp, - Err(err) => { - println!("{err}"); - return Err(err.into()); - } - }; - //println!("Response: {}", resp.status()); - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/ortc/README.md b/examples/examples/ortc/README.md deleted file mode 100644 index 06237de67..000000000 --- a/examples/examples/ortc/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# ortc - -ortc demonstrates WebRTC.rs's [ORTC](https://ortc.org/) capabilities. Instead of using the Session Description Protocol -to configure and communicate ORTC provides APIs. Users then can implement signaling with whatever protocol they wish. -ORTC can then be used to implement WebRTC. A ORTC implementation can parse/emit Session Description and act as a WebRTC -implementation. - -In this example we have defined a simple JSON based signaling protocol. - -## Instructions - -### Build ortc - -```shell -cargo build --example ortc -``` - -### Run first client as offerer - -`ortc --offer` this will emit a base64 message. Copy this message to your clipboard. - -## Run the second client as answerer - -Run the second client. This should be launched with the message you copied in the previous step as stdin. - -`echo BASE64_MESSAGE_YOU_COPIED | ortc` - -### Enjoy - -If everything worked you will see `Data channel 'Foo'-'' open.` in each terminal. - -Each client will send random messages every 5 seconds that will appear in the terminal - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/ortc/ortc.rs b/examples/examples/ortc/ortc.rs deleted file mode 100644 index f4d2fa439..000000000 --- a/examples/examples/ortc/ortc.rs +++ /dev/null @@ -1,278 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use serde::{Deserialize, Serialize}; -use tokio::sync::Notify; -use tokio::time::Duration; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::data_channel_parameters::DataChannelParameters; -use webrtc::data_channel::RTCDataChannel; -use webrtc::dtls_transport::dtls_parameters::DTLSParameters; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; -use webrtc::ice_transport::ice_gatherer::RTCIceGatherOptions; -use webrtc::ice_transport::ice_parameters::RTCIceParameters; -use webrtc::ice_transport::ice_role::RTCIceRole; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::peer_connection::math_rand_alpha; -use webrtc::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("ortc") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of ORTC.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("offer") - .long("offer") - .help("Act as the offerer if set."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let is_offer = matches.is_present("offer"); - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the Pion WebRTC (ORTC) API! Thanks for using it โค๏ธ. - - // Prepare ICE gathering options - let ice_options = RTCIceGatherOptions { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create an API object - let api = APIBuilder::new().build(); - - // Create the ICE gatherer - let gatherer = Arc::new(api.new_ice_gatherer(ice_options)?); - - // Construct the ICE transport - let ice = Arc::new(api.new_ice_transport(Arc::clone(&gatherer))); - - // Construct the DTLS transport - let dtls = Arc::new(api.new_dtls_transport(Arc::clone(&ice), vec![])?); - - // Construct the SCTP transport - let sctp = Arc::new(api.new_sctp_transport(Arc::clone(&dtls))?); - - let done = Arc::new(Notify::new()); - let done_answer = done.clone(); - let done_offer = done.clone(); - - // Handle incoming data channels - sctp.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - let done_answer1 = done_answer.clone(); - // Register the handlers - Box::pin(async move { - // no need to downgrade this to Weak, since on_open is FnOnce callback - let d2 = Arc::clone(&d); - let done_answer2 = done_answer1.clone(); - d.on_open(Box::new(move || { - Box::pin(async move { - tokio::select! { - _ = done_answer2.notified() => { - println!("received done_answer signal!"); - } - _ = handle_on_open(d2) => {} - }; - - println!("exit data answer"); - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - }) - })); - - let (gather_finished_tx, mut gather_finished_rx) = tokio::sync::mpsc::channel::<()>(1); - let mut gather_finished_tx = Some(gather_finished_tx); - gatherer.on_local_candidate(Box::new(move |c: Option| { - if c.is_none() { - gather_finished_tx.take(); - } - Box::pin(async {}) - })); - - // Gather candidates - gatherer.gather().await?; - - let _ = gather_finished_rx.recv().await; - - let ice_candidates = gatherer.get_local_candidates().await?; - - let ice_parameters = gatherer.get_local_parameters().await?; - - let dtls_parameters = dtls.get_local_parameters()?; - - let sctp_capabilities = sctp.get_capabilities(); - - let local_signal = Signal { - ice_candidates, - ice_parameters, - dtls_parameters, - sctp_capabilities, - }; - - // Exchange the information - let json_str = serde_json::to_string(&local_signal)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - - let line = signal::must_read_stdin()?; - let json_str = signal::decode(line.as_str())?; - let remote_signal = serde_json::from_str::(&json_str)?; - - let ice_role = if is_offer { - RTCIceRole::Controlling - } else { - RTCIceRole::Controlled - }; - - ice.set_remote_candidates(&remote_signal.ice_candidates) - .await?; - - // Start the ICE transport - ice.start(&remote_signal.ice_parameters, Some(ice_role)) - .await?; - - // Start the DTLS transport - dtls.start(remote_signal.dtls_parameters).await?; - - // Start the SCTP transport - sctp.start(remote_signal.sctp_capabilities).await?; - - // Construct the data channel as the offerer - if is_offer { - let id = 1u16; - - let dc_params = DataChannelParameters { - label: "Foo".to_owned(), - negotiated: Some(id), - ..Default::default() - }; - - let d = Arc::new(api.new_data_channel(Arc::clone(&sctp), dc_params).await?); - - // Register the handlers - // channel.OnOpen(handleOnOpen(channel)) // TODO: OnOpen on handle ChannelAck - // Temporary alternative - - // no need to downgrade this to Weak - let d2 = Arc::clone(&d); - tokio::spawn(async move { - tokio::select! { - _ = done_offer.notified() => { - println!("received done_offer signal!"); - } - _ = handle_on_open(d2) => {} - }; - - println!("exit data offer"); - }); - - let d_label = d.label().to_owned(); - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - } - - println!("Press ctrl-c to stop"); - tokio::signal::ctrl_c().await.unwrap(); - done.notify_waiters(); - - sctp.stop().await?; - dtls.stop().await?; - ice.stop().await?; - - Ok(()) -} - -// Signal is used to exchange signaling info. -// This is not part of the ORTC spec. You are free -// to exchange this information any way you want. -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Signal { - #[serde(rename = "iceCandidates")] - ice_candidates: Vec, // `json:"iceCandidates"` - - #[serde(rename = "iceParameters")] - ice_parameters: RTCIceParameters, // `json:"iceParameters"` - - #[serde(rename = "dtlsParameters")] - dtls_parameters: DTLSParameters, // `json:"dtlsParameters"` - - #[serde(rename = "sctpCapabilities")] - sctp_capabilities: SCTPTransportCapabilities, // `json:"sctpCapabilities"` -} - -async fn handle_on_open(d: Arc) -> Result<()> { - println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", d.label(), d.id()); - - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d.send_text(message).await.map_err(Into::into); - } - }; - } - - Ok(()) -} diff --git a/examples/examples/play-from-disk-h264/README.md b/examples/examples/play-from-disk-h264/README.md deleted file mode 100644 index fa8e53fd6..000000000 --- a/examples/examples/play-from-disk-h264/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# play-from-disk-h264 - -play-from-disk-h264 demonstrates how to send h264 video and/or audio to your browser from files saved to disk. - -## Instructions - -### Create IVF named `output.264` that contains a H264 track and/or `output.ogg` that contains a Opus track - -```shell -ffmpeg -i $INPUT_FILE -an -c:v libx264 -bsf:v h264_mp4toannexb -b:v 2M -max_delay 0 -bf 0 output.h264 -ffmpeg -i $INPUT_FILE -c:a libopus -page_duration 20000 -vn output.ogg -``` - -### Build play-from-disk-h264 - -```shell -cargo build --example play-from-disk-h264 -``` - -### Open play-from-disk-h264 example page - -[jsfiddle.net](https://jsfiddle.net/9s10amwL/) you should see two text-areas and a 'Start Session' button - -### Run play-from-disk-h264 with your browsers SessionDescription as stdin - -The `output.h264` you created should be in the same directory as `play-from-disk-h264`. In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/play-from-disk-h264 -v examples/test-data/output.h264 -a examples/test-data/output.ogg` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/play-from-disk-h264 -v examples/test-data/output.h264 -a examples/test-data/output.ogg < my_file` - -### Input play-from-disk-h264's SessionDescription into your browser - -Copy the text that `play-from-disk-h264` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -A video should start playing in your browser above the input boxes. `play-from-disk-h264` will exit when the file reaches the end - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/play-from-disk-h264/play-from-disk-h264.rs b/examples/examples/play-from-disk-h264/play-from-disk-h264.rs deleted file mode 100644 index 94e770185..000000000 --- a/examples/examples/play-from-disk-h264/play-from-disk-h264.rs +++ /dev/null @@ -1,361 +0,0 @@ -use std::fs::File; -use std::io::{BufReader, Write}; -use std::path::Path; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::sync::Notify; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_H264, MIME_TYPE_OPUS}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::h264_reader::H264Reader; -use webrtc::media::io::ogg_reader::OggReader; -use webrtc::media::Sample; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use webrtc::track::track_local::TrackLocal; -use webrtc::Error; - -const OGG_PAGE_DURATION: Duration = Duration::from_millis(20); - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("play-from-disk-h264") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of play-from-disk-h264.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ) - .arg( - Arg::new("audio") - .takes_value(true) - .short('a') - .long("audio") - .help("Audio file to be streaming."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let video_file = matches.value_of("video"); - let audio_file = matches.value_of("audio"); - - if let Some(video_path) = &video_file { - if !Path::new(video_path).exists() { - return Err(Error::new(format!("video file: '{video_path}' not exist")).into()); - } - } - if let Some(audio_path) = &audio_file { - if !Path::new(audio_path).exists() { - return Err(Error::new(format!("audio file: '{audio_path}' not exist")).into()); - } - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let notify_tx = Arc::new(Notify::new()); - let notify_video = notify_tx.clone(); - let notify_audio = notify_tx.clone(); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let video_done_tx = done_tx.clone(); - let audio_done_tx = done_tx.clone(); - - if let Some(video_file) = video_file { - // Create a video track - let video_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&video_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let video_file_name = video_file.to_owned(); - tokio::spawn(async move { - // Open a H264 file and start reading using our H264Reader - let file = File::open(&video_file_name)?; - let reader = BufReader::new(file); - let mut h264 = H264Reader::new(reader, 1_048_576); - - // Wait for connection established - notify_video.notified().await; - - println!("play video from disk file {video_file_name}"); - - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep - let mut ticker = tokio::time::interval(Duration::from_millis(33)); - loop { - let nal = match h264.next_nal() { - Ok(nal) => nal, - Err(err) => { - println!("All video frames parsed and sent: {err}"); - break; - } - }; - - /*println!( - "PictureOrderCount={}, ForbiddenZeroBit={}, RefIdc={}, UnitType={}, data={}", - nal.picture_order_count, - nal.forbidden_zero_bit, - nal.ref_idc, - nal.unit_type, - nal.data.len() - );*/ - - video_track - .write_sample(&Sample { - data: nal.data.freeze(), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await?; - - let _ = ticker.tick().await; - } - - let _ = video_done_tx.try_send(()); - - Result::<()>::Ok(()) - }); - } - - if let Some(audio_file) = audio_file { - // Create a audio track - let audio_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - ..Default::default() - }, - "audio".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&audio_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let audio_file_name = audio_file.to_owned(); - tokio::spawn(async move { - // Open a IVF file and start reading using our IVFReader - let file = File::open(audio_file_name)?; - let reader = BufReader::new(file); - // Open on oggfile in non-checksum mode. - let (mut ogg, _) = OggReader::new(reader, true)?; - - // Wait for connection established - notify_audio.notified().await; - - println!("play audio from disk file output.ogg"); - - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep - let mut ticker = tokio::time::interval(OGG_PAGE_DURATION); - - // Keep track of last granule, the difference is the amount of samples in the buffer - let mut last_granule: u64 = 0; - while let Ok((page_data, page_header)) = ogg.parse_next_page() { - // The amount of samples is the difference between the last and current timestamp - let sample_count = page_header.granule_position - last_granule; - last_granule = page_header.granule_position; - let sample_duration = Duration::from_millis(sample_count * 1000 / 48000); - - audio_track - .write_sample(&Sample { - data: page_data.freeze(), - duration: sample_duration, - ..Default::default() - }) - .await?; - - let _ = ticker.tick().await; - } - - let _ = audio_done_tx.try_send(()); - - Result::<()>::Ok(()) - }); - } - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - notify_tx.notify_waiters(); - } - Box::pin(async {}) - }, - )); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/play-from-disk-hevc/README.md b/examples/examples/play-from-disk-hevc/README.md deleted file mode 100644 index 16fa2cad5..000000000 --- a/examples/examples/play-from-disk-hevc/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# play-from-disk-hevc - -play-from-disk-hevc demonstrates how to send hevc video and/or audio to your browser from files saved to disk. - -## Instructions - -### Create IVF named `output.265` that contains a hevc track and/or `output.ogg` that contains a Opus track - -```shell -ffmpeg -i $INPUT_FILE -an -c:v libx265 -bsf:v hevc_mp4toannexb -b:v 2M -max_delay 0 -bf 0 output.265 -ffmpeg -i $INPUT_FILE -c:a libopus -page_duration 20000 -vn output.ogg -``` - -### Build/Run play-from-disk-hevc - -```shell -cargo run --example play-from-disk-hevc -``` - -### Result and Output -In the shell you opened, you should see from std that rtp of hevc get received and parsed - -After all is done, an `xx.output` file should be created at the same directory of the src video file - -Congrats, you have sent and received the hevc stream - -## Notes -- Maybe you will need to install libx265/opus for your ffmepg -- Please update the stun server to the best match, google maybe slow/unaccessable in some certain region/circumstance diff --git a/examples/examples/play-from-disk-hevc/play-from-disk-hevc.rs b/examples/examples/play-from-disk-hevc/play-from-disk-hevc.rs deleted file mode 100644 index c2174e211..000000000 --- a/examples/examples/play-from-disk-hevc/play-from-disk-hevc.rs +++ /dev/null @@ -1,599 +0,0 @@ -use anyhow::Result; -use bytes::BytesMut; -use clap::{AppSettings, Arg, Command}; -use std::fs::File; -use std::io::{BufReader, Read, Write}; -use std::path::Path; -use std::sync::{Arc, Weak}; -use tokio::sync::{mpsc, Notify}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_HEVC, MIME_TYPE_OPUS}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::ogg_reader::OggReader; -use webrtc::media::Sample; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::RTCPeerConnection; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp::codecs::h264::ANNEXB_NALUSTART_CODE; -use webrtc::rtp::codecs::h265::{H265NALUHeader, H265Packet, H265Payload, UnitType}; -use webrtc::rtp::packetizer::Depacketizer; -use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType}; -use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; -use webrtc::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use webrtc::rtp_transceiver::{RTCRtpTransceiver, RTCRtpTransceiverInit}; -use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use webrtc::track::track_local::TrackLocal; -use webrtc::track::track_remote::TrackRemote; -use webrtc::Error; - -const OGG_PAGE_DURATION: Duration = Duration::from_millis(20); - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("play-from-disk-hevc") - .version("0.1.0") - .author("RobinShi ") - .about("An example of play-from-disk-hevc.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ) - .arg( - Arg::new("audio") - .takes_value(true) - .short('a') - .long("audio") - .help("Audio file to be streaming."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let video_file = matches.value_of("video"); - let audio_file = matches.value_of("audio"); - - if let Some(video_path) = &video_file { - if !Path::new(video_path).exists() { - return Err(Error::new(format!("video file: '{video_path}' not exist")).into()); - } - } - if let Some(audio_path) = &audio_file { - if !Path::new(audio_path).exists() { - return Err(Error::new(format!("audio file: '{audio_path}' not exist")).into()); - } - } - let video_file = video_file.map(|v| v.to_owned()).unwrap(); - let audio_file = audio_file.map(|v| v.to_owned()).unwrap(); - - let video_file1 = video_file.clone(); - let (offer_sdr, mut offer_rcv) = mpsc::channel::(10); - let (answer_sdr, answer_rcv) = mpsc::channel::(10); - tokio::spawn(async move { - if let Err(e) = offer_worker(video_file1, audio_file, offer_sdr, answer_rcv).await { - println!("[Speaker] Error: {:?}", e); - } - }); - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - peer_connection - .add_transceiver_from_kind( - RTPCodecType::Audio, - Some(RTCRtpTransceiverInit { - direction: RTCRtpTransceiverDirection::Sendrecv, - send_encodings: vec![], - }), - ) - .await?; - peer_connection - .add_transceiver_from_kind( - RTPCodecType::Video, - Some(RTCRtpTransceiverInit { - direction: RTCRtpTransceiverDirection::Sendrecv, - send_encodings: vec![], - }), - ) - .await?; - let pc1 = Arc::downgrade(&peer_connection); - let close_notify = Arc::new(Notify::new()); - let notify1 = close_notify.clone(); - peer_connection.on_track(Box::new( - move |track: Arc, - _receiver: Arc, - _tranceiver: Arc| { - let media_ssrc = track.ssrc(); - let pc2 = pc1.clone(); - let kind = track.kind(); - let notify2 = notify1.clone(); - println!("[Listener] track codec {:?}", track.codec()); - if kind == RTPCodecType::Video { - tokio::spawn(async move { - let mut ticker = tokio::time::interval(Duration::from_secs(2)); - while let Some(pc3) = pc2.upgrade() { - if peer_closed(&pc3) { - break; - } - if pc3 - .write_rtcp(&[Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc, - })]) - .await - .is_err() - { - break; - } - let _ = ticker.tick().await; - } - println!("[Listener] closing {kind} pli thread"); - }); - } - - let pc2 = pc1.clone(); - let video_file1 = video_file.clone(); - match kind { - RTPCodecType::Video => { - tokio::spawn(async move { - let mut pck = H265Packet::default(); - let mut fdata = BytesMut::new(); - loop { - let timeout = tokio::time::sleep(Duration::from_secs(4)); - tokio::pin!(timeout); - tokio::select! { - _ = timeout.as_mut() => { - break; - } - m = track.read_rtp() => { - println!("rtp readed"); - if let Ok((p, _)) = m { - let data = pck.depacketize(&p.payload).unwrap(); - match pck.payload() { - H265Payload::H265PACIPacket(p) => { - println!("[Listener] paci {:?}", p.payload_header()); - } - H265Payload::H265SingleNALUnitPacket(p) => { - println!( - "[Listener] single len {:?} type {:?}", - p.payload().len(), - p.payload_header().nalu_type() - ); - fdata.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); - fdata.extend_from_slice(&data); - } - H265Payload::H265AggregationPacket(p) => { - if let Some(uf) = p.first_unit() { - println!( - "[Listener] aggr first nal len {} type {:?}", - uf.nal_unit().len(), - UnitType::for_id((uf.nal_unit()[0] & 0b0111_1110) >> 1) - ); - fdata.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); - fdata.extend_from_slice(&uf.nal_unit()); - } - for ou in p.other_units() { - println!( - "[Listener] aggr other nal len {} type {:?}", - ou.nal_unit().len(), - UnitType::for_id((ou.nal_unit()[0] & 0b0111_1110) >> 1) - ); - fdata.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); - fdata.extend_from_slice(&ou.nal_unit()); - } - } - H265Payload::H265FragmentationUnitPacket(p) => { - println!( - "[Listener] fu nal header {:?} data4 {:?}, nal_type {:?}", - p.fu_header(), - &data[0..4], - p.fu_header().fu_type(), - ); - if p.fu_header().s() { - let nal_type = (p.fu_header().fu_type() << 1) & 0b0111_1110; - fdata.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]); - fdata.extend_from_slice(&[nal_type, 0x01]); - } - fdata.extend_from_slice(&p.payload()); - if p.fu_header().e() { - println!("[Listener] fu nal collected"); - } - } - } - } else if weak_peer_closed(&pc2) { - println!("peer abnormally closed"); - break; - } - } - } - } - let mut file = std::fs::File::create(format!("{video_file1}.output")).unwrap(); - let _ = file.write_all(&fdata); - println!("[Listener] closing video read thread"); - notify2.notify_waiters(); - }); - } - RTPCodecType::Audio => { - tokio::spawn(async move { - loop { - let timeout = tokio::time::sleep(Duration::from_secs(4)); - tokio::pin!(timeout); - tokio::select! { - _ = timeout.as_mut() => { - break; - } - m = track.read_rtp() => { - if m.is_err() && weak_peer_closed(&pc2) { - break; - } - } - } - } - println!("[Listener] closing audio read thread"); - notify2.notify_waiters(); - }); - } - _ => {} - } - Box::pin(async {}) - }, - )); - let notify1 = close_notify.clone(); - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("[Listener] session state changed {connection_state}",); - if connection_state == RTCIceConnectionState::Closed - || connection_state == RTCIceConnectionState::Failed - { - notify1.notify_waiters(); - } - Box::pin(async {}) - }, - )); - - println!("[Listener] waiting for offer"); - let timeout = tokio::time::sleep(Duration::from_secs(60)); - tokio::pin!(timeout); - let offer = tokio::select! { - _ = timeout.as_mut() => {panic!("wait offer failed")} - sdp = offer_rcv.recv() => {sdp.unwrap()} - }; - peer_connection.set_remote_description(offer).await?; - let answer = peer_connection.create_answer(None).await?; - let mut gather_complete = peer_connection.gathering_complete_promise().await; - peer_connection.set_local_description(answer).await?; - let _ = gather_complete.recv().await; - - println!("[Listener] offer set, sending answer"); - if let Some(answer) = peer_connection.local_description().await { - let _ = answer_sdr.send(answer).await; - } - - println!("[Listener] answer sent, await quit event"); - let timeout = tokio::time::sleep(Duration::from_secs(60)); - tokio::pin!(timeout); - tokio::select! { - _ = timeout.as_mut() => {} - _ = close_notify.notified() => {} - } - let _ = peer_connection.close().await; - println!("[Listener] closing peer"); - - Ok(()) -} - -async fn offer_worker( - video_file: String, - audio_file: String, - offer_sdr: mpsc::Sender, - mut answer_rcv: mpsc::Receiver, -) -> Result<()> { - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let video_done_tx = done_tx.clone(); - let audio_done_tx = done_tx.clone(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - let notify_connect = Arc::new(Notify::new()); - - let local_video_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_HEVC.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let video_rtp_sender = peer_connection - .add_track(Arc::clone(&local_video_track) as Arc) - .await?; - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = video_rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - let notify1 = notify_connect.clone(); - tokio::spawn(async move { - let mut buf = vec![]; - let mut file = File::open(&video_file).unwrap(); - let _ = file.read_to_end(&mut buf); - let mut data = BytesMut::from_iter(buf); - - let list = memchr::memmem::find_iter(&data, &ANNEXB_NALUSTART_CODE); - let mut data_list = vec![]; - let mut idxs = list.into_iter().collect::>(); - idxs.reverse(); - for i in idxs { - let nal_data = data.split_off(i); - // let payload_header = H265NALUHeader::new(nal_data[4], nal_data[5]); - // let payload_nalu_type = payload_header.nalu_type(); - // let nalu_type = UnitType::for_id(payload_nalu_type).unwrap_or(UnitType::IGNORE); - data_list.insert(0, nal_data); - } - - let timeout = tokio::time::sleep(Duration::from_secs(10)); - tokio::pin!(timeout); - tokio::select! { - _ = timeout.as_mut() => {return;} - _ = notify1.notified()=> {} - }; - println!("[Speaker] play video from disk file"); - let mut ticker = tokio::time::interval(Duration::from_millis(33)); - loop { - if data_list.is_empty() { - break; - } - let nal_data = data_list.remove(0); - let payload_header = H265NALUHeader::new(nal_data[4], nal_data[5]); - let payload_nalu_type = payload_header.nalu_type(); - let nalu_type = UnitType::for_id(payload_nalu_type).unwrap_or(UnitType::IGNORE); - if let Err(e) = local_video_track - .write_sample(&Sample { - data: nal_data.freeze(), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await - { - println!("[Speaker] sending video err {e}"); - } - - if nalu_type != UnitType::VPS - || nalu_type != UnitType::SPS - || nalu_type != UnitType::PPS - || nalu_type != UnitType::SEI - { - let _ = ticker.tick().await; - } - } - let _ = video_done_tx.try_send(()); - }); - - let local_audio_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - ..Default::default() - }, - "audio".to_owned(), - "webrtc-rs".to_owned(), - )); - let audio_rtp_sender = peer_connection - .add_track(Arc::clone(&local_audio_track) as Arc) - .await?; - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = audio_rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - let notify1 = notify_connect.clone(); - tokio::spawn(async move { - // Open a IVF file and start reading using our IVFReader - let file = File::open(&audio_file)?; - let reader = BufReader::new(file); - // Open on oggfile in non-checksum mode. - let (mut ogg, _) = OggReader::new(reader, true)?; - // Wait for connection established - notify1.notified().await; - println!("[Speaker] play audio from disk file output.ogg"); - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep - let mut ticker = tokio::time::interval(OGG_PAGE_DURATION); - // Keep track of last granule, the difference is the amount of samples in the buffer - let mut last_granule: u64 = 0; - while let Ok((page_data, page_header)) = ogg.parse_next_page() { - // The amount of samples is the difference between the last and current timestamp - let sample_count = page_header.granule_position - last_granule; - last_granule = page_header.granule_position; - let sample_duration = Duration::from_millis(sample_count * 1000 / 48000); - if let Err(e) = local_audio_track - .write_sample(&Sample { - data: page_data.freeze(), - duration: sample_duration, - ..Default::default() - }) - .await - { - println!("[Speaker] sending audio err {e}"); - } - let _ = ticker.tick().await; - } - let _ = audio_done_tx.try_send(()); - Result::<()>::Ok(()) - }); - - let notify1 = notify_connect.clone(); - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("[Speaker] session state changed {connection_state}",); - if connection_state == RTCIceConnectionState::Connected { - notify1.notify_waiters(); - } - Box::pin(async {}) - }, - )); - // let pc = Arc::downgrade(&peer_connection); - // let mut candidates = Arc::new(Mutex::new(vec![])); - // let candidates1 = candidates.clone(); - // let notify_gather = Arc::new(Notify::new()); - // let notify1 = notify_gather.clone(); - // peer_connection.on_ice_candidate(Box::new(move |c: Option| { - // let pc2 = pc.clone(); - // let pending_candidates3 = Arc::clone(&pending_candidates2); - // Box::pin(async move { - // if let Some(c) = c { - // candidates1.lock().await.push(c); - // } else { - // notify1.notify_waiters(); - // } - // }) - // })); - let offer = peer_connection.create_offer(None).await?; - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(offer).await?; - let _ = gather_complete.recv().await; - - if let Some(sdp) = peer_connection.local_description().await { - let _ = offer_sdr.send(sdp).await; - } - println!("[Speaker] offer sent, waiting for answer"); - let answer = answer_rcv.recv().await.unwrap(); - peer_connection.set_remote_description(answer).await?; - println!("[Speaker] answer received, wait for quit event"); - - let timeout = tokio::time::sleep(Duration::from_secs(30)); - tokio::pin!(timeout); - tokio::select! { - _ = timeout.as_mut() => {} - _ = done_rx.recv() => {} - } - peer_connection.close().await?; - println!("[Speaker] closing peer"); - Ok(()) -} - -pub fn peer_closed(conn: &Arc) -> bool { - let state = conn.connection_state(); - state == RTCPeerConnectionState::Closed || state == RTCPeerConnectionState::Failed -} - -pub fn weak_peer_closed(conn: &Weak) -> bool { - let mut result = false; - if let Some(pc3) = conn.upgrade() { - if peer_closed(&pc3) { - result = true; - } - } else { - result = true - } - result -} - -// #[derive(Clone, Debug)] -// pub struct Nal { -// pub type_: UnitType, -// pub data: Vec, -// } - -// impl Nal { -// pub fn new(data: Vec) -> Result { -// Ok(Self { -// type_: Self::nal_unit_type(&data)?, -// data, -// }) -// } -// pub fn nal_unit_type(data: &[u8]) -> Result { -// UnitType::for_id((data[0] & 0b0111_1110) >> 1) -// } -// } diff --git a/examples/examples/play-from-disk-renegotiation/README.md b/examples/examples/play-from-disk-renegotiation/README.md deleted file mode 100644 index c18fb3d48..000000000 --- a/examples/examples/play-from-disk-renegotiation/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# play-from-disk-renegotiation - -play-from-disk-renegotiation demonstrates WebRTC.rs's renegotiation abilities. - -For a simpler example of playing a file from disk we also have [examples/play-from-disk](/examples/play-from-disk) - -## Instructions - -### Build play-from-disk-renegotiation - -```shell -cargo build --example play-from-disk-renegotiation -``` - -### Create IVF named `output.ivf` that contains a VP8 track - -```shell -ffmpeg -i $INPUT_FILE -g 30 output.ivf -``` - -### Run play-from-disk-renegotiation - -The `output.ivf` you created should be in the same directory as `play-from-disk-renegotiation`. - - -### Open the Web UI - -Open [http://localhost:8080](http://localhost:8080) and you should have a `Add Track` and `Remove Track` button. Press these to add as many tracks as you want, or to remove as many as you wish. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/play-from-disk-renegotiation/index.html b/examples/examples/play-from-disk-renegotiation/index.html deleted file mode 100644 index 77130feaf..000000000 --- a/examples/examples/play-from-disk-renegotiation/index.html +++ /dev/null @@ -1,80 +0,0 @@ - - - play-from-disk-renegotiation - - - - -
- -
- - -

Video

-
-
- -

Logs

-
- - - - diff --git a/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs b/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs deleted file mode 100644 index 503cd863d..000000000 --- a/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs +++ /dev/null @@ -1,410 +0,0 @@ -use std::fs::File; -use std::io::{BufReader, Write}; -use std::net::SocketAddr; -use std::path::Path; -use std::str::FromStr; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use tokio::sync::Mutex; -use tokio::time::Duration; -use tokio_util::codec::{BytesCodec, FramedRead}; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::ivf_reader::IVFReader; -use webrtc::media::Sample; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::RTCPeerConnection; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use webrtc::track::track_local::TrackLocal; -use webrtc::Error; - -#[macro_use] -extern crate lazy_static; - -lazy_static! { - static ref PEER_CONNECTION_MUTEX: Arc>>> = - Arc::new(Mutex::new(None)); - static ref VIDEO_FILE: Arc>> = Arc::new(Mutex::new(None)); -} - -static INDEX: &str = "examples/examples/play-from-disk-renegotiation/index.html"; -static NOTFOUND: &[u8] = b"Not Found"; - -/// HTTP status code 404 -fn not_found() -> Response { - Response::builder() - .status(StatusCode::NOT_FOUND) - .body(NOTFOUND.into()) - .unwrap() -} - -async fn simple_file_send(filename: &str) -> Result, hyper::Error> { - // Serve a file by asynchronously reading it by chunks using tokio-util crate. - - if let Ok(file) = tokio::fs::File::open(filename).await { - let stream = FramedRead::new(file, BytesCodec::new()); - let body = Body::wrap_stream(stream); - return Ok(Response::new(body)); - } - - Ok(not_found()) -} - -// HTTP Listener to get ICE Credentials/Candidate from remote Peer -async fn remote_handler(req: Request) -> Result, hyper::Error> { - let pc = { - let pcm = PEER_CONNECTION_MUTEX.lock().await; - pcm.clone().unwrap() - }; - - match (req.method(), req.uri().path()) { - (&Method::GET, "/") | (&Method::GET, "/index.html") => simple_file_send(INDEX).await, - - (&Method::POST, "/createPeerConnection") => create_peer_connection(&pc, req).await, - - (&Method::POST, "/addVideo") => add_video(&pc, req).await, - - (&Method::POST, "/removeVideo") => remove_video(&pc, req).await, - - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} - -// do_signaling exchanges all state of the local PeerConnection and is called -// every time a video is added or removed -async fn do_signaling( - pc: &Arc, - req: Request, -) -> Result, hyper::Error> { - let sdp_str = match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - let offer = match serde_json::from_str::(&sdp_str) { - Ok(s) => s, - Err(err) => panic!("{}", err), - }; - - if let Err(err) = pc.set_remote_description(offer).await { - panic!("{}", err); - } - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = pc.gathering_complete_promise().await; - - // Create an answer - let answer = match pc.create_answer(None).await { - Ok(answer) => answer, - Err(err) => panic!("{}", err), - }; - - // Sets the LocalDescription, and starts our UDP listeners - if let Err(err) = pc.set_local_description(answer).await { - panic!("{}", err); - } - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - let payload = if let Some(local_desc) = pc.local_description().await { - match serde_json::to_string(&local_desc) { - Ok(p) => p, - Err(err) => panic!("{}", err), - } - } else { - panic!("generate local_description failed!"); - }; - - let mut response = match Response::builder() - .header("content-type", "application/json") - .body(Body::from(payload)) - { - Ok(res) => res, - Err(err) => panic!("{}", err), - }; - - *response.status_mut() = StatusCode::OK; - Ok(response) -} - -// Add a single video track -async fn create_peer_connection( - pc: &Arc, - r: Request, -) -> Result, hyper::Error> { - if pc.connection_state() != RTCPeerConnectionState::New { - panic!( - "create_peer_connection called in non-new state ({})", - pc.connection_state() - ); - } - - println!("PeerConnection has been created"); - do_signaling(pc, r).await -} - -// Add a single video track -async fn add_video( - pc: &Arc, - r: Request, -) -> Result, hyper::Error> { - let video_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - format!("video-{}", rand::random::()), - format!("video-{}", rand::random::()), - )); - - let rtp_sender = match pc - .add_track(Arc::clone(&video_track) as Arc) - .await - { - Ok(rtp_sender) => rtp_sender, - Err(err) => panic!("{}", err), - }; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let video_file = { - let vf = VIDEO_FILE.lock().await; - vf.clone() - }; - - if let Some(video_file) = video_file { - tokio::spawn(async move { - let _ = write_video_to_track(video_file, video_track).await; - }); - } - - println!("Video track has been added"); - do_signaling(pc, r).await -} - -// Remove a single sender -async fn remove_video( - pc: &Arc, - r: Request, -) -> Result, hyper::Error> { - let senders = pc.get_senders().await; - if !senders.is_empty() { - if let Err(err) = pc.remove_track(&senders[0]).await { - panic!("{}", err); - } - } - - println!("Video track has been removed"); - do_signaling(pc, r).await -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("play-from-disk-renegotiation") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of play-from-disk-renegotiation.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let video_file = matches.value_of("video"); - - if let Some(video_file) = video_file { - if !Path::new(video_file).exists() { - return Err(Error::new(format!("video file: '{video_file}' not exist")).into()); - } - let mut vf = VIDEO_FILE.lock().await; - *vf = Some(video_file.to_owned()); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - { - let mut pcm = PEER_CONNECTION_MUTEX.lock().await; - *pcm = Some(Arc::clone(&peer_connection)); - } - - tokio::spawn(async move { - println!("Open http://localhost:8080 to access this demo"); - - let addr = SocketAddr::from_str("0.0.0.0:8080").unwrap(); - let service = - make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); - let server = Server::bind(&addr).serve(service); - // Run this server for... forever! - if let Err(e) = server.await { - eprintln!("server error: {e}"); - } - }); - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} - -// Read a video file from disk and write it to a webrtc.Track -// When the video has been completely read this exits without error -async fn write_video_to_track(video_file: String, t: Arc) -> Result<()> { - println!("play video from disk file {video_file}"); - - // Open a IVF file and start reading using our IVFReader - let file = File::open(video_file)?; - let reader = BufReader::new(file); - let (mut ivf, header) = IVFReader::new(reader)?; - - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep - // Send our video file frame at a time. Pace our sending so we send it at the same speed it should be played back as. - // This isn't required since the video is timestamped, but we will such much higher loss if we send all at once. - let sleep_time = Duration::from_millis( - ((1000 * header.timebase_numerator) / header.timebase_denominator) as u64, - ); - let mut ticker = tokio::time::interval(sleep_time); - loop { - let frame = match ivf.parse_next_frame() { - Ok((frame, _)) => frame, - Err(err) => { - println!("All video frames parsed and sent: {err}"); - return Err(err.into()); - } - }; - - t.write_sample(&Sample { - data: frame.freeze(), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await?; - - let _ = ticker.tick().await; - } -} diff --git a/examples/examples/play-from-disk-vpx/README.md b/examples/examples/play-from-disk-vpx/README.md deleted file mode 100644 index 05bc968f8..000000000 --- a/examples/examples/play-from-disk-vpx/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# play-from-disk-vpx - -play-from-disk-vpx demonstrates how to send vp8/vp8 video and/or audio to your browser from files saved to disk. - -## Instructions - -### Create IVF named `output_vp8.ivf` or `output_vp9.ivf` that contains a VP8/VP9 track and/or `output.ogg` that contains a Opus track - -```shell -ffmpeg -i $INPUT_FILE -g 30 output_vp8.ivf -ffmpeg -i $INPUT_FILE -g 30 -c libvpx-vp9 output_vp9.ivf -ffmpeg -i $INPUT_FILE -map 0:a -c:a dca -ac 2 -c:a libopus -page_duration 20000 -vn output.ogg -``` - -### Build play-from-disk-vpx - -```shell -cargo build --example play-from-disk-vpx -``` - -### Open play-from-disk-vpx example page - -[jsfiddle.net](https://jsfiddle.net/9s10amwL/) you should see two text-areas and a 'Start Session' button - -### Run play-from-disk-vpx with your browsers SessionDescription as stdin - -The `output_vp8.ivf`/`output_vp9.ivf` you created should be in the same directory as `play-from-disk-vpx`. In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -1. Run `echo $BROWSER_SDP | ./target/debug/examples/play-from-disk-vpx -v examples/test-data/output_vp8.ivf -a examples/test-data/output.ogg` -2. Run `echo $BROWSER_SDP | ./target/debug/examples/play-from-disk-vpx -v examples/test-data/output_vp9.ivf -a examples/test-data/output.ogg --vp9` - -#### Windows - -1. Paste the SessionDescription into a file. -2. Run `./target/debug/examples/play-from-disk-vpx -v examples/test-data/output_vp8.ivf -a examples/test-data/output.ogg < my_file` -3. Run `./target/debug/examples/play-from-disk-vpx -v examples/test-data/output_vp9.ivf -a examples/test-data/output.ogg --vp9 < my_file` - -### Input play-from-disk-vpx's SessionDescription into your browser - -Copy the text that `play-from-disk-vpx` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -A video should start playing in your browser above the input boxes. `play-from-disk-vpx` will exit when the file reaches the end - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs b/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs deleted file mode 100644 index d3a578dc8..000000000 --- a/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs +++ /dev/null @@ -1,372 +0,0 @@ -use std::fs::File; -use std::io::{BufReader, Write}; -use std::path::Path; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::sync::Notify; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_OPUS, MIME_TYPE_VP8, MIME_TYPE_VP9}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::ivf_reader::IVFReader; -use webrtc::media::io::ogg_reader::OggReader; -use webrtc::media::Sample; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use webrtc::track::track_local::TrackLocal; -use webrtc::Error; - -const OGG_PAGE_DURATION: Duration = Duration::from_millis(20); - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("play-from-disk-vpx") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of play-from-disk-vpx.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ) - .arg( - Arg::new("audio") - .takes_value(true) - .short('a') - .long("audio") - .help("Audio file to be streaming."), - ) - .arg( - Arg::new("vp9") - .long("vp9") - .help("Save VP9 to disk. Default: VP8"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let is_vp9 = matches.is_present("vp9"); - let video_file = matches.value_of("video"); - let audio_file = matches.value_of("audio"); - - if let Some(video_path) = &video_file { - if !Path::new(video_path).exists() { - return Err(Error::new(format!("video file: '{video_path}' not exist")).into()); - } - } - if let Some(audio_path) = &audio_file { - if !Path::new(audio_path).exists() { - return Err(Error::new(format!("audio file: '{audio_path}' not exist")).into()); - } - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let notify_tx = Arc::new(Notify::new()); - let notify_video = notify_tx.clone(); - let notify_audio = notify_tx.clone(); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let video_done_tx = done_tx.clone(); - let audio_done_tx = done_tx.clone(); - - if let Some(video_file) = video_file { - // Create a video track - let video_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: if is_vp9 { - MIME_TYPE_VP9.to_owned() - } else { - MIME_TYPE_VP8.to_owned() - }, - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&video_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let video_file_name = video_file.to_owned(); - tokio::spawn(async move { - // Open a IVF file and start reading using our IVFReader - let file = File::open(&video_file_name)?; - let reader = BufReader::new(file); - let (mut ivf, header) = IVFReader::new(reader)?; - - // Wait for connection established - notify_video.notified().await; - - println!("play video from disk file {video_file_name}"); - - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep - // Send our video file frame at a time. Pace our sending so we send it at the same speed it should be played back as. - // This isn't required since the video is timestamped, but we will such much higher loss if we send all at once. - let sleep_time = Duration::from_millis( - ((1000 * header.timebase_numerator) / header.timebase_denominator) as u64, - ); - let mut ticker = tokio::time::interval(sleep_time); - loop { - let frame = match ivf.parse_next_frame() { - Ok((frame, _)) => frame, - Err(err) => { - println!("All video frames parsed and sent: {err}"); - break; - } - }; - - video_track - .write_sample(&Sample { - data: frame.freeze(), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await?; - - let _ = ticker.tick().await; - } - - let _ = video_done_tx.try_send(()); - - Result::<()>::Ok(()) - }); - } - - if let Some(audio_file) = audio_file { - // Create a audio track - let audio_track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - ..Default::default() - }, - "audio".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&audio_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let audio_file_name = audio_file.to_owned(); - tokio::spawn(async move { - // Open a IVF file and start reading using our IVFReader - let file = File::open(audio_file_name)?; - let reader = BufReader::new(file); - // Open on oggfile in non-checksum mode. - let (mut ogg, _) = match OggReader::new(reader, true) { - Ok(tup) => tup, - Err(err) => { - println!("error while opening audio file output.ogg: {err}"); - return Err(err.into()); - } - }; - // Wait for connection established - notify_audio.notified().await; - - println!("play audio from disk file output.ogg"); - - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep - let mut ticker = tokio::time::interval(OGG_PAGE_DURATION); - - // Keep track of last granule, the difference is the amount of samples in the buffer - let mut last_granule: u64 = 0; - while let Ok((page_data, page_header)) = ogg.parse_next_page() { - // The amount of samples is the difference between the last and current timestamp - let sample_count = page_header.granule_position - last_granule; - last_granule = page_header.granule_position; - let sample_duration = Duration::from_millis(sample_count * 1000 / 48000); - - audio_track - .write_sample(&Sample { - data: page_data.freeze(), - duration: sample_duration, - ..Default::default() - }) - .await?; - - let _ = ticker.tick().await; - } - - let _ = audio_done_tx.try_send(()); - - Result::<()>::Ok(()) - }); - } - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - notify_tx.notify_waiters(); - } - Box::pin(async {}) - }, - )); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/rc-cycle/rc-cycle.rs b/examples/examples/rc-cycle/rc-cycle.rs deleted file mode 100644 index 88ba15e3d..000000000 --- a/examples/examples/rc-cycle/rc-cycle.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::cell::RefCell; -use std::rc::Rc; - -#[derive(Clone)] -struct Cycle { - cell: RefCell>>, -} - -impl Drop for Cycle { - fn drop(&mut self) { - println!("freed"); - } -} - -#[tokio::main] -async fn main() { - let cycle = Rc::new(Cycle { - cell: RefCell::new(None), - }); - *cycle.cell.borrow_mut() = Some(cycle.clone()); -} - -// use nightly rust -// RUSTFLAGS="-Z sanitizer=leak" cargo build --example rc-cycle -// ./target/debug/example/rc-cycle -// ================================================================= -// ==1457719==ERROR: LeakSanitizer: detected memory leaks -// -// Direct leak of 32 byte(s) in 1 object(s) allocated from: -// #0 0x55d4688e1b58 in malloc /rustc/llvm/src/llvm-project/compiler-rt/lib/lsan/lsan_interceptors.cpp:56:3 -// #1 0x55d4689db6cb in alloc::alloc::alloc::h1ab42fe6949393de /rustc/e269e6bf47f40c9046cd44ab787881d700099252/library/alloc/src/alloc.rs:86:14 -// -// SUMMARY: LeakSanitizer: 32 byte(s) leaked in 1 allocation(s). diff --git a/examples/examples/reflect/README.md b/examples/examples/reflect/README.md deleted file mode 100644 index dea22a1e7..000000000 --- a/examples/examples/reflect/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# reflect - -reflect demonstrates how with one PeerConnection you can send video to webrtc-rs and have the packets sent back. This example could be easily extended to do server side processing. - -## Instructions - -### Build reflect - -```shell -cargo build --example reflect -``` - -### Open reflect example page - -[jsfiddle.net](https://jsfiddle.net/9jgukzt1/) you should see two text-areas and a 'Start Session' button. - -### Run reflect, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/reflect -a -v` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/reflect -a -v < my_file` - -### Input reflect's SessionDescription into your browser - -Copy the text that `reflect` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -Your browser should send video to webrtc-rs, and then it will be relayed right back to you. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/reflect/reflect.rs b/examples/examples/reflect/reflect.rs deleted file mode 100644 index c01875e42..000000000 --- a/examples/examples/reflect/reflect.rs +++ /dev/null @@ -1,326 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_OPUS, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::{ - RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, -}; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("reflect") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of how to send back to the user exactly what it receives using the same PeerConnection.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ).arg( - Arg::new("audio") - .long("audio") - .short('a') - .help("Enable audio reflect"), - ).arg( - Arg::new("video") - .long("video") - .short('v') - .help("Enable video reflect"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let audio = matches.is_present("audio"); - let video = matches.is_present("video"); - if !audio && !video { - println!("one of audio or video must be enabled"); - std::process::exit(0); - } - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Setup the codecs you want to use. - if audio { - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - ..Default::default() - }, - payload_type: 120, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - } - - // We'll use a VP8 and Opus but you can also define your own - if video { - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTPCodecType::Video, - )?; - } - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - let mut output_tracks = HashMap::new(); - let mut media = vec![]; - if audio { - media.push("audio"); - } - if video { - media.push("video"); - }; - for s in media { - let output_track = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: if s == "video" { - MIME_TYPE_VP8.to_owned() - } else { - MIME_TYPE_OPUS.to_owned() - }, - ..Default::default() - }, - format!("track-{s}"), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&output_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - let m = s.to_owned(); - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - println!("{m} rtp_sender.read loop exit"); - Result::<()>::Ok(()) - }); - - output_tracks.insert(s.to_owned(), output_track); - } - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Set a handler for when a new remote track starts, this handler copies inbound RTP packets, - // replaces the SSRC and sends them back - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it - let media_ssrc = track.ssrc(); - - if track.kind() == RTPCodecType::Video { - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } - }; - } - }); - } - - let kind = if track.kind() == RTPCodecType::Audio { - "audio" - } else { - "video" - }; - let output_track = if let Some(output_track) = output_tracks.get(kind) { - Arc::clone(output_track) - } else { - println!("output_track not found for type = {kind}"); - return Box::pin(async {}); - }; - - let output_track2 = Arc::clone(&output_track); - tokio::spawn(async move { - println!( - "Track has started, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - // Read RTP packets being sent to webrtc-rs - while let Ok((rtp, _)) = track.read_rtp().await { - if let Err(err) = output_track2.write_rtp(&rtp).await { - println!("output track write_rtp got error: {err}"); - break; - } - } - - println!( - "on_track finished, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - }); - - Box::pin(async {}) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - //let timeout = tokio::time::sleep(Duration::from_secs(20)); - //tokio::pin!(timeout); - - tokio::select! { - //_ = timeout.as_mut() => { - // println!("received timeout signal!"); - //} - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/rtp-forwarder/README.md b/examples/examples/rtp-forwarder/README.md deleted file mode 100644 index 04fb0bf45..000000000 --- a/examples/examples/rtp-forwarder/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# rtp-forwarder - -rtp-forwarder is a simple application that shows how to forward your webcam/microphone via RTP using WebRTC.rs. - -## Instructions - -### Build rtp-forwarder - -```shell -cargo build --example rtp-forwarder -``` - -### Open rtp-forwarder example page - -[jsfiddle.net](https://jsfiddle.net/1qva2zd8/) you should see your Webcam, two text-areas and a 'Start Session' button - -### Run rtp-forwarder, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/rtp-forwarder` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/rtp-forwarder < my_file` - -### Input rtp-forwarder's SessionDescription into your browser - -Copy the text that `rtp-forwarder` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle and enjoy your RTP forwarded stream! - -You can run any of these commands at anytime. The media is live/stateless, you can switch commands without restarting Pion. - -#### VLC - -Open `rtp-forwarder.sdp` with VLC and enjoy your live video! - -#### ffmpeg/ffprobe - -Run `ffprobe -i rtp-forwarder.sdp -protocol_whitelist file,udp,rtp` to get more details about your streams - -Run `ffplay -i rtp-forwarder.sdp -protocol_whitelist file,udp,rtp` to play your streams - -You can add `-fflags nobuffer` to lower the latency. You will have worse playback in networks with jitter. - -#### Twitch/RTMP - -`ffmpeg -protocol_whitelist file,udp,rtp -i rtp-forwarder.sdp -c:v libx264 -preset veryfast -b:v 3000k -maxrate 3000k -bufsize 6000k -pix_fmt yuv420p -g 50 -c:a aac -b:a 160k -ac 2 -ar 44100 -f flv rtmp://live.twitch.tv/app/$STREAM_KEY` Make sure to replace `$STREAM_KEY` at the end of the URL first. diff --git a/examples/examples/rtp-forwarder/rtp-forwarder.rs b/examples/examples/rtp-forwarder/rtp-forwarder.rs deleted file mode 100644 index c6e548ae3..000000000 --- a/examples/examples/rtp-forwarder/rtp-forwarder.rs +++ /dev/null @@ -1,321 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::net::UdpSocket; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_OPUS, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::{ - RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, -}; -use webrtc::util::{Conn, Marshal}; - -#[derive(Clone)] -struct UdpConn { - conn: Arc, - payload_type: u8, -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("rtp-forwarder") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of rtp-forwarder.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Setup the codecs you want to use. - // We'll use a VP8 and Opus but you can also define your own - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Allow us to receive 1 audio track, and 1 video track - peer_connection - .add_transceiver_from_kind(RTPCodecType::Audio, None) - .await?; - peer_connection - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - // Prepare udp conns - // Also update incoming packets with expected PayloadType, the browser may use - // a different value. We have to modify so our stream matches what rtp-forwarder.sdp expects - let mut udp_conns = HashMap::new(); - udp_conns.insert( - "audio".to_owned(), - UdpConn { - conn: { - let sock = UdpSocket::bind("127.0.0.1:0").await?; - sock.connect(format!("127.0.0.1:{}", 4000)).await?; - Arc::new(sock) - }, - payload_type: 111, - }, - ); - udp_conns.insert( - "video".to_owned(), - UdpConn { - conn: { - let sock = UdpSocket::bind("127.0.0.1:0").await?; - sock.connect(format!("127.0.0.1:{}", 4002)).await?; - Arc::new(sock) - }, - payload_type: 96, - }, - ); - - // Set a handler for when a new remote track starts, this handler will forward data to - // our UDP listeners. - // In your application this is where you would handle/process audio/video - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Retrieve udp connection - let c = if let Some(c) = udp_conns.get(&track.kind().to_string()) { - c.clone() - } else { - return Box::pin(async {}); - }; - - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } - }; - } - }); - - tokio::spawn(async move { - let mut b = vec![0u8; 1500]; - while let Ok((mut rtp_packet, _)) = track.read(&mut b).await { - // Update the PayloadType - rtp_packet.header.payload_type = c.payload_type; - - // Marshal into original buffer with updated PayloadType - - let n = rtp_packet.marshal_to(&mut b)?; - - // Write - if let Err(err) = c.conn.send(&b[..n]).await { - // For this particular example, third party applications usually timeout after a short - // amount of time during which the user doesn't have enough time to provide the answer - // to the browser. - // That's why, for this particular example, the user first needs to provide the answer - // to the browser then open the third party application. Therefore we must not kill - // the forward on "connection refused" errors - //if opError, ok := err.(*net.OpError); ok && opError.Err.Error() == "write: connection refused" { - // continue - //} - //panic(err) - if err.to_string().contains("Connection refused") { - continue; - } else { - println!("conn send err: {err}"); - break; - } - } - } - - Result::<()>::Ok(()) - }); - - Box::pin(async {}) - })); - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - println!("Ctrl+C the remote client to stop the demo"); - } - Box::pin(async {}) - }, - )); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting: Done forwarding"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/rtp-forwarder/rtp-forwarder.sdp b/examples/examples/rtp-forwarder/rtp-forwarder.sdp deleted file mode 100644 index bb0367c11..000000000 --- a/examples/examples/rtp-forwarder/rtp-forwarder.sdp +++ /dev/null @@ -1,9 +0,0 @@ -v=0 -o=- 0 0 IN IP4 127.0.0.1 -s=WebRTC.rs -c=IN IP4 127.0.0.1 -t=0 0 -m=audio 4000 RTP/AVP 111 -a=rtpmap:111 OPUS/48000/2 -m=video 4002 RTP/AVP 96 -a=rtpmap:96 VP8/90000 \ No newline at end of file diff --git a/examples/examples/rtp-to-webrtc/README.md b/examples/examples/rtp-to-webrtc/README.md deleted file mode 100644 index 2ccb2cb7f..000000000 --- a/examples/examples/rtp-to-webrtc/README.md +++ /dev/null @@ -1,56 +0,0 @@ -# rtp-to-webrtc - -rtp-to-webrtc demonstrates how to consume a RTP stream video UDP, and then send to a WebRTC client. - -With this example we have pre-made GStreamer and ffmpeg pipelines, but you can use any tool you like! - -## Instructions - -### Build rtp-to-webrtc - -```shell -cargo build --example rtp-to-webrtc -``` - -### Open jsfiddle example page - -[jsfiddle.net](https://jsfiddle.net/z7ms3u5r/) you should see two text-areas and a 'Start Session' button - -### Run rtp-to-webrtc with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser's SessionDescription, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/rtp-to-webrtc` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/rtp-to-webrtc < my_file` - -### Send RTP to listening socket - -You can use any software to send VP8 packets to port 5004. We also have the pre made examples below - -#### GStreamer - -```shell -gst-launch-1.0 videotestsrc ! video/x-raw,width=640,height=480,format=I420 ! vp8enc error-resilient=partitions keyframe-max-dist=10 auto-alt-ref=true cpu-used=5 deadline=1 ! rtpvp8pay ! udpsink host=127.0.0.1 port=5004 -``` - -#### ffmpeg - -```shell -ffmpeg -re -f lavfi -i testsrc=size=640x480:rate=30 -vcodec libvpx -cpu-used 5 -deadline 1 -g 10 -error-resilient 1 -auto-alt-ref 1 -f rtp rtp://127.0.0.1:5004?pkt_size=1200 -``` - -### Input rtp-to-webrtc's SessionDescription into your browser - -Copy the text that `rtp-to-webrtc` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -A video should start playing in your browser above the input boxes. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs b/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs deleted file mode 100644 index ddcb0555f..000000000 --- a/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::net::UdpSocket; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; -use webrtc::Error; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("rtp-forwarder") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of rtp-forwarder.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Create Track that we send video back to browser on - let video_track = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&video_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - let done_tx1 = done_tx.clone(); - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Failed { - let _ = done_tx1.try_send(()); - } - Box::pin(async {}) - }, - )); - - let done_tx2 = done_tx.clone(); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting: Done forwarding"); - let _ = done_tx2.try_send(()); - } - - Box::pin(async {}) - })); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - // Open a UDP Listener for RTP Packets on port 5004 - let listener = UdpSocket::bind("127.0.0.1:5004").await?; - - let done_tx3 = done_tx.clone(); - // Read RTP packets forever and send them to the WebRTC Client - tokio::spawn(async move { - let mut inbound_rtp_packet = vec![0u8; 1600]; // UDP MTU - while let Ok((n, _)) = listener.recv_from(&mut inbound_rtp_packet).await { - if let Err(err) = video_track.write(&inbound_rtp_packet[..n]).await { - if Error::ErrClosedPipe == err { - // The peerConnection has been closed. - } else { - println!("video_track write err: {err}"); - } - let _ = done_tx3.try_send(()); - return; - } - } - }); - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/save-to-disk-h264/README.md b/examples/examples/save-to-disk-h264/README.md deleted file mode 100644 index f18102cc1..000000000 --- a/examples/examples/save-to-disk-h264/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# save-to-disk-h264 - -save-to-disk-h264 is a simple application that shows how to record your webcam/microphone using WebRTC.rs and save H264 and Opus to disk. - -## Instructions - -### Build save-to-disk-h264 - -```shell -cargo build --example save-to-disk-h264 -``` - -### Open save-to-disk example page - -[jsfiddle.net](https://jsfiddle.net/vfmcg8rk/1/) you should see your Webcam, two text-areas and a 'Start Session' button - -### Run save-to-disk-h264, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/save-to-disk-h264` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/save-to-disk-h264 < my_file` - -### Input save-to-disk-h264's SessionDescription into your browser - -Copy the text that `save-to-disk-h264` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, wait, close jsfiddle, enjoy your video! - -In the folder you ran `save-to-disk-h264` you should now have a file `output.h264` play with your video player of choice! - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/save-to-disk-h264/save-to-disk-h264.rs b/examples/examples/save-to-disk-h264/save-to-disk-h264.rs deleted file mode 100644 index 238320b37..000000000 --- a/examples/examples/save-to-disk-h264/save-to-disk-h264.rs +++ /dev/null @@ -1,318 +0,0 @@ -use std::fs::File; -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::sync::{Mutex, Notify}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_H264, MIME_TYPE_OPUS}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::h264_writer::H264Writer; -use webrtc::media::io::ogg_writer::OggWriter; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::{ - RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, -}; -use webrtc::track::track_remote::TrackRemote; - -async fn save_to_disk( - writer: Arc>, - track: Arc, - notify: Arc, -) -> Result<()> { - loop { - tokio::select! { - result = track.read_rtp() => { - if let Ok((rtp_packet, _)) = result { - let mut w = writer.lock().await; - w.write_rtp(&rtp_packet)?; - }else{ - println!("file closing begin after read_rtp error"); - let mut w = writer.lock().await; - if let Err(err) = w.close() { - println!("file close err: {err}"); - } - println!("file closing end after read_rtp error"); - return Ok(()); - } - } - _ = notify.notified() => { - println!("file closing begin after notified"); - let mut w = writer.lock().await; - if let Err(err) = w.close() { - println!("file close err: {err}"); - } - println!("file closing end after notified"); - return Ok(()); - } - } - } -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("save-to-disk-h264") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of save-to-disk-h264.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ) - .arg( - Arg::new("audio") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('a') - .long("audio") - .help("Audio file to be streaming."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let video_file = matches.value_of("video").unwrap(); - let audio_file = matches.value_of("audio").unwrap(); - - let h264_writer: Arc> = - Arc::new(Mutex::new(H264Writer::new(File::create(video_file)?))); - let ogg_writer: Arc> = Arc::new(Mutex::new( - OggWriter::new(File::create(audio_file)?, 48000, 2)?, - )); - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Setup the codecs you want to use. - // We'll use a H264 and Opus but you can also define your own - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 102, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Allow us to receive 1 audio track, and 1 video track - peer_connection - .add_transceiver_from_kind(RTPCodecType::Audio, None) - .await?; - peer_connection - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let notify_tx = Arc::new(Notify::new()); - let notify_rx = notify_tx.clone(); - - // Set a handler for when a new remote track starts, this handler saves buffers to disk as - // an ivf file, since we could have multiple video tracks we provide a counter. - // In your application this is where you would handle/process video - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else { - break; - } - } - }; - } - }); - - let notify_rx2 = Arc::clone(¬ify_rx); - let h264_writer2 = Arc::clone(&h264_writer); - let ogg_writer2 = Arc::clone(&ogg_writer); - Box::pin(async move { - let codec = track.codec(); - let mime_type = codec.capability.mime_type.to_lowercase(); - if mime_type == MIME_TYPE_OPUS.to_lowercase() { - println!("Got Opus track, saving to disk as output.opus (48 kHz, 2 channels)"); - tokio::spawn(async move { - let _ = save_to_disk(ogg_writer2, track, notify_rx2).await; - }); - } else if mime_type == MIME_TYPE_H264.to_lowercase() { - println!("Got h264 track, saving to disk as output.h264"); - tokio::spawn(async move { - let _ = save_to_disk(h264_writer2, track, notify_rx2).await; - }); - } - }) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - - if connection_state == RTCIceConnectionState::Connected { - println!("Ctrl+C the remote client to stop the demo"); - } else if connection_state == RTCIceConnectionState::Failed { - notify_tx.notify_waiters(); - - println!("Done writing media files"); - - let _ = done_tx.try_send(()); - } - Box::pin(async {}) - }, - )); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/save-to-disk-vpx/README.md b/examples/examples/save-to-disk-vpx/README.md deleted file mode 100644 index 2e18057f7..000000000 --- a/examples/examples/save-to-disk-vpx/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# save-to-disk-vpx - -save-to-disk-vpx is a simple application that shows how to record your webcam/microphone using WebRTC.rs and save VP8/VP9 and Opus to disk. - -## Instructions - -### Build save-to-disk-vpx - -```shell -cargo build --example save-to-disk-vpx -``` - -### Open save-to-disk-vpx example page - -[jsfiddle.net](https://jsfiddle.net/vfmcg8rk/1/) you should see your Webcam, two text-areas and a 'Start Session' button - -### Run save-to-disk-vpx, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/save-to-disk-vpx` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/save-to-disk-vpx < my_file` - -### Input save-to-disk-vpx's SessionDescription into your browser - -Copy the text that `save-to-disk-vpx` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, wait, close jsfiddle, enjoy your video! - -In the folder you ran `save-to-disk-vpx` you should now have a file `output_vpx.ivf` play with your video player of choice! - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs b/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs deleted file mode 100644 index 2b02986b6..000000000 --- a/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs +++ /dev/null @@ -1,348 +0,0 @@ -use std::fs::File; -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::sync::{Mutex, Notify}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_OPUS, MIME_TYPE_VP8, MIME_TYPE_VP9}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::media::io::ivf_reader::IVFFileHeader; -use webrtc::media::io::ivf_writer::IVFWriter; -use webrtc::media::io::ogg_writer::OggWriter; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::{ - RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, -}; -use webrtc::track::track_remote::TrackRemote; - -async fn save_to_disk( - writer: Arc>, - track: Arc, - notify: Arc, -) -> Result<()> { - loop { - tokio::select! { - result = track.read_rtp() => { - if let Ok((rtp_packet, _)) = result { - let mut w = writer.lock().await; - w.write_rtp(&rtp_packet)?; - }else{ - println!("file closing begin after read_rtp error"); - let mut w = writer.lock().await; - if let Err(err) = w.close() { - println!("file close err: {err}"); - } - println!("file closing end after read_rtp error"); - return Ok(()); - } - } - _ = notify.notified() => { - println!("file closing begin after notified"); - let mut w = writer.lock().await; - if let Err(err) = w.close() { - println!("file close err: {err}"); - } - println!("file closing end after notified"); - return Ok(()); - } - } - } -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("save-to-disk-vpx") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of save-to-disk-vpx.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ) - .arg( - Arg::new("vp9") - .long("vp9") - .help("Save VP9 to disk. Default: VP8"), - ) - .arg( - Arg::new("video") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('v') - .long("video") - .help("Video file to be streaming."), - ) - .arg( - Arg::new("audio") - .required_unless_present("FULLHELP") - .takes_value(true) - .short('a') - .long("audio") - .help("Audio file to be streaming."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - let is_vp9 = matches.is_present("vp9"); - let video_file = matches.value_of("video").unwrap(); - let audio_file = matches.value_of("audio").unwrap(); - - let ivf_writer: Arc> = - Arc::new(Mutex::new(IVFWriter::new( - File::create(video_file)?, - &IVFFileHeader { - signature: *b"DKIF", // 0-3 - version: 0, // 4-5 - header_size: 32, // 6-7 - four_cc: if is_vp9 { *b"VP90" } else { *b"VP80" }, // 8-11 - width: 640, // 12-13 - height: 480, // 14-15 - timebase_denominator: 30, // 16-19 - timebase_numerator: 1, // 20-23 - num_frames: 900, // 24-27 - unused: 0, // 28-31 - }, - )?)); - let ogg_writer: Arc> = Arc::new(Mutex::new( - OggWriter::new(File::create(audio_file)?, 48000, 2)?, - )); - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Setup the codecs you want to use. - // We'll use a VP8/VP9 and Opus but you can also define your own - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: if is_vp9 { - MIME_TYPE_VP9.to_owned() - } else { - MIME_TYPE_VP8.to_owned() - }, - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: if is_vp9 { 98 } else { 96 }, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Allow us to receive 1 audio track, and 1 video track - peer_connection - .add_transceiver_from_kind(RTPCodecType::Audio, None) - .await?; - peer_connection - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let notify_tx = Arc::new(Notify::new()); - let notify_rx = notify_tx.clone(); - - // Set a handler for when a new remote track starts, this handler saves buffers to disk as - // an ivf file, since we could have multiple video tracks we provide a counter. - // In your application this is where you would handle/process video - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } - }; - } - }); - - let notify_rx2 = Arc::clone(¬ify_rx); - let ivf_writer2 = Arc::clone(&ivf_writer); - let ogg_writer2 = Arc::clone(&ogg_writer); - Box::pin(async move { - let codec = track.codec(); - let mime_type = codec.capability.mime_type.to_lowercase(); - if mime_type == MIME_TYPE_OPUS.to_lowercase() { - println!("Got Opus track, saving to disk as output.opus (48 kHz, 2 channels)"); - tokio::spawn(async move { - let _ = save_to_disk(ogg_writer2, track, notify_rx2).await; - }); - } else if mime_type == MIME_TYPE_VP8.to_lowercase() - || mime_type == MIME_TYPE_VP9.to_lowercase() - { - println!( - "Got {} track, saving to disk as output.ivf", - if is_vp9 { "VP9" } else { "VP8" } - ); - tokio::spawn(async move { - let _ = save_to_disk(ivf_writer2, track, notify_rx2).await; - }); - } - }) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - - if connection_state == RTCIceConnectionState::Connected { - println!("Ctrl+C the remote client to stop the demo"); - } else if connection_state == RTCIceConnectionState::Failed { - notify_tx.notify_waiters(); - - println!("Done writing media files"); - - let _ = done_tx.try_send(()); - } - Box::pin(async {}) - }, - )); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/signal/Cargo.toml b/examples/examples/signal/Cargo.toml deleted file mode 100644 index 3c71bb789..000000000 --- a/examples/examples/signal/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "signal" -version = "0.1.0" -edition = "2021" - -[dependencies] -tokio = { version = "1.32.0", features = ["full"] } -anyhow = "1" -base64 = "0.21" -lazy_static = "1" -hyper = { version = "0.14.27", features = ["full"] } diff --git a/examples/examples/signal/src/lib.rs b/examples/examples/signal/src/lib.rs deleted file mode 100644 index c6267b848..000000000 --- a/examples/examples/signal/src/lib.rs +++ /dev/null @@ -1,149 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -use std::net::SocketAddr; -use std::str::FromStr; -use std::sync::Arc; - -use anyhow::Result; -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use tokio::sync::{mpsc, Mutex}; - -#[macro_use] -extern crate lazy_static; - -lazy_static! { - static ref SDP_CHAN_TX_MUTEX: Arc>>> = - Arc::new(Mutex::new(None)); -} - -// HTTP Listener to get sdp -async fn remote_handler(req: Request) -> Result, hyper::Error> { - match (req.method(), req.uri().path()) { - // A HTTP handler that processes a SessionDescription given to us from the other WebRTC-rs or Pion process - (&Method::POST, "/sdp") => { - //println!("remote_handler receive from /sdp"); - let sdp_str = match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) - { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - - { - let sdp_chan_tx = SDP_CHAN_TX_MUTEX.lock().await; - if let Some(tx) = &*sdp_chan_tx { - let _ = tx.send(sdp_str).await; - } - } - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} - -/// http_sdp_server starts a HTTP Server that consumes SDPs -pub async fn http_sdp_server(port: u16) -> mpsc::Receiver { - let (sdp_chan_tx, sdp_chan_rx) = mpsc::channel::(1); - { - let mut tx = SDP_CHAN_TX_MUTEX.lock().await; - *tx = Some(sdp_chan_tx); - } - - tokio::spawn(async move { - let addr = SocketAddr::from_str(&format!("0.0.0.0:{port}")).unwrap(); - let service = - make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); - let server = Server::bind(&addr).serve(service); - // Run this server for... forever! - if let Err(e) = server.await { - eprintln!("server error: {e}"); - } - }); - - sdp_chan_rx -} - -/// must_read_stdin blocks until input is received from stdin -#[allow(clippy::assigning_clones)] -pub fn must_read_stdin() -> Result { - let mut line = String::new(); - - std::io::stdin().read_line(&mut line)?; - line = line.trim().to_owned(); - println!(); - - Ok(line) -} - -// Allows compressing offer/answer to bypass terminal input limits. -// const COMPRESS: bool = false; - -/// encode encodes the input in base64 -/// It can optionally zip the input before encoding -pub fn encode(b: &str) -> String { - //if COMPRESS { - // b = zip(b) - //} - - BASE64_STANDARD.encode(b) -} - -/// decode decodes the input from base64 -/// It can optionally unzip the input after decoding -pub fn decode(s: &str) -> Result { - let b = BASE64_STANDARD.decode(s)?; - - //if COMPRESS { - // b = unzip(b) - //} - - let s = String::from_utf8(b)?; - Ok(s) -} -/* -func zip(in []byte) []byte { - var b bytes.Buffer - gz := gzip.NewWriter(&b) - _, err := gz.Write(in) - if err != nil { - panic(err) - } - err = gz.Flush() - if err != nil { - panic(err) - } - err = gz.Close() - if err != nil { - panic(err) - } - return b.Bytes() -} - -func unzip(in []byte) []byte { - var b bytes.Buffer - _, err := b.Write(in) - if err != nil { - panic(err) - } - r, err := gzip.NewReader(&b) - if err != nil { - panic(err) - } - res, err := ioutil.ReadAll(r) - if err != nil { - panic(err) - } - return res -} -*/ diff --git a/examples/examples/simulcast/README.md b/examples/examples/simulcast/README.md deleted file mode 100644 index 31962490c..000000000 --- a/examples/examples/simulcast/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# simulcast - -demonstrates of how to handle incoming track with multiple simulcast rtp streams and show all them back. - -The browser will not send higher quality streams unless it has the available bandwidth. You can look at -the bandwidth estimation in `chrome://webrtc-internals`. It is under `VideoBwe` when `Read Stats From: Legacy non-Standard` -is selected. - -## Instructions - -### Build simulcast - -```shell -cargo build --example simulcast -``` - -### Open simulcast example page - -[jsfiddle.net](https://jsfiddle.net/rxk4bftc) you should see two text-areas and a 'Start Session' button. - -### Run simulcast, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/simulcast` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/simulcast < my_file` - -### Input simulcast's SessionDescription into your browser - -Copy the text that `simulcast` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -Your browser should send a simulcast track to WebRTC.rs, and then all 3 incoming streams will be relayed back. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/simulcast/simulcast.rs b/examples/examples/simulcast/simulcast.rs deleted file mode 100644 index f47c7b620..000000000 --- a/examples/examples/simulcast/simulcast.rs +++ /dev/null @@ -1,266 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::{ - RTCRtpCodecCapability, RTCRtpHeaderExtensionCapability, RTPCodecType, -}; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; -use webrtc::Error; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("simulcast") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of simulcast.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - m.register_default_codecs()?; - - // Enable Extension Headers needed for Simulcast - for extension in [ - "urn:ietf:params:rtp-hdrext:sdes:mid", - "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id", - "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id", - ] { - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: extension.to_owned(), - }, - RTPCodecType::Video, - None, - )?; - } - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - // Create Track that we send video back to browser on - let mut output_tracks = HashMap::new(); - for s in ["q", "h", "f"] { - let output_track = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - format!("video_{s}"), - format!("webrtc-rs_{s}"), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&output_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - output_tracks.insert(s.to_owned(), output_track); - } - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Set a handler for when a new remote track starts - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - println!("Track has started"); - - let rid = track.rid().to_owned(); - let output_track = if let Some(output_track) = output_tracks.get(&rid) { - Arc::clone(output_track) - } else { - println!("output_track not found for rid = {rid}"); - return Box::pin(async {}); - }; - - // Start reading from all the streams and sending them to the related output track - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - println!("Sending pli for stream with rid: {rid}, ssrc: {media_ssrc}"); - - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } - }; - } - }); - - tokio::spawn(async move { - // Read RTP packets being sent to webrtc-rs - println!("enter track loop {}", track.rid()); - while let Ok((rtp, _)) = track.read_rtp().await { - if let Err(err) = output_track.write_rtp(&rtp).await { - if Error::ErrClosedPipe != err { - println!("output track write_rtp got error: {err} and break"); - break; - } else { - println!("output track write_rtp got error: {err}"); - } - } - } - println!("exit track loop {}", track.rid()); - }); - - Box::pin(async {}) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - println!("Press ctrl-c to stop"); - tokio::select! { - _ = done_rx.recv() => { - println!("received done signal!"); - } - _ = tokio::signal::ctrl_c() => { - println!(); - } - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/swap-tracks/README.md b/examples/examples/swap-tracks/README.md deleted file mode 100644 index 1157ae206..000000000 --- a/examples/examples/swap-tracks/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# swap-tracks - -swap-tracks demonstrates how to swap multiple incoming tracks on a single outgoing track. - -## Instructions - -### Build swap-tracks - -```shell -cargo build --example swap-tracks -``` - -### Open swap-tracks example page - -[jsfiddle.net](https://jsfiddle.net/dzc17fga/) you should see two text-areas and a 'Start Session' button. - -### Run swap-tracks, with your browsers SessionDescription as stdin - -In the jsfiddle the top textarea is your browser, copy that and: - -#### Linux/macOS - -Run `echo $BROWSER_SDP | ./target/debug/examples/swap-tracks` - -#### Windows - -1. Paste the SessionDescription into a file. -1. Run `./target/debug/examples/swap-tracks < my_file` - -### Input swap-tracks's SessionDescription into your browser - -Copy the text that `swap-tracks` just emitted and copy into second text area - -### Hit 'Start Session' in jsfiddle, enjoy your video! - -Your browser should send streams to webrtc-rs, and then a stream will be relayed back, changing every 5 seconds. - -Congrats, you have used WebRTC.rs! diff --git a/examples/examples/swap-tracks/swap-tracks.rs b/examples/examples/swap-tracks/swap-tracks.rs deleted file mode 100644 index db8cd25d7..000000000 --- a/examples/examples/swap-tracks/swap-tracks.rs +++ /dev/null @@ -1,315 +0,0 @@ -use std::io::Write; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; - -use anyhow::Result; -use clap::{AppSettings, Arg, Command}; -use tokio::time::Duration; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use webrtc::api::APIBuilder; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; -use webrtc::Error; - -#[tokio::main] -async fn main() -> Result<()> { - let mut app = Command::new("swap-tracks") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of swap-tracks.") - .setting(AppSettings::DeriveDisplayOrder) - .subcommand_negates_reqs(true) - .arg( - Arg::new("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::new("debug") - .long("debug") - .short('d') - .help("Prints debug log information"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let debug = matches.is_present("debug"); - if debug { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - } - - // Everything below is the WebRTC-rs API! Thanks for using it โค๏ธ. - - // Create a MediaEngine object to configure the supported codec - let mut m = MediaEngine::default(); - - // Setup the codecs you want to use. - m.register_default_codecs()?; - - // Create a InterceptorRegistry. This is the user configurable RTP/RTCP Pipeline. - // This provides NACKs, RTCP Reports and other features. If you use `webrtc.NewPeerConnection` - // this is enabled by default. If you are manually managing You MUST create a InterceptorRegistry - // for each PeerConnection. - let mut registry = Registry::new(); - - // Use the default set of Interceptors - registry = register_default_interceptors(registry, &mut m)?; - - // Create the API object with the MediaEngine - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - // Prepare the configuration - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let output_track = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - // Add this newly created track to the PeerConnection - let rtp_sender = peer_connection - .add_track(Arc::clone(&output_track) as Arc) - .await?; - - // Read incoming RTCP packets - // Before these packets are returned they are processed by interceptors. For things - // like NACK this needs to be called. - tokio::spawn(async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} - Result::<()>::Ok(()) - }); - - // Wait for the offer to be pasted - let line = signal::must_read_stdin()?; - let desc_data = signal::decode(line.as_str())?; - let offer = serde_json::from_str::(&desc_data)?; - - // Set the remote SessionDescription - peer_connection.set_remote_description(offer).await?; - - // Which track is currently being handled - let curr_track = Arc::new(AtomicUsize::new(0)); - // The total number of tracks - let track_count = Arc::new(AtomicUsize::new(0)); - // The channel of packets with a bit of buffer - let (packets_tx, mut packets_rx) = - tokio::sync::mpsc::channel::(60); - let packets_tx = Arc::new(packets_tx); - - // Set a handler for when a new remote track starts, this handler copies inbound RTP packets, - // replaces the SSRC and sends them back - let pc = Arc::downgrade(&peer_connection); - let curr_track1 = Arc::clone(&curr_track); - let track_count1 = Arc::clone(&track_count); - peer_connection.on_track(Box::new(move |track, _, _| { - let track_num = track_count1.fetch_add(1, Ordering::SeqCst); - - let curr_track2 = Arc::clone(&curr_track1); - let pc2 = pc.clone(); - let packets_tx2 = Arc::clone(&packets_tx); - tokio::spawn(async move { - println!( - "Track has started, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - - let mut last_timestamp = 0; - let mut is_curr_track = false; - while let Ok((mut rtp, _)) = track.read_rtp().await { - // Change the timestamp to only be the delta - let old_timestamp = rtp.header.timestamp; - if last_timestamp == 0 { - rtp.header.timestamp = 0 - } else { - rtp.header.timestamp -= last_timestamp; - } - last_timestamp = old_timestamp; - - // Check if this is the current track - if curr_track2.load(Ordering::SeqCst) == track_num { - // If just switched to this track, send PLI to get picture refresh - if !is_curr_track { - is_curr_track = true; - if let Some(pc) = pc2.upgrade() { - if let Err(err) = pc - .write_rtcp(&[Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: track.ssrc(), - })]) - .await - { - println!("write_rtcp err: {err}"); - } - } else { - break; - } - } - let _ = packets_tx2.send(rtp).await; - } else { - is_curr_track = false; - } - } - - println!( - "Track has ended, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - }); - - Box::pin(async {}) - })); - - let (connected_tx, mut connected_rx) = tokio::sync::mpsc::channel(1); - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - if s == RTCPeerConnectionState::Connected { - let _ = connected_tx.try_send(()); - } else if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - let _ = done_tx.try_send(()); - } - Box::pin(async move {}) - })); - - // Create an answer - let answer = peer_connection.create_answer(None).await?; - - // Create channel that is blocked until ICE Gathering is complete - let mut gather_complete = peer_connection.gathering_complete_promise().await; - - // Sets the LocalDescription, and starts our UDP listeners - peer_connection.set_local_description(answer).await?; - - // Block until ICE Gathering is complete, disabling trickle ICE - // we do this because we only can exchange one signaling message - // in a production application you should exchange ICE Candidates via OnICECandidate - let _ = gather_complete.recv().await; - - // Output the answer in base64 so we can paste it in browser - if let Some(local_desc) = peer_connection.local_description().await { - let json_str = serde_json::to_string(&local_desc)?; - let b64 = signal::encode(&json_str); - println!("{b64}"); - } else { - println!("generate local_description failed!"); - } - - // Asynchronously take all packets in the channel and write them out to our - // track - tokio::spawn(async move { - let mut curr_timestamp = 0; - let mut i = 0; - while let Some(mut packet) = packets_rx.recv().await { - // Timestamp on the packet is really a diff, so add it to current - curr_timestamp += packet.header.timestamp; - packet.header.timestamp = curr_timestamp; - // Keep an increasing sequence number - packet.header.sequence_number = i; - // Write out the packet, ignoring closed pipe if nobody is listening - if let Err(err) = output_track.write_rtp(&packet).await { - if Error::ErrClosedPipe == err { - // The peerConnection has been closed. - return; - } else { - panic!("{}", err); - } - } - i += 1; - } - }); - - // Wait for connection, then rotate the track every 5s - println!("Waiting for connection"); - tokio::select! { - _ = connected_rx.recv() =>{ - loop { - println!("Press ctrl-c to stop, or waiting 5 seconds then changing..."); - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() => { - // We haven't gotten any tracks yet - if track_count.load(Ordering::SeqCst) == 0 { - continue; - } - - if curr_track.load(Ordering::SeqCst) == track_count.load(Ordering::SeqCst) - 1 { - curr_track.store(0, Ordering::SeqCst); - } else { - curr_track.fetch_add(1, Ordering::SeqCst); - } - println!( - "Switched to track {}", - curr_track.load(Ordering::SeqCst) + 1, - ); - } - _ = done_rx.recv() => { - println!("received done signal!"); - break; - } - _ = tokio::signal::ctrl_c() => { - println!(); - break; - } - }; - } - } - _ = done_rx.recv() => {} - }; - - peer_connection.close().await?; - - Ok(()) -} diff --git a/examples/examples/test-data/output.h264 b/examples/examples/test-data/output.h264 deleted file mode 100644 index 03555e80a..000000000 Binary files a/examples/examples/test-data/output.h264 and /dev/null differ diff --git a/examples/examples/test-data/output.ogg b/examples/examples/test-data/output.ogg deleted file mode 100644 index e824455c6..000000000 Binary files a/examples/examples/test-data/output.ogg and /dev/null differ diff --git a/examples/examples/test-data/output_vp8.ivf b/examples/examples/test-data/output_vp8.ivf deleted file mode 100644 index 28bf0022e..000000000 Binary files a/examples/examples/test-data/output_vp8.ivf and /dev/null differ diff --git a/examples/examples/test-data/output_vp9.ivf b/examples/examples/test-data/output_vp9.ivf deleted file mode 100644 index 03607b526..000000000 Binary files a/examples/examples/test-data/output_vp9.ivf and /dev/null differ diff --git a/examples/src/lib.rs b/examples/src/lib.rs deleted file mode 100644 index 31e1bb209..000000000 --- a/examples/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -} diff --git a/ice/.gitignore b/ice/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/ice/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/ice/CHANGELOG.md b/ice/CHANGELOG.md deleted file mode 100644 index d42dfb896..000000000 --- a/ice/CHANGELOG.md +++ /dev/null @@ -1,49 +0,0 @@ -# webrtc-ice changelog - -## Unreleased - -### Breaking changes - -* remove non used `MulticastDnsMode::Unspecified` variant [#404](https://github.com/webrtc-rs/webrtc/pull/404): - -## v0.9.0 - -* Increased minimum support rust version to `1.60.0`. - -### Breaking changes - -* Make functions non-async [#338](https://github.com/webrtc-rs/webrtc/pull/338): - - `Agent`: - - `get_bytes_received`; - - `get_bytes_sent`; - - `on_connection_state_change`; - - `on_selected_candidate_pair_change`; - - `on_candidate`; - - `add_remote_candidate`; - - `gather_candidates`. - - `unmarshal_candidate`; - - `CandidateHostConfig::new_candidate_host`; - - `CandidatePeerReflexiveConfig::new_candidate_peer_reflexive`; - - `CandidateRelayConfig::new_candidate_relay`; - - `CandidateServerReflexiveConfig::new_candidate_server_reflexive`; - - `Candidate`: - - `addr`; - - `set_ip`. - -## v0.8.2 - -* Add IP filter to ICE `AgentConfig` [#306](https://github.com/webrtc-rs/webrtc/pull/306) and [#318](https://github.com/webrtc-rs/webrtc/pull/318). -* Add `rust-version` at 1.57.0 to `Cargo.toml`. This was already the minimum version so does not constitute a change. - -## v0.8.1 - -This release was released in error and contains no changes from 0.8.0. - -## v0.8.0 - -* Increased min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). -* Increased serde's minimum version to 1.0.102 [#243 Fixes for cargo minimal-versions](https://github.com/webrtc-rs/webrtc/pull/243) contributed by [algesten](https://github.com/algesten) - -## Prior to 0.8.0 - -Before 0.8.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/ice/releases). diff --git a/ice/Cargo.toml b/ice/Cargo.toml deleted file mode 100644 index 84916213b..000000000 --- a/ice/Cargo.toml +++ /dev/null @@ -1,57 +0,0 @@ -[package] -name = "webrtc-ice" -version = "0.11.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of ICE" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-ice" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/ice" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["conn", "vnet", "sync"] } -turn = { version = "0.8.0", path = "../turn" } -stun = { version = "0.6.0", path = "../stun" } -mdns = { version = "0.7.0", path = "../mdns", package = "webrtc-mdns" } - -arc-swap = "1" -async-trait = "0.1" -crc = "3" -log = "0.4" -rand = "0.8" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -thiserror = "1" -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -url = "2" -uuid = { version = "1", features = ["v4"] } -waitgroup = "0.1" -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" -regex = "1.9.5" -env_logger = "0.10" -chrono = "0.4.28" -ipnet = "2" -clap = "3" -lazy_static = "1" -hyper = { version = "0.14.27", features = ["full"] } -sha1 = "0.10" - -[[example]] -name = "ping_pong" -path = "examples/ping_pong.rs" -bench = false diff --git a/ice/LICENSE-APACHE b/ice/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/ice/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/ice/LICENSE-MIT b/ice/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/ice/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/ice/README.md b/ice/README.md deleted file mode 100644 index 73786a11b..000000000 --- a/ice/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of ICE. Rewrite Pion ICE in Rust -

diff --git a/ice/codecov.yml b/ice/codecov.yml deleted file mode 100644 index 99d83b7f7..000000000 --- a/ice/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 4bd5cec1-2807-4cd6-8430-d5f3efe32ce0 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/ice/doc/webrtc.rs.png b/ice/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/ice/doc/webrtc.rs.png and /dev/null differ diff --git a/ice/examples/ping_pong.rs b/ice/examples/ping_pong.rs deleted file mode 100644 index 9620383b1..000000000 --- a/ice/examples/ping_pong.rs +++ /dev/null @@ -1,424 +0,0 @@ -use std::io; -use std::sync::Arc; -use std::time::Duration; - -use clap::{App, AppSettings, Arg}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Method, Request, Response, Server, StatusCode}; -use ice::agent::agent_config::AgentConfig; -use ice::agent::Agent; -use ice::candidate::candidate_base::*; -use ice::candidate::*; -use ice::network_type::*; -use ice::state::*; -use ice::udp_network::UDPNetwork; -use ice::Error; -use rand::{thread_rng, Rng}; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, watch, Mutex}; -use util::Conn; -use webrtc_ice as ice; - -#[macro_use] -extern crate lazy_static; - -type SenderType = Arc>>; -type ReceiverType = Arc>>; - -lazy_static! { - // ErrUnknownType indicates an error with Unknown info. - static ref REMOTE_AUTH_CHANNEL: (SenderType, ReceiverType ) = { - let (tx, rx) = mpsc::channel::(3); - (Arc::new(Mutex::new(tx)), Arc::new(Mutex::new(rx))) - }; - - static ref REMOTE_CAND_CHANNEL: (SenderType, ReceiverType) = { - let (tx, rx) = mpsc::channel::(10); - (Arc::new(Mutex::new(tx)), Arc::new(Mutex::new(rx))) - }; -} - -// HTTP Listener to get ICE Credentials/Candidate from remote Peer -async fn remote_handler(req: Request) -> Result, hyper::Error> { - //println!("received {:?}", req); - match (req.method(), req.uri().path()) { - (&Method::POST, "/remoteAuth") => { - let full_body = - match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - let tx = REMOTE_AUTH_CHANNEL.0.lock().await; - //println!("body: {:?}", full_body); - let _ = tx.send(full_body).await; - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - - (&Method::POST, "/remoteCandidate") => { - let full_body = - match std::str::from_utf8(&hyper::body::to_bytes(req.into_body()).await?) { - Ok(s) => s.to_owned(), - Err(err) => panic!("{}", err), - }; - let tx = REMOTE_CAND_CHANNEL.0.lock().await; - //println!("body: {:?}", full_body); - let _ = tx.send(full_body).await; - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::OK; - Ok(response) - } - - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} - -// Controlled Agent: -// cargo run --color=always --package webrtc-ice --example ping_pong -// Controlling Agent: -// cargo run --color=always --package webrtc-ice --example ping_pong -- --controlling - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::init(); - // .format(|buf, record| { - // writeln!( - // buf, - // "{}:{} [{}] {} - {}", - // record.file().unwrap_or("unknown"), - // record.line().unwrap_or(0), - // record.level(), - // chrono::Local::now().format("%H:%M:%S.%6f"), - // record.args() - // ) - // }) - // .filter(None, log::LevelFilter::Trace) - // .init(); - - let mut app = App::new("ICE Demo") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of ICE") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("use-mux") - .takes_value(false) - .long("use-mux") - .short('m') - .help("Use a muxed UDP connection over a single listening port"), - ) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("controlling") - .takes_value(false) - .long("controlling") - .help("is ICE Agent controlling"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let is_controlling = matches.is_present("controlling"); - let use_mux = matches.is_present("use-mux"); - - let (local_http_port, remote_http_port) = if is_controlling { - (9000, 9001) - } else { - (9001, 9000) - }; - - let (weak_conn, weak_agent) = { - let (done_tx, done_rx) = watch::channel(()); - - println!("Listening on http://localhost:{local_http_port}"); - let mut done_http_server = done_rx.clone(); - tokio::spawn(async move { - let addr = ([0, 0, 0, 0], local_http_port).into(); - let service = - make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(remote_handler)) }); - let server = Server::bind(&addr).serve(service); - tokio::select! { - _ = done_http_server.changed() => { - println!("receive cancel http server!"); - } - result = server => { - // Run this server for... forever! - if let Err(e) = result { - eprintln!("server error: {e}"); - } - println!("exit http server!"); - } - }; - }); - - if is_controlling { - println!("Local Agent is controlling"); - } else { - println!("Local Agent is controlled"); - }; - println!("Press 'Enter' when both processes have started"); - let mut input = String::new(); - let _ = io::stdin().read_line(&mut input)?; - - let udp_network = if use_mux { - use ice::udp_mux::*; - let port = if is_controlling { 4000 } else { 4001 }; - - let udp_socket = UdpSocket::bind(("0.0.0.0", port)).await?; - let udp_mux = UDPMuxDefault::new(UDPMuxParams::new(udp_socket)); - - UDPNetwork::Muxed(udp_mux) - } else { - UDPNetwork::Ephemeral(Default::default()) - }; - - let ice_agent = Arc::new( - Agent::new(AgentConfig { - network_types: vec![NetworkType::Udp4], - udp_network, - ..Default::default() - }) - .await?, - ); - - let client = Arc::new(Client::new()); - - // When we have gathered a new ICE Candidate send it to the remote peer - let client2 = Arc::clone(&client); - ice_agent.on_candidate(Box::new( - move |c: Option>| { - let client3 = Arc::clone(&client2); - Box::pin(async move { - if let Some(c) = c { - println!("posting remoteCandidate with {}", c.marshal()); - - let req = match Request::builder() - .method(Method::POST) - .uri(format!( - "http://localhost:{remote_http_port}/remoteCandidate" - )) - .body(Body::from(c.marshal())) - { - Ok(req) => req, - Err(err) => { - println!("{err}"); - return; - } - }; - let resp = match client3.request(req).await { - Ok(resp) => resp, - Err(err) => { - println!("{err}"); - return; - } - }; - println!("Response from remoteCandidate: {}", resp.status()); - } - }) - }, - )); - - let (ice_done_tx, mut ice_done_rx) = mpsc::channel::<()>(1); - // When ICE Connection state has change print to stdout - ice_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - println!("ICE Connection State has changed: {c}"); - if c == ConnectionState::Failed { - let _ = ice_done_tx.try_send(()); - } - Box::pin(async move {}) - })); - - // Get the local auth details and send to remote peer - let (local_ufrag, local_pwd) = ice_agent.get_local_user_credentials().await; - - println!("posting remoteAuth with {local_ufrag}:{local_pwd}"); - let req = match Request::builder() - .method(Method::POST) - .uri(format!("http://localhost:{remote_http_port}/remoteAuth")) - .body(Body::from(format!("{local_ufrag}:{local_pwd}"))) - { - Ok(req) => req, - Err(err) => return Err(Error::Other(format!("{err}"))), - }; - let resp = match client.request(req).await { - Ok(resp) => resp, - Err(err) => return Err(Error::Other(format!("{err}"))), - }; - println!("Response from remoteAuth: {}", resp.status()); - - let (remote_ufrag, remote_pwd) = { - let mut rx = REMOTE_AUTH_CHANNEL.1.lock().await; - if let Some(s) = rx.recv().await { - println!("received: {s}"); - let fields: Vec = s.split(':').map(|s| s.to_string()).collect(); - (fields[0].clone(), fields[1].clone()) - } else { - panic!("rx.recv() empty"); - } - }; - println!("remote_ufrag: {remote_ufrag}, remote_pwd: {remote_pwd}"); - - let ice_agent2 = Arc::clone(&ice_agent); - let mut done_cand = done_rx.clone(); - tokio::spawn(async move { - let mut rx = REMOTE_CAND_CHANNEL.1.lock().await; - loop { - tokio::select! { - _ = done_cand.changed() => { - println!("receive cancel remote cand!"); - break; - } - result = rx.recv() => { - if let Some(s) = result { - if let Ok(c) = unmarshal_candidate(&s) { - println!("add_remote_candidate: {c}"); - let c: Arc = Arc::new(c); - let _ = ice_agent2.add_remote_candidate(&c); - }else{ - println!("unmarshal_candidate error!"); - break; - } - }else{ - println!("REMOTE_CAND_CHANNEL done!"); - break; - } - } - }; - } - }); - - ice_agent.gather_candidates()?; - println!("Connecting..."); - - let (_cancel_tx, cancel_rx) = mpsc::channel(1); - // Start the ICE Agent. One side must be controlled, and the other must be controlling - let conn: Arc = if is_controlling { - ice_agent.dial(cancel_rx, remote_ufrag, remote_pwd).await? - } else { - ice_agent - .accept(cancel_rx, remote_ufrag, remote_pwd) - .await? - }; - - let weak_conn = Arc::downgrade(&conn); - - // Send messages in a loop to the remote peer - let conn_tx = Arc::clone(&conn); - let mut done_send = done_rx.clone(); - tokio::spawn(async move { - const RANDOM_STRING: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - loop { - tokio::time::sleep(Duration::from_secs(3)).await; - - let val: String = (0..15) - .map(|_| { - let idx = thread_rng().gen_range(0..RANDOM_STRING.len()); - RANDOM_STRING[idx] as char - }) - .collect(); - - tokio::select! { - _ = done_send.changed() => { - println!("receive cancel ice send!"); - break; - } - result = conn_tx.send(val.as_bytes()) => { - if let Err(err) = result { - eprintln!("conn_tx send error: {err}"); - break; - }else{ - println!("Sent: '{val}'"); - } - } - }; - } - }); - - let mut done_recv = done_rx.clone(); - tokio::spawn(async move { - // Receive messages in a loop from the remote peer - let mut buf = vec![0u8; 1500]; - loop { - tokio::select! { - _ = done_recv.changed() => { - println!("receive cancel ice recv!"); - break; - } - result = conn.recv(&mut buf) => { - match result { - Ok(n) => { - println!("Received: '{}'", std::str::from_utf8(&buf[..n]).unwrap()); - } - Err(err) => { - eprintln!("conn_tx send error: {err}"); - break; - } - }; - } - }; - } - }); - - println!("Press ctrl-c to stop"); - /*let d = if is_controlling { - Duration::from_secs(500) - } else { - Duration::from_secs(5) - }; - let timeout = tokio::time::sleep(d); - tokio::pin!(timeout);*/ - - tokio::select! { - /*_ = timeout.as_mut() => { - println!("received timeout signal!"); - let _ = done_tx.send(()); - }*/ - _ = ice_done_rx.recv() => { - println!("ice_done_rx"); - let _ = done_tx.send(()); - } - _ = tokio::signal::ctrl_c() => { - println!(); - let _ = done_tx.send(()); - } - }; - - let _ = ice_agent.close().await; - - (weak_conn, Arc::downgrade(&ice_agent)) - }; - - let mut int = tokio::time::interval(Duration::from_secs(1)); - loop { - int.tick().await; - println!( - "weak_conn: weak count = {}, strong count = {}, weak_agent: weak count = {}, strong count = {}", - weak_conn.weak_count(), - weak_conn.strong_count(), - weak_agent.weak_count(), - weak_agent.strong_count(), - ); - if weak_conn.strong_count() == 0 && weak_agent.strong_count() == 0 { - break; - } - } - - Ok(()) -} diff --git a/ice/src/agent/agent_config.rs b/ice/src/agent/agent_config.rs deleted file mode 100644 index 9c1c6cd4d..000000000 --- a/ice/src/agent/agent_config.rs +++ /dev/null @@ -1,255 +0,0 @@ -use std::net::IpAddr; -use std::time::Duration; - -use util::vnet::net::*; - -use super::*; -use crate::error::*; -use crate::mdns::*; -use crate::network_type::*; -use crate::udp_network::UDPNetwork; -use crate::url::*; - -/// The interval at which the agent performs candidate checks in the connecting phase. -pub(crate) const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(200); - -/// The interval used to keep candidates alive. -pub(crate) const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2); - -/// The default time till an Agent transitions disconnected. -pub(crate) const DEFAULT_DISCONNECTED_TIMEOUT: Duration = Duration::from_secs(5); - -/// The default time till an Agent transitions to failed after disconnected. -pub(crate) const DEFAULT_FAILED_TIMEOUT: Duration = Duration::from_secs(25); - -/// Wait time before nominating a host candidate. -pub(crate) const DEFAULT_HOST_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_secs(0); - -/// Wait time before nominating a srflx candidate. -pub(crate) const DEFAULT_SRFLX_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_millis(500); - -/// Wait time before nominating a prflx candidate. -pub(crate) const DEFAULT_PRFLX_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_millis(1000); - -/// Wait time before nominating a relay candidate. -pub(crate) const DEFAULT_RELAY_ACCEPTANCE_MIN_WAIT: Duration = Duration::from_millis(2000); - -/// Max binding request before considering a pair failed. -pub(crate) const DEFAULT_MAX_BINDING_REQUESTS: u16 = 7; - -/// The number of bytes that can be buffered before we start to error. -pub(crate) const MAX_BUFFER_SIZE: usize = 1000 * 1000; // 1MB - -/// Wait time before binding requests can be deleted. -pub(crate) const MAX_BINDING_REQUEST_TIMEOUT: Duration = Duration::from_millis(4000); - -pub(crate) fn default_candidate_types() -> Vec { - vec![ - CandidateType::Host, - CandidateType::ServerReflexive, - CandidateType::Relay, - ] -} - -pub type InterfaceFilterFn = Box bool) + Send + Sync>; -pub type IpFilterFn = Box bool) + Send + Sync>; - -/// Collects the arguments to `ice::Agent` construction into a single structure, for -/// future-proofness of the interface. -#[derive(Default)] -pub struct AgentConfig { - pub urls: Vec, - - /// Controls how the UDP network stack works. - /// See [`UDPNetwork`] - pub udp_network: UDPNetwork, - - /// It is used to perform connectivity checks. The values MUST be unguessable, with at least - /// 128 bits of random number generator output used to generate the password, and at least 24 - /// bits of output to generate the username fragment. - pub local_ufrag: String, - /// It is used to perform connectivity checks. The values MUST be unguessable, with at least - /// 128 bits of random number generator output used to generate the password, and at least 24 - /// bits of output to generate the username fragment. - pub local_pwd: String, - - /// Controls mDNS behavior for the ICE agent. - pub multicast_dns_mode: MulticastDnsMode, - - /// Controls the hostname for this agent. If none is specified a random one will be generated. - pub multicast_dns_host_name: String, - - /// Control mDNS destination address - pub multicast_dns_dest_addr: String, - - /// Defaults to 5 seconds when this property is nil. - /// If the duration is 0, the ICE Agent will never go to disconnected. - pub disconnected_timeout: Option, - - /// Defaults to 25 seconds when this property is nil. - /// If the duration is 0, we will never go to failed. - pub failed_timeout: Option, - - /// Determines how often should we send ICE keepalives (should be less then connectiontimeout - /// above) when this is nil, it defaults to 10 seconds. - /// A keepalive interval of 0 means we never send keepalive packets - pub keepalive_interval: Option, - - /// An optional configuration for disabling or enabling support for specific network types. - pub network_types: Vec, - - /// An optional configuration for disabling or enabling support for specific candidate types. - pub candidate_types: Vec, - - //LoggerFactory logging.LoggerFactory - /// Controls how often our internal task loop runs when in the connecting state. - /// Only useful for testing. - pub check_interval: Duration, - - /// The max amount of binding requests the agent will send over a candidate pair for validation - /// or nomination, if after max_binding_requests the candidate is yet to answer a binding - /// request or a nomination we set the pair as failed. - pub max_binding_requests: Option, - - pub is_controlling: bool, - - /// lite agents do not perform connectivity check and only provide host candidates. - pub lite: bool, - - /// It is used along with nat1to1ips to specify which candidate type the 1:1 NAT IP addresses - /// should be mapped to. If unspecified or CandidateTypeHost, nat1to1ips are used to replace - /// host candidate IPs. If CandidateTypeServerReflexive, it will insert a srflx candidate (as - /// if it was derived from a STUN server) with its port number being the one for the actual host - /// candidate. Other values will result in an error. - pub nat_1to1_ip_candidate_type: CandidateType, - - /// Contains a list of public IP addresses that are to be used as a host candidate or srflx - /// candidate. This is used typically for servers that are behind 1:1 D-NAT (e.g. AWS EC2 - /// instances) and to eliminate the need of server reflexisive candidate gathering. - pub nat_1to1_ips: Vec, - - /// Specify a minimum wait time before selecting host candidates. - pub host_acceptance_min_wait: Option, - /// Specify a minimum wait time before selecting srflx candidates. - pub srflx_acceptance_min_wait: Option, - /// Specify a minimum wait time before selecting prflx candidates. - pub prflx_acceptance_min_wait: Option, - /// Specify a minimum wait time before selecting relay candidates. - pub relay_acceptance_min_wait: Option, - - /// Net is the our abstracted network interface for internal development purpose only - /// (see (github.com/pion/transport/vnet)[github.com/pion/transport/vnet]). - pub net: Option>, - - /// A function that you can use in order to whitelist or blacklist the interfaces which are - /// used to gather ICE candidates. - pub interface_filter: Arc>, - - /// A function that you can use in order to whitelist or blacklist - /// the ips which are used to gather ICE candidates. - pub ip_filter: Arc>, - - /// Controls if self-signed certificates are accepted when connecting to TURN servers via TLS or - /// DTLS. - pub insecure_skip_verify: bool, -} - -impl AgentConfig { - /// Populates an agent and falls back to defaults if fields are unset. - pub(crate) fn init_with_defaults(&self, a: &mut AgentInternal) { - if let Some(max_binding_requests) = self.max_binding_requests { - a.max_binding_requests = max_binding_requests; - } else { - a.max_binding_requests = DEFAULT_MAX_BINDING_REQUESTS; - } - - if let Some(host_acceptance_min_wait) = self.host_acceptance_min_wait { - a.host_acceptance_min_wait = host_acceptance_min_wait; - } else { - a.host_acceptance_min_wait = DEFAULT_HOST_ACCEPTANCE_MIN_WAIT; - } - - if let Some(srflx_acceptance_min_wait) = self.srflx_acceptance_min_wait { - a.srflx_acceptance_min_wait = srflx_acceptance_min_wait; - } else { - a.srflx_acceptance_min_wait = DEFAULT_SRFLX_ACCEPTANCE_MIN_WAIT; - } - - if let Some(prflx_acceptance_min_wait) = self.prflx_acceptance_min_wait { - a.prflx_acceptance_min_wait = prflx_acceptance_min_wait; - } else { - a.prflx_acceptance_min_wait = DEFAULT_PRFLX_ACCEPTANCE_MIN_WAIT; - } - - if let Some(relay_acceptance_min_wait) = self.relay_acceptance_min_wait { - a.relay_acceptance_min_wait = relay_acceptance_min_wait; - } else { - a.relay_acceptance_min_wait = DEFAULT_RELAY_ACCEPTANCE_MIN_WAIT; - } - - if let Some(disconnected_timeout) = self.disconnected_timeout { - a.disconnected_timeout = disconnected_timeout; - } else { - a.disconnected_timeout = DEFAULT_DISCONNECTED_TIMEOUT; - } - - if let Some(failed_timeout) = self.failed_timeout { - a.failed_timeout = failed_timeout; - } else { - a.failed_timeout = DEFAULT_FAILED_TIMEOUT; - } - - if let Some(keepalive_interval) = self.keepalive_interval { - a.keepalive_interval = keepalive_interval; - } else { - a.keepalive_interval = DEFAULT_KEEPALIVE_INTERVAL; - } - - if self.check_interval == Duration::from_secs(0) { - a.check_interval = DEFAULT_CHECK_INTERVAL; - } else { - a.check_interval = self.check_interval; - } - } - - pub(crate) fn init_ext_ip_mapping( - &self, - mdns_mode: MulticastDnsMode, - candidate_types: &[CandidateType], - ) -> Result> { - if let Some(ext_ip_mapper) = - ExternalIpMapper::new(self.nat_1to1_ip_candidate_type, &self.nat_1to1_ips)? - { - if ext_ip_mapper.candidate_type == CandidateType::Host { - if mdns_mode == MulticastDnsMode::QueryAndGather { - return Err(Error::ErrMulticastDnsWithNat1to1IpMapping); - } - let mut candi_host_enabled = false; - for candi_type in candidate_types { - if *candi_type == CandidateType::Host { - candi_host_enabled = true; - break; - } - } - if !candi_host_enabled { - return Err(Error::ErrIneffectiveNat1to1IpMappingHost); - } - } else if ext_ip_mapper.candidate_type == CandidateType::ServerReflexive { - let mut candi_srflx_enabled = false; - for candi_type in candidate_types { - if *candi_type == CandidateType::ServerReflexive { - candi_srflx_enabled = true; - break; - } - } - if !candi_srflx_enabled { - return Err(Error::ErrIneffectiveNat1to1IpMappingSrflx); - } - } - - Ok(Some(ext_ip_mapper)) - } else { - Ok(None) - } - } -} diff --git a/ice/src/agent/agent_gather.rs b/ice/src/agent/agent_gather.rs deleted file mode 100644 index 311d1cf14..000000000 --- a/ice/src/agent/agent_gather.rs +++ /dev/null @@ -1,892 +0,0 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::str::FromStr; -use std::sync::Arc; - -use util::vnet::net::*; -use util::Conn; -use waitgroup::WaitGroup; - -use super::*; -use crate::candidate::candidate_base::CandidateBaseConfig; -use crate::candidate::candidate_host::CandidateHostConfig; -use crate::candidate::candidate_relay::CandidateRelayConfig; -use crate::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; -use crate::candidate::*; -use crate::error::*; -use crate::network_type::*; -use crate::udp_network::UDPNetwork; -use crate::url::{ProtoType, SchemeType, Url}; -use crate::util::*; - -const STUN_GATHER_TIMEOUT: Duration = Duration::from_secs(5); - -pub(crate) struct GatherCandidatesInternalParams { - pub(crate) udp_network: UDPNetwork, - pub(crate) candidate_types: Vec, - pub(crate) urls: Vec, - pub(crate) network_types: Vec, - pub(crate) mdns_mode: MulticastDnsMode, - pub(crate) mdns_name: String, - pub(crate) net: Arc, - pub(crate) interface_filter: Arc>, - pub(crate) ip_filter: Arc>, - pub(crate) ext_ip_mapper: Arc>, - pub(crate) agent_internal: Arc, - pub(crate) gathering_state: Arc, - pub(crate) chan_candidate_tx: ChanCandidateTx, -} - -struct GatherCandidatesLocalParams { - udp_network: UDPNetwork, - network_types: Vec, - mdns_mode: MulticastDnsMode, - mdns_name: String, - interface_filter: Arc>, - ip_filter: Arc>, - ext_ip_mapper: Arc>, - net: Arc, - agent_internal: Arc, -} - -struct GatherCandidatesLocalUDPMuxParams { - network_types: Vec, - interface_filter: Arc>, - ip_filter: Arc>, - ext_ip_mapper: Arc>, - net: Arc, - agent_internal: Arc, - udp_mux: Arc, -} - -struct GatherCandidatesSrflxMappedParasm { - network_types: Vec, - port_max: u16, - port_min: u16, - ext_ip_mapper: Arc>, - net: Arc, - agent_internal: Arc, -} - -struct GatherCandidatesSrflxParams { - urls: Vec, - network_types: Vec, - port_max: u16, - port_min: u16, - net: Arc, - agent_internal: Arc, -} - -impl Agent { - pub(crate) async fn gather_candidates_internal(params: GatherCandidatesInternalParams) { - Self::set_gathering_state( - ¶ms.chan_candidate_tx, - ¶ms.gathering_state, - GatheringState::Gathering, - ) - .await; - - let wg = WaitGroup::new(); - - for t in ¶ms.candidate_types { - match t { - CandidateType::Host => { - let local_params = GatherCandidatesLocalParams { - udp_network: params.udp_network.clone(), - network_types: params.network_types.clone(), - mdns_mode: params.mdns_mode, - mdns_name: params.mdns_name.clone(), - interface_filter: Arc::clone(¶ms.interface_filter), - ip_filter: Arc::clone(¶ms.ip_filter), - ext_ip_mapper: Arc::clone(¶ms.ext_ip_mapper), - net: Arc::clone(¶ms.net), - agent_internal: Arc::clone(¶ms.agent_internal), - }; - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - Self::gather_candidates_local(local_params).await; - }); - } - CandidateType::ServerReflexive => { - let ephemeral_config = match ¶ms.udp_network { - UDPNetwork::Ephemeral(e) => e, - // No server reflexive for muxxed connections - UDPNetwork::Muxed(_) => continue, - }; - - let srflx_params = GatherCandidatesSrflxParams { - urls: params.urls.clone(), - network_types: params.network_types.clone(), - port_max: ephemeral_config.port_max(), - port_min: ephemeral_config.port_min(), - net: Arc::clone(¶ms.net), - agent_internal: Arc::clone(¶ms.agent_internal), - }; - let w1 = wg.worker(); - tokio::spawn(async move { - let _d = w1; - - Self::gather_candidates_srflx(srflx_params).await; - }); - if let Some(ext_ip_mapper) = &*params.ext_ip_mapper { - if ext_ip_mapper.candidate_type == CandidateType::ServerReflexive { - let srflx_mapped_params = GatherCandidatesSrflxMappedParasm { - network_types: params.network_types.clone(), - port_max: ephemeral_config.port_max(), - port_min: ephemeral_config.port_min(), - ext_ip_mapper: Arc::clone(¶ms.ext_ip_mapper), - net: Arc::clone(¶ms.net), - agent_internal: Arc::clone(¶ms.agent_internal), - }; - let w2 = wg.worker(); - tokio::spawn(async move { - let _d = w2; - - Self::gather_candidates_srflx_mapped(srflx_mapped_params).await; - }); - } - } - } - CandidateType::Relay => { - let urls = params.urls.clone(); - let net = Arc::clone(¶ms.net); - let agent_internal = Arc::clone(¶ms.agent_internal); - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - Self::gather_candidates_relay(urls, net, agent_internal).await; - }); - } - _ => {} - } - } - - // Block until all STUN and TURN URLs have been gathered (or timed out) - wg.wait().await; - - Self::set_gathering_state( - ¶ms.chan_candidate_tx, - ¶ms.gathering_state, - GatheringState::Complete, - ) - .await; - } - - async fn set_gathering_state( - chan_candidate_tx: &ChanCandidateTx, - gathering_state: &Arc, - new_state: GatheringState, - ) { - if GatheringState::from(gathering_state.load(Ordering::SeqCst)) != new_state - && new_state == GatheringState::Complete - { - let cand_tx = chan_candidate_tx.lock().await; - if let Some(tx) = &*cand_tx { - let _ = tx.send(None).await; - } - } - - gathering_state.store(new_state as u8, Ordering::SeqCst); - } - - async fn gather_candidates_local(params: GatherCandidatesLocalParams) { - let GatherCandidatesLocalParams { - udp_network, - network_types, - mdns_mode, - mdns_name, - interface_filter, - ip_filter, - ext_ip_mapper, - net, - agent_internal, - } = params; - - // If we wanna use UDP mux, do so - // FIXME: We still need to support TCP in combination with this option - if let UDPNetwork::Muxed(udp_mux) = udp_network { - let result = Self::gather_candidates_local_udp_mux(GatherCandidatesLocalUDPMuxParams { - network_types, - interface_filter, - ip_filter, - ext_ip_mapper, - net, - agent_internal, - udp_mux, - }) - .await; - - if let Err(err) = result { - log::error!("Failed to gather local candidates using UDP mux: {}", err); - } - - return; - } - - let ips = local_interfaces(&net, &interface_filter, &ip_filter, &network_types).await; - for ip in ips { - let mut mapped_ip = ip; - - if mdns_mode != MulticastDnsMode::QueryAndGather && ext_ip_mapper.is_some() { - if let Some(ext_ip_mapper2) = ext_ip_mapper.as_ref() { - if ext_ip_mapper2.candidate_type == CandidateType::Host { - if let Ok(mi) = ext_ip_mapper2.find_external_ip(&ip.to_string()) { - mapped_ip = mi; - } else { - log::warn!( - "[{}]: 1:1 NAT mapping is enabled but no external IP is found for {}", - agent_internal.get_name(), - ip - ); - } - } - } - } - - let address = if mdns_mode == MulticastDnsMode::QueryAndGather { - mdns_name.clone() - } else { - mapped_ip.to_string() - }; - - //TODO: for network in networks - let network = UDP.to_owned(); - if let UDPNetwork::Ephemeral(ephemeral_config) = &udp_network { - /*TODO:switch network { - case tcp: - // Handle ICE TCP passive mode - - a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag) - conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag) - if err != nil { - if !errors.Is(err, ErrTCPMuxNotInitialized) { - a.log.Warnf("error getting tcp conn by ufrag: %s %s %s\n", network, ip, a.localUfrag) - } - continue - } - port = conn.LocalAddr().(*net.TCPAddr).Port - tcpType = TCPTypePassive - // is there a way to verify that the listen address is even - // accessible from the current interface. - case udp:*/ - - let conn: Arc = match listen_udp_in_port_range( - &net, - ephemeral_config.port_max(), - ephemeral_config.port_min(), - SocketAddr::new(ip, 0), - ) - .await - { - Ok(conn) => conn, - Err(err) => { - log::warn!( - "[{}]: could not listen {} {}: {}", - agent_internal.get_name(), - network, - ip, - err - ); - continue; - } - }; - - let port = match conn.local_addr() { - Ok(addr) => addr.port(), - Err(err) => { - log::warn!( - "[{}]: could not get local addr: {}", - agent_internal.get_name(), - err - ); - continue; - } - }; - - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: network.clone(), - address, - port, - component: COMPONENT_RTP, - conn: Some(conn), - ..CandidateBaseConfig::default() - }, - ..CandidateHostConfig::default() - }; - - let candidate: Arc = - match host_config.new_candidate_host() { - Ok(candidate) => { - if mdns_mode == MulticastDnsMode::QueryAndGather { - if let Err(err) = candidate.set_ip(&ip) { - log::warn!( - "[{}]: Failed to create host candidate: {} {} {}: {:?}", - agent_internal.get_name(), - network, - mapped_ip, - port, - err - ); - continue; - } - } - Arc::new(candidate) - } - Err(err) => { - log::warn!( - "[{}]: Failed to create host candidate: {} {} {}: {}", - agent_internal.get_name(), - network, - mapped_ip, - port, - err - ); - continue; - } - }; - - { - if let Err(err) = agent_internal.add_candidate(&candidate).await { - if let Err(close_err) = candidate.close().await { - log::warn!( - "[{}]: Failed to close candidate: {}", - agent_internal.get_name(), - close_err - ); - } - log::warn!( - "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", - agent_internal.get_name(), - err - ); - } - } - } - } - } - - async fn gather_candidates_local_udp_mux( - params: GatherCandidatesLocalUDPMuxParams, - ) -> Result<()> { - let GatherCandidatesLocalUDPMuxParams { - network_types, - interface_filter, - ip_filter, - ext_ip_mapper, - net, - agent_internal, - udp_mux, - } = params; - - // Filter out non UDP network types - let relevant_network_types: Vec<_> = - network_types.into_iter().filter(|n| n.is_udp()).collect(); - - let udp_mux = Arc::clone(&udp_mux); - - let local_ips = - local_interfaces(&net, &interface_filter, &ip_filter, &relevant_network_types).await; - - let candidate_ips: Vec = ext_ip_mapper - .as_ref() // Arc - .as_ref() // Option - .and_then(|mapper| { - if mapper.candidate_type != CandidateType::Host { - return None; - } - - Some( - local_ips - .iter() - .filter_map(|ip| match mapper.find_external_ip(&ip.to_string()) { - Ok(ip) => Some(ip), - Err(err) => { - log::warn!( - "1:1 NAT mapping is enabled but not external IP is found for {}: {}", - ip, - err - ); - None - } - }) - .collect(), - ) - }) - .unwrap_or_else(|| local_ips.iter().copied().collect()); - - if candidate_ips.is_empty() { - return Err(Error::ErrCandidateIpNotFound); - } - - let ufrag = { - let ufrag_pwd = agent_internal.ufrag_pwd.lock().await; - - ufrag_pwd.local_ufrag.clone() - }; - - let conn = udp_mux.get_conn(&ufrag).await?; - let port = conn.local_addr()?.port(); - - for candidate_ip in candidate_ips { - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: UDP.to_owned(), - address: candidate_ip.to_string(), - port, - conn: Some(conn.clone()), - component: COMPONENT_RTP, - ..Default::default() - }, - tcp_type: TcpType::Unspecified, - }; - - let candidate: Arc = - Arc::new(host_config.new_candidate_host()?); - - agent_internal.add_candidate(&candidate).await?; - } - - Ok(()) - } - - async fn gather_candidates_srflx_mapped(params: GatherCandidatesSrflxMappedParasm) { - let GatherCandidatesSrflxMappedParasm { - network_types, - port_max, - port_min, - ext_ip_mapper, - net, - agent_internal, - } = params; - - let wg = WaitGroup::new(); - - for network_type in network_types { - if network_type.is_tcp() { - continue; - } - - let network = network_type.to_string(); - let net2 = Arc::clone(&net); - let agent_internal2 = Arc::clone(&agent_internal); - let ext_ip_mapper2 = Arc::clone(&ext_ip_mapper); - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - let conn: Arc = match listen_udp_in_port_range( - &net2, - port_max, - port_min, - if network_type.is_ipv4() { - SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) - } else { - SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0) - }, - ) - .await - { - Ok(conn) => conn, - Err(err) => { - log::warn!( - "[{}]: Failed to listen {}: {}", - agent_internal2.get_name(), - network, - err - ); - return Ok(()); - } - }; - - let laddr = conn.local_addr()?; - let mapped_ip = { - if let Some(ext_ip_mapper3) = &*ext_ip_mapper2 { - match ext_ip_mapper3.find_external_ip(&laddr.ip().to_string()) { - Ok(ip) => ip, - Err(err) => { - log::warn!( - "[{}]: 1:1 NAT mapping is enabled but no external IP is found for {}: {}", - agent_internal2.get_name(), - laddr, - err - ); - return Ok(()); - } - } - } else { - log::error!( - "[{}]: ext_ip_mapper is None in gather_candidates_srflx_mapped", - agent_internal2.get_name(), - ); - return Ok(()); - } - }; - - let srflx_config = CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: network.clone(), - address: mapped_ip.to_string(), - port: laddr.port(), - component: COMPONENT_RTP, - conn: Some(conn), - ..CandidateBaseConfig::default() - }, - rel_addr: laddr.ip().to_string(), - rel_port: laddr.port(), - }; - - let candidate: Arc = - match srflx_config.new_candidate_server_reflexive() { - Ok(candidate) => Arc::new(candidate), - Err(err) => { - log::warn!( - "[{}]: Failed to create server reflexive candidate: {} {} {}: {}", - agent_internal2.get_name(), - network, - mapped_ip, - laddr.port(), - err - ); - return Ok(()); - } - }; - - { - if let Err(err) = agent_internal2.add_candidate(&candidate).await { - if let Err(close_err) = candidate.close().await { - log::warn!( - "[{}]: Failed to close candidate: {}", - agent_internal2.get_name(), - close_err - ); - } - log::warn!( - "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", - agent_internal2.get_name(), - err - ); - } - } - - Result::<()>::Ok(()) - }); - } - - wg.wait().await; - } - - async fn gather_candidates_srflx(params: GatherCandidatesSrflxParams) { - let GatherCandidatesSrflxParams { - urls, - network_types, - port_max, - port_min, - net, - agent_internal, - } = params; - - let wg = WaitGroup::new(); - for network_type in network_types { - if network_type.is_tcp() { - continue; - } - - for url in &urls { - let network = network_type.to_string(); - let is_ipv4 = network_type.is_ipv4(); - let url = url.clone(); - let net2 = Arc::clone(&net); - let agent_internal2 = Arc::clone(&agent_internal); - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - let host_port = format!("{}:{}", url.host, url.port); - let server_addr = match net2.resolve_addr(is_ipv4, &host_port).await { - Ok(addr) => addr, - Err(err) => { - log::warn!( - "[{}]: failed to resolve stun host: {}: {}", - agent_internal2.get_name(), - host_port, - err - ); - return Ok(()); - } - }; - - let conn: Arc = match listen_udp_in_port_range( - &net2, - port_max, - port_min, - if is_ipv4 { - SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) - } else { - SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0) - }, - ) - .await - { - Ok(conn) => conn, - Err(err) => { - log::warn!( - "[{}]: Failed to listen for {}: {}", - agent_internal2.get_name(), - server_addr, - err - ); - return Ok(()); - } - }; - - let xoraddr = - match get_xormapped_addr(&conn, server_addr, STUN_GATHER_TIMEOUT).await { - Ok(xoraddr) => xoraddr, - Err(err) => { - log::warn!( - "[{}]: could not get server reflexive address {} {}: {}", - agent_internal2.get_name(), - network, - url, - err - ); - return Ok(()); - } - }; - - let (ip, port) = (xoraddr.ip, xoraddr.port); - - let laddr = conn.local_addr()?; - let srflx_config = CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: network.clone(), - address: ip.to_string(), - port, - component: COMPONENT_RTP, - conn: Some(conn), - ..CandidateBaseConfig::default() - }, - rel_addr: laddr.ip().to_string(), - rel_port: laddr.port(), - }; - - let candidate: Arc = - match srflx_config.new_candidate_server_reflexive() { - Ok(candidate) => Arc::new(candidate), - Err(err) => { - log::warn!( - "[{}]: Failed to create server reflexive candidate: {} {} {}: {:?}", - agent_internal2.get_name(), - network, - ip, - port, - err - ); - return Ok(()); - } - }; - - { - if let Err(err) = agent_internal2.add_candidate(&candidate).await { - if let Err(close_err) = candidate.close().await { - log::warn!( - "[{}]: Failed to close candidate: {}", - agent_internal2.get_name(), - close_err - ); - } - log::warn!( - "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", - agent_internal2.get_name(), - err - ); - } - } - - Result::<()>::Ok(()) - }); - } - } - - wg.wait().await; - } - - pub(crate) async fn gather_candidates_relay( - urls: Vec, - net: Arc, - agent_internal: Arc, - ) { - let wg = WaitGroup::new(); - - for url in urls { - if url.scheme != SchemeType::Turn && url.scheme != SchemeType::Turns { - continue; - } - if url.username.is_empty() { - log::error!( - "[{}]:Failed to gather relay candidates: {:?}", - agent_internal.get_name(), - Error::ErrUsernameEmpty - ); - return; - } - if url.password.is_empty() { - log::error!( - "[{}]: Failed to gather relay candidates: {:?}", - agent_internal.get_name(), - Error::ErrPasswordEmpty - ); - return; - } - - let network = NetworkType::Udp4.to_string(); - let net2 = Arc::clone(&net); - let agent_internal2 = Arc::clone(&agent_internal); - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - let turn_server_addr = format!("{}:{}", url.host, url.port); - - let (loc_conn, rel_addr, rel_port) = - if url.proto == ProtoType::Udp && url.scheme == SchemeType::Turn { - let loc_conn = match net2.bind(SocketAddr::from_str("0.0.0.0:0")?).await { - Ok(c) => c, - Err(err) => { - log::warn!( - "[{}]: Failed to listen due to error: {}", - agent_internal2.get_name(), - err - ); - return Ok(()); - } - }; - - let local_addr = loc_conn.local_addr()?; - let rel_addr = local_addr.ip().to_string(); - let rel_port = local_addr.port(); - (loc_conn, rel_addr, rel_port) - /*TODO: case url.proto == ProtoType::UDP && url.scheme == SchemeType::TURNS{ - case a.proxyDialer != nil && url.Proto == ProtoTypeTCP && (url.Scheme == SchemeTypeTURN || url.Scheme == SchemeTypeTURNS): - case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURN: - case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURNS:*/ - } else { - log::warn!( - "[{}]: Unable to handle URL in gather_candidates_relay {}", - agent_internal2.get_name(), - url - ); - return Ok(()); - }; - - let cfg = turn::client::ClientConfig { - stun_serv_addr: String::new(), - turn_serv_addr: turn_server_addr.clone(), - username: url.username, - password: url.password, - realm: String::new(), - software: String::new(), - rto_in_ms: 0, - conn: loc_conn, - vnet: Some(Arc::clone(&net2)), - }; - let client = match turn::client::Client::new(cfg).await { - Ok(client) => Arc::new(client), - Err(err) => { - log::warn!( - "[{}]: Failed to build new turn.Client {} {}\n", - agent_internal2.get_name(), - turn_server_addr, - err - ); - return Ok(()); - } - }; - if let Err(err) = client.listen().await { - let _ = client.close().await; - log::warn!( - "[{}]: Failed to listen on turn.Client {} {}", - agent_internal2.get_name(), - turn_server_addr, - err - ); - return Ok(()); - } - - let relay_conn: Arc = match client.allocate().await { - Ok(conn) => Arc::new(conn), - Err(err) => { - let _ = client.close().await; - log::warn!( - "[{}]: Failed to allocate on turn.Client {} {}", - agent_internal2.get_name(), - turn_server_addr, - err - ); - return Ok(()); - } - }; - - let raddr = relay_conn.local_addr()?; - let relay_config = CandidateRelayConfig { - base_config: CandidateBaseConfig { - network: network.clone(), - address: raddr.ip().to_string(), - port: raddr.port(), - component: COMPONENT_RTP, - conn: Some(Arc::clone(&relay_conn)), - ..CandidateBaseConfig::default() - }, - rel_addr, - rel_port, - relay_client: Some(Arc::clone(&client)), - }; - - let candidate: Arc = - match relay_config.new_candidate_relay() { - Ok(candidate) => Arc::new(candidate), - Err(err) => { - let _ = relay_conn.close().await; - let _ = client.close().await; - log::warn!( - "[{}]: Failed to create relay candidate: {} {}: {}", - agent_internal2.get_name(), - network, - raddr, - err - ); - return Ok(()); - } - }; - - { - if let Err(err) = agent_internal2.add_candidate(&candidate).await { - if let Err(close_err) = candidate.close().await { - log::warn!( - "[{}]: Failed to close candidate: {}", - agent_internal2.get_name(), - close_err - ); - } - log::warn!( - "[{}]: Failed to append to localCandidates and run onCandidateHdlr: {}", - agent_internal2.get_name(), - err - ); - } - } - - Result::<()>::Ok(()) - }); - } - - wg.wait().await; - } -} diff --git a/ice/src/agent/agent_gather_test.rs b/ice/src/agent/agent_gather_test.rs deleted file mode 100644 index ed90f097a..000000000 --- a/ice/src/agent/agent_gather_test.rs +++ /dev/null @@ -1,490 +0,0 @@ -use std::str::FromStr; - -use ipnet::IpNet; -use tokio::net::UdpSocket; -use util::vnet::*; - -use super::agent_vnet_test::*; -use super::*; -use crate::udp_mux::{UDPMuxDefault, UDPMuxParams}; -use crate::util::*; - -#[tokio::test] -async fn test_vnet_gather_no_local_ip_address() -> Result<()> { - let vnet = Arc::new(net::Net::new(Some(net::NetConfig::default()))); - - let a = Agent::new(AgentConfig { - net: Some(Arc::clone(&vnet)), - ..Default::default() - }) - .await?; - - let local_ips = local_interfaces( - &vnet, - &a.interface_filter, - &a.ip_filter, - &[NetworkType::Udp4], - ) - .await; - assert!(local_ips.is_empty(), "should return no local IP"); - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_dynamic_ip_address() -> Result<()> { - let cider = "1.2.3.0/24"; - let ipnet = IpNet::from_str(cider).map_err(|e| Error::Other(e.to_string()))?; - - let r = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: cider.to_owned(), - ..Default::default() - })?)); - let nw = Arc::new(net::Net::new(Some(net::NetConfig::default()))); - connect_net2router(&nw, &r).await?; - - let a = Agent::new(AgentConfig { - net: Some(Arc::clone(&nw)), - ..Default::default() - }) - .await?; - - let local_ips = - local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; - assert!(!local_ips.is_empty(), "should have one local IP"); - - for ip in &local_ips { - if ip.is_loopback() { - panic!("should not return loopback IP"); - } - if !ipnet.contains(ip) { - panic!("{ip} should be contained in the CIDR {ipnet}"); - } - } - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_listen_udp() -> Result<()> { - let cider = "1.2.3.0/24"; - let r = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: cider.to_owned(), - ..Default::default() - })?)); - let nw = Arc::new(net::Net::new(Some(net::NetConfig::default()))); - connect_net2router(&nw, &r).await?; - - let a = Agent::new(AgentConfig { - net: Some(Arc::clone(&nw)), - ..Default::default() - }) - .await?; - - let local_ips = - local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; - assert!(!local_ips.is_empty(), "should have one local IP"); - - for ip in local_ips { - let _ = listen_udp_in_port_range(&nw, 0, 0, SocketAddr::new(ip, 0)).await?; - - let result = listen_udp_in_port_range(&nw, 4999, 5000, SocketAddr::new(ip, 0)).await; - assert!( - result.is_err(), - "listenUDP with invalid port range did not return ErrPort" - ); - - let conn = listen_udp_in_port_range(&nw, 5000, 5000, SocketAddr::new(ip, 0)).await?; - let port = conn.local_addr()?.port(); - assert_eq!( - port, 5000, - "listenUDP with port restriction of 5000 listened on incorrect port ({port})" - ); - } - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_with_nat_1to1_as_host_candidates() -> Result<()> { - let external_ip0 = "1.2.3.4"; - let external_ip1 = "1.2.3.5"; - let local_ip0 = "10.0.0.1"; - let local_ip1 = "10.0.0.2"; - let map0 = format!("{external_ip0}/{local_ip0}"); - let map1 = format!("{external_ip1}/{local_ip1}"); - - let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?)); - - let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "10.0.0.0/24".to_owned(), - static_ips: vec![map0.clone(), map1.clone()], - nat_type: Some(nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }), - ..Default::default() - })?)); - - connect_router2router(&lan, &wan).await?; - - let nw = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec![local_ip0.to_owned(), local_ip1.to_owned()], - ..Default::default() - }))); - - connect_net2router(&nw, &lan).await?; - - let a = Agent::new(AgentConfig { - network_types: vec![NetworkType::Udp4], - nat_1to1_ips: vec![map0.clone(), map1.clone()], - net: Some(Arc::clone(&nw)), - ..Default::default() - }) - .await?; - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); - - a.gather_candidates()?; - - log::debug!("wait for gathering is done..."); - let _ = done_rx.recv().await; - log::debug!("gathering is done"); - - let candidates = a.get_local_candidates().await?; - assert_eq!(candidates.len(), 2, "There must be two candidates"); - - let mut laddrs = vec![]; - for candi in &candidates { - if let Some(conn) = candi.get_conn() { - let laddr = conn.local_addr()?; - assert_eq!( - candi.port(), - laddr.port(), - "Unexpected candidate port: {}", - candi.port() - ); - laddrs.push(laddr); - } - } - - if candidates[0].address() == external_ip0 { - assert_eq!( - candidates[1].address(), - external_ip1, - "Unexpected candidate IP: {}", - candidates[1].address() - ); - assert_eq!( - laddrs[0].ip().to_string(), - local_ip0, - "Unexpected listen IP: {}", - laddrs[0].ip() - ); - assert_eq!( - laddrs[1].ip().to_string(), - local_ip1, - "Unexpected listen IP: {}", - laddrs[1].ip() - ); - } else if candidates[0].address() == external_ip1 { - assert_eq!( - candidates[1].address(), - external_ip0, - "Unexpected candidate IP: {}", - candidates[1].address() - ); - assert_eq!( - laddrs[0].ip().to_string(), - local_ip1, - "Unexpected listen IP: {}", - laddrs[0].ip(), - ); - assert_eq!( - laddrs[1].ip().to_string(), - local_ip0, - "Unexpected listen IP: {}", - laddrs[1].ip(), - ) - } - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_with_nat_1to1_as_srflx_candidates() -> Result<()> { - let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?)); - - let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "10.0.0.0/24".to_owned(), - static_ips: vec!["1.2.3.4/10.0.0.1".to_owned()], - nat_type: Some(nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }), - ..Default::default() - })?)); - - connect_router2router(&lan, &wan).await?; - - let nw = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["10.0.0.1".to_owned()], - ..Default::default() - }))); - - connect_net2router(&nw, &lan).await?; - - let a = Agent::new(AgentConfig { - network_types: vec![NetworkType::Udp4], - nat_1to1_ips: vec!["1.2.3.4".to_owned()], - nat_1to1_ip_candidate_type: CandidateType::ServerReflexive, - net: Some(nw), - ..Default::default() - }) - .await?; - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); - - a.gather_candidates()?; - - log::debug!("wait for gathering is done..."); - let _ = done_rx.recv().await; - log::debug!("gathering is done"); - - let candidates = a.get_local_candidates().await?; - assert_eq!(candidates.len(), 2, "There must be two candidates"); - - let mut candi_host = None; - let mut candi_srflx = None; - - for candidate in candidates { - match candidate.candidate_type() { - CandidateType::Host => { - candi_host = Some(candidate); - } - CandidateType::ServerReflexive => { - candi_srflx = Some(candidate); - } - _ => { - panic!("Unexpected candidate type"); - } - } - } - - assert!(candi_host.is_some(), "should not be nil"); - assert_eq!("10.0.0.1", candi_host.unwrap().address(), "should match"); - assert!(candi_srflx.is_some(), "should not be nil"); - assert_eq!("1.2.3.4", candi_srflx.unwrap().address(), "should match"); - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_with_interface_filter() -> Result<()> { - let r = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?)); - let nw = Arc::new(net::Net::new(Some(net::NetConfig::default()))); - connect_net2router(&nw, &r).await?; - - //"InterfaceFilter should exclude the interface" - { - let a = Agent::new(AgentConfig { - net: Some(Arc::clone(&nw)), - interface_filter: Arc::new(Some(Box::new(|_: &str| -> bool { - //assert_eq!("eth0", interface_name); - false - }))), - ..Default::default() - }) - .await?; - - let local_ips = - local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; - assert!( - local_ips.is_empty(), - "InterfaceFilter should have excluded everything" - ); - - a.close().await?; - } - - //"InterfaceFilter should not exclude the interface" - { - let a = Agent::new(AgentConfig { - net: Some(Arc::clone(&nw)), - interface_filter: Arc::new(Some(Box::new(|interface_name: &str| -> bool { - "eth0" == interface_name - }))), - ..Default::default() - }) - .await?; - - let local_ips = - local_interfaces(&nw, &a.interface_filter, &a.ip_filter, &[NetworkType::Udp4]).await; - assert_eq!( - local_ips.len(), - 1, - "InterfaceFilter should not have excluded everything" - ); - - a.close().await?; - } - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_turn_connection_leak() -> Result<()> { - let turn_server_url = Url { - scheme: SchemeType::Turn, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - username: "user".to_owned(), - password: "pass".to_owned(), - proto: ProtoType::Udp, - }; - - // buildVNet with a Symmetric NATs for both LANs - let nat_type = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - ..Default::default() - }; - - let v = build_vnet(nat_type, nat_type).await?; - - let cfg0 = AgentConfig { - urls: vec![turn_server_url.clone()], - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - nat_1to1_ips: vec![VNET_GLOBAL_IPA.to_owned()], - net: Some(Arc::clone(&v.net0)), - ..Default::default() - }; - - let a_agent = Agent::new(cfg0).await?; - - { - let agent_internal = Arc::clone(&a_agent.internal); - Agent::gather_candidates_relay( - vec![turn_server_url.clone()], - Arc::clone(&v.net0), - agent_internal, - ) - .await; - } - - // Assert relay conn leak on close. - a_agent.close().await?; - v.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_vnet_gather_muxed_udp() -> Result<()> { - let udp_socket = UdpSocket::bind("0.0.0.0:0").await?; - let udp_mux = UDPMuxDefault::new(UDPMuxParams::new(udp_socket)); - - let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "10.0.0.0/24".to_owned(), - nat_type: Some(nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }), - ..Default::default() - })?)); - - let nw = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["10.0.0.1".to_owned()], - ..Default::default() - }))); - - connect_net2router(&nw, &lan).await?; - - let a = Agent::new(AgentConfig { - network_types: vec![NetworkType::Udp4], - nat_1to1_ips: vec!["1.2.3.4".to_owned()], - net: Some(nw), - udp_network: UDPNetwork::Muxed(udp_mux), - ..Default::default() - }) - .await?; - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); - - a.gather_candidates()?; - - log::debug!("wait for gathering is done..."); - let _ = done_rx.recv().await; - log::debug!("gathering is done"); - - let candidates = a.get_local_candidates().await?; - assert_eq!(candidates.len(), 1, "There must be a single candidate"); - - let candi = &candidates[0]; - let laddr = candi.get_conn().unwrap().local_addr()?; - assert_eq!(candi.address(), "1.2.3.4"); - assert_eq!( - candi.port(), - laddr.port(), - "Unexpected candidate port: {}", - candi.port() - ); - - Ok(()) -} diff --git a/ice/src/agent/agent_internal.rs b/ice/src/agent/agent_internal.rs deleted file mode 100644 index dc943b5b4..000000000 --- a/ice/src/agent/agent_internal.rs +++ /dev/null @@ -1,1198 +0,0 @@ -use portable_atomic::{AtomicBool, AtomicU64}; - -use arc_swap::ArcSwapOption; -use util::sync::Mutex as SyncMutex; - -use super::agent_transport::*; -use super::*; -use crate::candidate::candidate_base::CandidateBaseConfig; -use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; -use crate::util::*; - -pub type ChanCandidateTx = - Arc>>>>>; - -#[derive(Default)] -pub(crate) struct UfragPwd { - pub(crate) local_ufrag: String, - pub(crate) local_pwd: String, - pub(crate) remote_ufrag: String, - pub(crate) remote_pwd: String, -} - -pub struct AgentInternal { - // State owned by the taskLoop - pub(crate) on_connected_tx: Mutex>>, - pub(crate) on_connected_rx: Mutex>>, - - // State for closing - pub(crate) done_tx: Mutex>>, - // force candidate to be contacted immediately (instead of waiting for task ticker) - pub(crate) force_candidate_contact_tx: mpsc::Sender, - pub(crate) done_and_force_candidate_contact_rx: - Mutex, mpsc::Receiver)>>, - - pub(crate) chan_candidate_tx: ChanCandidateTx, - pub(crate) chan_candidate_pair_tx: Mutex>>, - pub(crate) chan_state_tx: Mutex>>, - - pub(crate) on_connection_state_change_hdlr: ArcSwapOption>, - pub(crate) on_selected_candidate_pair_change_hdlr: - ArcSwapOption>, - pub(crate) on_candidate_hdlr: ArcSwapOption>, - - pub(crate) tie_breaker: AtomicU64, - pub(crate) is_controlling: AtomicBool, - pub(crate) lite: AtomicBool, - - pub(crate) start_time: SyncMutex, - pub(crate) nominated_pair: Mutex>>, - - pub(crate) connection_state: AtomicU8, //ConnectionState, - - pub(crate) started_ch_tx: Mutex>>, - - pub(crate) ufrag_pwd: Mutex, - - pub(crate) local_candidates: Mutex>>>, - pub(crate) remote_candidates: - Mutex>>>, - - // LRU of outbound Binding request Transaction IDs - pub(crate) pending_binding_requests: Mutex>, - - pub(crate) agent_conn: Arc, - - // the following variables won't be changed after init_with_defaults() - pub(crate) insecure_skip_verify: bool, - pub(crate) max_binding_requests: u16, - pub(crate) host_acceptance_min_wait: Duration, - pub(crate) srflx_acceptance_min_wait: Duration, - pub(crate) prflx_acceptance_min_wait: Duration, - pub(crate) relay_acceptance_min_wait: Duration, - // How long connectivity checks can fail before the ICE Agent - // goes to disconnected - pub(crate) disconnected_timeout: Duration, - // How long connectivity checks can fail before the ICE Agent - // goes to failed - pub(crate) failed_timeout: Duration, - // How often should we send keepalive packets? - // 0 means never - pub(crate) keepalive_interval: Duration, - // How often should we run our internal taskLoop to check for state changes when connecting - pub(crate) check_interval: Duration, -} - -impl AgentInternal { - pub(super) fn new(config: &AgentConfig) -> (Self, ChanReceivers) { - let (chan_state_tx, chan_state_rx) = mpsc::channel(1); - let (chan_candidate_tx, chan_candidate_rx) = mpsc::channel(1); - let (chan_candidate_pair_tx, chan_candidate_pair_rx) = mpsc::channel(1); - let (on_connected_tx, on_connected_rx) = mpsc::channel(1); - let (done_tx, done_rx) = mpsc::channel(1); - let (force_candidate_contact_tx, force_candidate_contact_rx) = mpsc::channel(1); - let (started_ch_tx, _) = broadcast::channel(1); - - let ai = AgentInternal { - on_connected_tx: Mutex::new(Some(on_connected_tx)), - on_connected_rx: Mutex::new(Some(on_connected_rx)), - - done_tx: Mutex::new(Some(done_tx)), - force_candidate_contact_tx, - done_and_force_candidate_contact_rx: Mutex::new(Some(( - done_rx, - force_candidate_contact_rx, - ))), - - chan_candidate_tx: Arc::new(Mutex::new(Some(chan_candidate_tx))), - chan_candidate_pair_tx: Mutex::new(Some(chan_candidate_pair_tx)), - chan_state_tx: Mutex::new(Some(chan_state_tx)), - - on_connection_state_change_hdlr: ArcSwapOption::empty(), - on_selected_candidate_pair_change_hdlr: ArcSwapOption::empty(), - on_candidate_hdlr: ArcSwapOption::empty(), - - tie_breaker: AtomicU64::new(rand::random::()), - is_controlling: AtomicBool::new(config.is_controlling), - lite: AtomicBool::new(config.lite), - - start_time: SyncMutex::new(Instant::now()), - nominated_pair: Mutex::new(None), - - connection_state: AtomicU8::new(ConnectionState::New as u8), - - insecure_skip_verify: config.insecure_skip_verify, - - started_ch_tx: Mutex::new(Some(started_ch_tx)), - - //won't change after init_with_defaults() - max_binding_requests: 0, - host_acceptance_min_wait: Duration::from_secs(0), - srflx_acceptance_min_wait: Duration::from_secs(0), - prflx_acceptance_min_wait: Duration::from_secs(0), - relay_acceptance_min_wait: Duration::from_secs(0), - - // How long connectivity checks can fail before the ICE Agent - // goes to disconnected - disconnected_timeout: Duration::from_secs(0), - - // How long connectivity checks can fail before the ICE Agent - // goes to failed - failed_timeout: Duration::from_secs(0), - - // How often should we send keepalive packets? - // 0 means never - keepalive_interval: Duration::from_secs(0), - - // How often should we run our internal taskLoop to check for state changes when connecting - check_interval: Duration::from_secs(0), - - ufrag_pwd: Mutex::new(UfragPwd::default()), - - local_candidates: Mutex::new(HashMap::new()), - remote_candidates: Mutex::new(HashMap::new()), - - // LRU of outbound Binding request Transaction IDs - pending_binding_requests: Mutex::new(vec![]), - - // AgentConn - agent_conn: Arc::new(AgentConn::new()), - }; - - let chan_receivers = ChanReceivers { - chan_state_rx, - chan_candidate_rx, - chan_candidate_pair_rx, - }; - (ai, chan_receivers) - } - pub(crate) async fn start_connectivity_checks( - self: &Arc, - is_controlling: bool, - remote_ufrag: String, - remote_pwd: String, - ) -> Result<()> { - { - let started_ch_tx = self.started_ch_tx.lock().await; - if started_ch_tx.is_none() { - return Err(Error::ErrMultipleStart); - } - } - - log::debug!( - "Started agent: isControlling? {}, remoteUfrag: {}, remotePwd: {}", - is_controlling, - remote_ufrag, - remote_pwd - ); - self.set_remote_credentials(remote_ufrag, remote_pwd) - .await?; - self.is_controlling.store(is_controlling, Ordering::SeqCst); - self.start().await; - { - let mut started_ch_tx = self.started_ch_tx.lock().await; - started_ch_tx.take(); - } - - self.update_connection_state(ConnectionState::Checking) - .await; - - self.request_connectivity_check(); - - self.connectivity_checks().await; - - Ok(()) - } - - async fn contact( - &self, - last_connection_state: &mut ConnectionState, - checking_duration: &mut Instant, - ) { - if self.connection_state.load(Ordering::SeqCst) == ConnectionState::Failed as u8 { - // The connection is currently failed so don't send any checks - // In the future it may be restarted though - *last_connection_state = self.connection_state.load(Ordering::SeqCst).into(); - return; - } - if self.connection_state.load(Ordering::SeqCst) == ConnectionState::Checking as u8 { - // We have just entered checking for the first time so update our checking timer - if *last_connection_state as u8 != self.connection_state.load(Ordering::SeqCst) { - *checking_duration = Instant::now(); - } - - // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed - if Instant::now() - .checked_duration_since(*checking_duration) - .unwrap_or_else(|| Duration::from_secs(0)) - > self.disconnected_timeout + self.failed_timeout - { - self.update_connection_state(ConnectionState::Failed).await; - *last_connection_state = self.connection_state.load(Ordering::SeqCst).into(); - return; - } - } - - self.contact_candidates().await; - - *last_connection_state = self.connection_state.load(Ordering::SeqCst).into(); - } - - async fn connectivity_checks(self: &Arc) { - const ZERO_DURATION: Duration = Duration::from_secs(0); - let mut last_connection_state = ConnectionState::Unspecified; - let mut checking_duration = Instant::now(); - let (check_interval, keepalive_interval, disconnected_timeout, failed_timeout) = ( - self.check_interval, - self.keepalive_interval, - self.disconnected_timeout, - self.failed_timeout, - ); - - let done_and_force_candidate_contact_rx = { - let mut done_and_force_candidate_contact_rx = - self.done_and_force_candidate_contact_rx.lock().await; - done_and_force_candidate_contact_rx.take() - }; - - if let Some((mut done_rx, mut force_candidate_contact_rx)) = - done_and_force_candidate_contact_rx - { - let ai = Arc::clone(self); - tokio::spawn(async move { - loop { - let mut interval = DEFAULT_CHECK_INTERVAL; - - let mut update_interval = |x: Duration| { - if x != ZERO_DURATION && (interval == ZERO_DURATION || interval > x) { - interval = x; - } - }; - - match last_connection_state { - ConnectionState::New | ConnectionState::Checking => { - // While connecting, check candidates more frequently - update_interval(check_interval); - } - ConnectionState::Connected | ConnectionState::Disconnected => { - update_interval(keepalive_interval); - } - _ => {} - }; - // Ensure we run our task loop as quickly as the minimum of our various configured timeouts - update_interval(disconnected_timeout); - update_interval(failed_timeout); - - let t = tokio::time::sleep(interval); - tokio::pin!(t); - - tokio::select! { - _ = t.as_mut() => { - ai.contact(&mut last_connection_state, &mut checking_duration).await; - }, - _ = force_candidate_contact_rx.recv() => { - ai.contact(&mut last_connection_state, &mut checking_duration).await; - }, - _ = done_rx.recv() => { - return; - } - } - } - }); - } - } - - pub(crate) async fn update_connection_state(&self, new_state: ConnectionState) { - if self.connection_state.load(Ordering::SeqCst) != new_state as u8 { - // Connection has gone to failed, release all gathered candidates - if new_state == ConnectionState::Failed { - self.delete_all_candidates().await; - } - - log::info!( - "[{}]: Setting new connection state: {}", - self.get_name(), - new_state - ); - self.connection_state - .store(new_state as u8, Ordering::SeqCst); - - // Call handler after finishing current task since we may be holding the agent lock - // and the handler may also require it - { - let chan_state_tx = self.chan_state_tx.lock().await; - if let Some(tx) = &*chan_state_tx { - let _ = tx.send(new_state).await; - } - } - } - } - - pub(crate) async fn set_selected_pair(&self, p: Option>) { - log::trace!( - "[{}]: Set selected candidate pair: {:?}", - self.get_name(), - p - ); - - if let Some(p) = p { - p.nominated.store(true, Ordering::SeqCst); - self.agent_conn.selected_pair.store(Some(p)); - - self.update_connection_state(ConnectionState::Connected) - .await; - - // Notify when the selected pair changes - { - let chan_candidate_pair_tx = self.chan_candidate_pair_tx.lock().await; - if let Some(tx) = &*chan_candidate_pair_tx { - let _ = tx.send(()).await; - } - } - - // Signal connected - { - let mut on_connected_tx = self.on_connected_tx.lock().await; - on_connected_tx.take(); - } - } else { - self.agent_conn.selected_pair.store(None); - } - } - - pub(crate) async fn ping_all_candidates(&self) { - log::trace!("[{}]: pinging all candidates", self.get_name(),); - - let mut pairs: Vec<( - Arc, - Arc, - )> = vec![]; - - { - let mut checklist = self.agent_conn.checklist.lock().await; - if checklist.is_empty() { - log::warn!( - "[{}]: pingAllCandidates called with no candidate pairs. Connection is not possible yet.", - self.get_name(), - ); - } - for p in &mut *checklist { - let p_state = p.state.load(Ordering::SeqCst); - if p_state == CandidatePairState::Waiting as u8 { - p.state - .store(CandidatePairState::InProgress as u8, Ordering::SeqCst); - } else if p_state != CandidatePairState::InProgress as u8 { - continue; - } - - if p.binding_request_count.load(Ordering::SeqCst) > self.max_binding_requests { - log::trace!( - "[{}]: max requests reached for pair {}, marking it as failed", - self.get_name(), - p - ); - p.state - .store(CandidatePairState::Failed as u8, Ordering::SeqCst); - } else { - p.binding_request_count.fetch_add(1, Ordering::SeqCst); - let local = p.local.clone(); - let remote = p.remote.clone(); - pairs.push((local, remote)); - } - } - } - - for (local, remote) in pairs { - self.ping_candidate(&local, &remote).await; - } - } - - pub(crate) async fn add_pair( - &self, - local: Arc, - remote: Arc, - ) { - let p = Arc::new(CandidatePair::new( - local, - remote, - self.is_controlling.load(Ordering::SeqCst), - )); - let mut checklist = self.agent_conn.checklist.lock().await; - checklist.push(p); - } - - pub(crate) async fn find_pair( - &self, - local: &Arc, - remote: &Arc, - ) -> Option> { - let checklist = self.agent_conn.checklist.lock().await; - for p in &*checklist { - if p.local.equal(&**local) && p.remote.equal(&**remote) { - return Some(p.clone()); - } - } - None - } - - /// Checks if the selected pair is (still) valid. - /// Note: the caller should hold the agent lock. - pub(crate) async fn validate_selected_pair(&self) -> bool { - let (valid, disconnected_time) = { - let selected_pair = self.agent_conn.selected_pair.load(); - (*selected_pair).as_ref().map_or_else( - || (false, Duration::from_secs(0)), - |selected_pair| { - let disconnected_time = SystemTime::now() - .duration_since(selected_pair.remote.last_received()) - .unwrap_or_else(|_| Duration::from_secs(0)); - (true, disconnected_time) - }, - ) - }; - - if valid { - // Only allow transitions to failed if a.failedTimeout is non-zero - let mut total_time_to_failure = self.failed_timeout; - if total_time_to_failure != Duration::from_secs(0) { - total_time_to_failure += self.disconnected_timeout; - } - - if total_time_to_failure != Duration::from_secs(0) - && disconnected_time > total_time_to_failure - { - self.update_connection_state(ConnectionState::Failed).await; - } else if self.disconnected_timeout != Duration::from_secs(0) - && disconnected_time > self.disconnected_timeout - { - self.update_connection_state(ConnectionState::Disconnected) - .await; - } else { - self.update_connection_state(ConnectionState::Connected) - .await; - } - } - - valid - } - - /// Sends STUN Binding Indications to the selected pair. - /// if no packet has been sent on that pair in the last keepaliveInterval. - /// Note: the caller should hold the agent lock. - pub(crate) async fn check_keepalive(&self) { - let (local, remote) = { - let selected_pair = self.agent_conn.selected_pair.load(); - (*selected_pair) - .as_ref() - .map_or((None, None), |selected_pair| { - ( - Some(selected_pair.local.clone()), - Some(selected_pair.remote.clone()), - ) - }) - }; - - if let (Some(local), Some(remote)) = (local, remote) { - let last_sent = SystemTime::now() - .duration_since(local.last_sent()) - .unwrap_or_else(|_| Duration::from_secs(0)); - - let last_received = SystemTime::now() - .duration_since(remote.last_received()) - .unwrap_or_else(|_| Duration::from_secs(0)); - - if (self.keepalive_interval != Duration::from_secs(0)) - && ((last_sent > self.keepalive_interval) - || (last_received > self.keepalive_interval)) - { - // we use binding request instead of indication to support refresh consent schemas - // see https://tools.ietf.org/html/rfc7675 - self.ping_candidate(&local, &remote).await; - } - } - } - - fn request_connectivity_check(&self) { - let _ = self.force_candidate_contact_tx.try_send(true); - } - - /// Assumes you are holding the lock (must be execute using a.run). - pub(crate) async fn add_remote_candidate(&self, c: &Arc) { - let network_type = c.network_type(); - - { - let mut remote_candidates = self.remote_candidates.lock().await; - if let Some(cands) = remote_candidates.get(&network_type) { - for cand in cands { - if cand.equal(&**c) { - return; - } - } - } - - if let Some(cands) = remote_candidates.get_mut(&network_type) { - cands.push(c.clone()); - } else { - remote_candidates.insert(network_type, vec![c.clone()]); - } - } - - let mut local_cands = vec![]; - { - let local_candidates = self.local_candidates.lock().await; - if let Some(cands) = local_candidates.get(&network_type) { - local_cands.clone_from(cands); - } - } - - for cand in local_cands { - self.add_pair(cand, c.clone()).await; - } - - self.request_connectivity_check(); - } - - pub(crate) async fn add_candidate( - self: &Arc, - c: &Arc, - ) -> Result<()> { - let initialized_ch = { - let started_ch_tx = self.started_ch_tx.lock().await; - (*started_ch_tx).as_ref().map(|tx| tx.subscribe()) - }; - - self.start_candidate(c, initialized_ch).await; - - let network_type = c.network_type(); - { - let mut local_candidates = self.local_candidates.lock().await; - if let Some(cands) = local_candidates.get(&network_type) { - for cand in cands { - if cand.equal(&**c) { - if let Err(err) = c.close().await { - log::warn!( - "[{}]: Failed to close duplicate candidate: {}", - self.get_name(), - err - ); - } - //TODO: why return? - return Ok(()); - } - } - } - - if let Some(cands) = local_candidates.get_mut(&network_type) { - cands.push(c.clone()); - } else { - local_candidates.insert(network_type, vec![c.clone()]); - } - } - - let mut remote_cands = vec![]; - { - let remote_candidates = self.remote_candidates.lock().await; - if let Some(cands) = remote_candidates.get(&network_type) { - remote_cands.clone_from(cands); - } - } - - for cand in remote_cands { - self.add_pair(c.clone(), cand).await; - } - - self.request_connectivity_check(); - { - let chan_candidate_tx = self.chan_candidate_tx.lock().await; - if let Some(tx) = &*chan_candidate_tx { - let _ = tx.send(Some(c.clone())).await; - } - } - - Ok(()) - } - - pub(crate) async fn close(&self) -> Result<()> { - { - let mut done_tx = self.done_tx.lock().await; - if done_tx.is_none() { - return Err(Error::ErrClosed); - } - done_tx.take(); - }; - self.delete_all_candidates().await; - { - let mut started_ch_tx = self.started_ch_tx.lock().await; - started_ch_tx.take(); - } - - self.agent_conn.buffer.close().await; - - self.update_connection_state(ConnectionState::Closed).await; - - { - let mut chan_candidate_tx = self.chan_candidate_tx.lock().await; - chan_candidate_tx.take(); - } - { - let mut chan_candidate_pair_tx = self.chan_candidate_pair_tx.lock().await; - chan_candidate_pair_tx.take(); - } - { - let mut chan_state_tx = self.chan_state_tx.lock().await; - chan_state_tx.take(); - } - - self.agent_conn.done.store(true, Ordering::SeqCst); - - Ok(()) - } - - /// Remove all candidates. - /// This closes any listening sockets and removes both the local and remote candidate lists. - /// - /// This is used for restarts, failures and on close. - pub(crate) async fn delete_all_candidates(&self) { - { - let mut local_candidates = self.local_candidates.lock().await; - for cs in local_candidates.values_mut() { - for c in cs { - if let Err(err) = c.close().await { - log::warn!( - "[{}]: Failed to close candidate {}: {}", - self.get_name(), - c, - err - ); - } - } - } - local_candidates.clear(); - } - - { - let mut remote_candidates = self.remote_candidates.lock().await; - for cs in remote_candidates.values_mut() { - for c in cs { - if let Err(err) = c.close().await { - log::warn!( - "[{}]: Failed to close candidate {}: {}", - self.get_name(), - c, - err - ); - } - } - } - remote_candidates.clear(); - } - } - - pub(crate) async fn find_remote_candidate( - &self, - network_type: NetworkType, - addr: SocketAddr, - ) -> Option> { - let (ip, port) = (addr.ip(), addr.port()); - - let remote_candidates = self.remote_candidates.lock().await; - if let Some(cands) = remote_candidates.get(&network_type) { - for c in cands { - if c.address() == ip.to_string() && c.port() == port { - return Some(c.clone()); - } - } - } - None - } - - pub(crate) async fn send_binding_request( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ) { - log::trace!( - "[{}]: ping STUN from {} to {}", - self.get_name(), - local, - remote - ); - - self.invalidate_pending_binding_requests(Instant::now()) - .await; - { - let mut pending_binding_requests = self.pending_binding_requests.lock().await; - pending_binding_requests.push(BindingRequest { - timestamp: Instant::now(), - transaction_id: m.transaction_id, - destination: remote.addr(), - is_use_candidate: m.contains(ATTR_USE_CANDIDATE), - }); - } - - self.send_stun(m, local, remote).await; - } - - pub(crate) async fn send_binding_success( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ) { - let addr = remote.addr(); - let (ip, port) = (addr.ip(), addr.port()); - let local_pwd = { - let ufrag_pwd = self.ufrag_pwd.lock().await; - ufrag_pwd.local_pwd.clone() - }; - - let (out, result) = { - let mut out = Message::new(); - let result = out.build(&[ - Box::new(m.clone()), - Box::new(BINDING_SUCCESS), - Box::new(XorMappedAddress { ip, port }), - Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)), - Box::new(FINGERPRINT), - ]); - (out, result) - }; - - if let Err(err) = result { - log::warn!( - "[{}]: Failed to handle inbound ICE from: {} to: {} error: {}", - self.get_name(), - local, - remote, - err - ); - } else { - self.send_stun(&out, local, remote).await; - } - } - - /// Removes pending binding requests that are over `maxBindingRequestTimeout` old Let HTO be the - /// transaction timeout, which SHOULD be 2*RTT if RTT is known or 500 ms otherwise. - /// - /// reference: (IETF ref-8445)[https://tools.ietf.org/html/rfc8445#appendix-B.1]. - pub(crate) async fn invalidate_pending_binding_requests(&self, filter_time: Instant) { - let mut pending_binding_requests = self.pending_binding_requests.lock().await; - let initial_size = pending_binding_requests.len(); - - let mut temp = vec![]; - for binding_request in pending_binding_requests.drain(..) { - if filter_time - .checked_duration_since(binding_request.timestamp) - .map(|duration| duration < MAX_BINDING_REQUEST_TIMEOUT) - .unwrap_or(true) - { - temp.push(binding_request); - } - } - - *pending_binding_requests = temp; - let bind_requests_removed = initial_size - pending_binding_requests.len(); - if bind_requests_removed > 0 { - log::trace!( - "[{}]: Discarded {} binding requests because they expired", - self.get_name(), - bind_requests_removed - ); - } - } - - /// Assert that the passed `TransactionID` is in our `pendingBindingRequests` and returns the - /// destination, If the bindingRequest was valid remove it from our pending cache. - pub(crate) async fn handle_inbound_binding_success( - &self, - id: TransactionId, - ) -> Option { - self.invalidate_pending_binding_requests(Instant::now()) - .await; - - let mut pending_binding_requests = self.pending_binding_requests.lock().await; - for i in 0..pending_binding_requests.len() { - if pending_binding_requests[i].transaction_id == id { - let valid_binding_request = pending_binding_requests.remove(i); - return Some(valid_binding_request); - } - } - None - } - - /// Processes STUN traffic from a remote candidate. - pub(crate) async fn handle_inbound( - &self, - m: &mut Message, - local: &Arc, - remote: SocketAddr, - ) { - if m.typ.method != METHOD_BINDING - || !(m.typ.class == CLASS_SUCCESS_RESPONSE - || m.typ.class == CLASS_REQUEST - || m.typ.class == CLASS_INDICATION) - { - log::trace!( - "[{}]: unhandled STUN from {} to {} class({}) method({})", - self.get_name(), - remote, - local, - m.typ.class, - m.typ.method - ); - return; - } - - if self.is_controlling.load(Ordering::SeqCst) { - if m.contains(ATTR_ICE_CONTROLLING) { - log::debug!( - "[{}]: inbound isControlling && a.isControlling == true", - self.get_name(), - ); - return; - } else if m.contains(ATTR_USE_CANDIDATE) { - log::debug!( - "[{}]: useCandidate && a.isControlling == true", - self.get_name(), - ); - return; - } - } else if m.contains(ATTR_ICE_CONTROLLED) { - log::debug!( - "[{}]: inbound isControlled && a.isControlling == false", - self.get_name(), - ); - return; - } - - let mut remote_candidate = self - .find_remote_candidate(local.network_type(), remote) - .await; - if m.typ.class == CLASS_SUCCESS_RESPONSE { - { - let ufrag_pwd = self.ufrag_pwd.lock().await; - if let Err(err) = - assert_inbound_message_integrity(m, ufrag_pwd.remote_pwd.as_bytes()) - { - log::warn!( - "[{}]: discard message from ({}), {}", - self.get_name(), - remote, - err - ); - return; - } - } - - if let Some(rc) = &remote_candidate { - self.handle_success_response(m, local, rc, remote).await; - } else { - log::warn!( - "[{}]: discard success message from ({}), no such remote", - self.get_name(), - remote - ); - return; - } - } else if m.typ.class == CLASS_REQUEST { - { - let ufrag_pwd = self.ufrag_pwd.lock().await; - let username = - ufrag_pwd.local_ufrag.clone() + ":" + ufrag_pwd.remote_ufrag.as_str(); - if let Err(err) = assert_inbound_username(m, &username) { - log::warn!( - "[{}]: discard message from ({}), {}", - self.get_name(), - remote, - err - ); - return; - } else if let Err(err) = - assert_inbound_message_integrity(m, ufrag_pwd.local_pwd.as_bytes()) - { - log::warn!( - "[{}]: discard message from ({}), {}", - self.get_name(), - remote, - err - ); - return; - } - } - - if remote_candidate.is_none() { - let (ip, port, network_type) = (remote.ip(), remote.port(), NetworkType::Udp4); - - let prflx_candidate_config = CandidatePeerReflexiveConfig { - base_config: CandidateBaseConfig { - network: network_type.to_string(), - address: ip.to_string(), - port, - component: local.component(), - ..CandidateBaseConfig::default() - }, - rel_addr: "".to_owned(), - rel_port: 0, - }; - - match prflx_candidate_config.new_candidate_peer_reflexive() { - Ok(prflx_candidate) => remote_candidate = Some(Arc::new(prflx_candidate)), - Err(err) => { - log::error!( - "[{}]: Failed to create new remote prflx candidate ({})", - self.get_name(), - err - ); - return; - } - }; - - log::debug!( - "[{}]: adding a new peer-reflexive candidate: {} ", - self.get_name(), - remote - ); - if let Some(rc) = &remote_candidate { - self.add_remote_candidate(rc).await; - } - } - - log::trace!( - "[{}]: inbound STUN (Request) from {} to {}", - self.get_name(), - remote, - local - ); - - if let Some(rc) = &remote_candidate { - self.handle_binding_request(m, local, rc).await; - } - } - - if let Some(rc) = remote_candidate { - rc.seen(false); - } - } - - /// Processes non STUN traffic from a remote candidate, and returns true if it is an actual - /// remote candidate. - pub(crate) async fn validate_non_stun_traffic( - &self, - local: &Arc, - remote: SocketAddr, - ) -> bool { - self.find_remote_candidate(local.network_type(), remote) - .await - .map_or(false, |remote_candidate| { - remote_candidate.seen(false); - true - }) - } - - /// Sets the credentials of the remote agent. - pub(crate) async fn set_remote_credentials( - &self, - remote_ufrag: String, - remote_pwd: String, - ) -> Result<()> { - if remote_ufrag.is_empty() { - return Err(Error::ErrRemoteUfragEmpty); - } else if remote_pwd.is_empty() { - return Err(Error::ErrRemotePwdEmpty); - } - - let mut ufrag_pwd = self.ufrag_pwd.lock().await; - ufrag_pwd.remote_ufrag = remote_ufrag; - ufrag_pwd.remote_pwd = remote_pwd; - Ok(()) - } - - pub(crate) async fn send_stun( - &self, - msg: &Message, - local: &Arc, - remote: &Arc, - ) { - if let Err(err) = local.write_to(&msg.raw, &**remote).await { - log::trace!( - "[{}]: failed to send STUN message: {}", - self.get_name(), - err - ); - } - } - - /// Runs the candidate using the provided connection. - async fn start_candidate( - self: &Arc, - candidate: &Arc, - initialized_ch: Option>, - ) { - let (closed_ch_tx, closed_ch_rx) = broadcast::channel(1); - { - let closed_ch = candidate.get_closed_ch(); - let mut closed = closed_ch.lock().await; - *closed = Some(closed_ch_tx); - } - - let cand = Arc::clone(candidate); - if let Some(conn) = candidate.get_conn() { - let conn = Arc::clone(conn); - let addr = candidate.addr(); - let ai = Arc::clone(self); - tokio::spawn(async move { - let _ = ai - .recv_loop(cand, closed_ch_rx, initialized_ch, conn, addr) - .await; - }); - } else { - log::error!("[{}]: Can't start due to conn is_none", self.get_name(),); - } - } - - pub(super) fn start_on_connection_state_change_routine( - self: &Arc, - mut chan_state_rx: mpsc::Receiver, - mut chan_candidate_rx: mpsc::Receiver>>, - mut chan_candidate_pair_rx: mpsc::Receiver<()>, - ) { - let ai = Arc::clone(self); - tokio::spawn(async move { - // CandidatePair and ConnectionState are usually changed at once. - // Blocking one by the other one causes deadlock. - while chan_candidate_pair_rx.recv().await.is_some() { - if let (Some(cb), Some(p)) = ( - &*ai.on_selected_candidate_pair_change_hdlr.load(), - &*ai.agent_conn.selected_pair.load(), - ) { - let mut f = cb.lock().await; - f(&p.local, &p.remote).await; - } - } - }); - - let ai = Arc::clone(self); - tokio::spawn(async move { - loop { - tokio::select! { - opt_state = chan_state_rx.recv() => { - if let Some(s) = opt_state { - if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() { - let mut f = handler.lock().await; - f(s).await; - } - } else { - while let Some(c) = chan_candidate_rx.recv().await { - if let Some(handler) = &*ai.on_candidate_hdlr.load() { - let mut f = handler.lock().await; - f(c).await; - } - } - break; - } - }, - opt_cand = chan_candidate_rx.recv() => { - if let Some(c) = opt_cand { - if let Some(handler) = &*ai.on_candidate_hdlr.load() { - let mut f = handler.lock().await; - f(c).await; - } - } else { - while let Some(s) = chan_state_rx.recv().await { - if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() { - let mut f = handler.lock().await; - f(s).await; - } - } - break; - } - } - } - } - }); - } - - async fn recv_loop( - self: &Arc, - candidate: Arc, - mut closed_ch_rx: broadcast::Receiver<()>, - initialized_ch: Option>, - conn: Arc, - addr: SocketAddr, - ) -> Result<()> { - if let Some(mut initialized_ch) = initialized_ch { - tokio::select! { - _ = initialized_ch.recv() => {} - _ = closed_ch_rx.recv() => return Err(Error::ErrClosed), - } - } - - let mut buffer = vec![0_u8; RECEIVE_MTU]; - let mut n; - let mut src_addr; - loop { - tokio::select! { - result = conn.recv_from(&mut buffer) => { - match result { - Ok((num, src)) => { - n = num; - src_addr = src; - } - Err(err) => return Err(Error::Other(err.to_string())), - } - }, - _ = closed_ch_rx.recv() => return Err(Error::ErrClosed), - } - - self.handle_inbound_candidate_msg(&candidate, &buffer[..n], src_addr, addr) - .await; - } - } - - async fn handle_inbound_candidate_msg( - self: &Arc, - c: &Arc, - buf: &[u8], - src_addr: SocketAddr, - addr: SocketAddr, - ) { - if stun::message::is_message(buf) { - let mut m = Message { - raw: vec![], - ..Message::default() - }; - // Explicitly copy raw buffer so Message can own the memory. - m.raw.extend_from_slice(buf); - - if let Err(err) = m.decode() { - log::warn!( - "[{}]: Failed to handle decode ICE from {} to {}: {}", - self.get_name(), - addr, - src_addr, - err - ); - } else { - self.handle_inbound(&mut m, c, src_addr).await; - } - } else if !self.validate_non_stun_traffic(c, src_addr).await { - log::warn!( - "[{}]: Discarded message, not a valid remote candidate", - self.get_name(), - //c.addr().await //from {} - ); - } else if let Err(err) = self.agent_conn.buffer.write(buf).await { - // NOTE This will return packetio.ErrFull if the buffer ever manages to fill up. - log::warn!("[{}]: failed to write packet: {}", self.get_name(), err); - } - } - - pub(crate) fn get_name(&self) -> &str { - if self.is_controlling.load(Ordering::SeqCst) { - "controlling" - } else { - "controlled" - } - } -} diff --git a/ice/src/agent/agent_selector.rs b/ice/src/agent/agent_selector.rs deleted file mode 100644 index b7e05fc40..000000000 --- a/ice/src/agent/agent_selector.rs +++ /dev/null @@ -1,545 +0,0 @@ -use std::net::SocketAddr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use stun::agent::*; -use stun::attributes::*; -use stun::fingerprint::*; -use stun::integrity::*; -use stun::message::*; -use stun::textattrs::*; -use tokio::time::{Duration, Instant}; - -use crate::agent::agent_internal::*; -use crate::candidate::*; -use crate::control::*; -use crate::priority::*; -use crate::use_candidate::*; - -#[async_trait] -trait ControllingSelector { - async fn start(&self); - async fn contact_candidates(&self); - async fn ping_candidate( - &self, - local: &Arc, - remote: &Arc, - ); - async fn handle_success_response( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - remote_addr: SocketAddr, - ); - async fn handle_binding_request( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ); -} - -#[async_trait] -trait ControlledSelector { - async fn start(&self); - async fn contact_candidates(&self); - async fn ping_candidate( - &self, - local: &Arc, - remote: &Arc, - ); - async fn handle_success_response( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - remote_addr: SocketAddr, - ); - async fn handle_binding_request( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ); -} - -impl AgentInternal { - fn is_nominatable(&self, c: &Arc) -> bool { - let start_time = *self.start_time.lock(); - match c.candidate_type() { - CandidateType::Host => { - Instant::now() - .checked_duration_since(start_time) - .unwrap_or_else(|| Duration::from_secs(0)) - .as_nanos() - > self.host_acceptance_min_wait.as_nanos() - } - CandidateType::ServerReflexive => { - Instant::now() - .checked_duration_since(start_time) - .unwrap_or_else(|| Duration::from_secs(0)) - .as_nanos() - > self.srflx_acceptance_min_wait.as_nanos() - } - CandidateType::PeerReflexive => { - Instant::now() - .checked_duration_since(start_time) - .unwrap_or_else(|| Duration::from_secs(0)) - .as_nanos() - > self.prflx_acceptance_min_wait.as_nanos() - } - CandidateType::Relay => { - Instant::now() - .checked_duration_since(start_time) - .unwrap_or_else(|| Duration::from_secs(0)) - .as_nanos() - > self.relay_acceptance_min_wait.as_nanos() - } - CandidateType::Unspecified => { - log::error!( - "is_nominatable invalid candidate type {}", - c.candidate_type() - ); - false - } - } - } - - async fn nominate_pair(&self) { - let result = { - let nominated_pair = self.nominated_pair.lock().await; - if let Some(pair) = &*nominated_pair { - // The controlling agent MUST include the USE-CANDIDATE attribute in - // order to nominate a candidate pair (Section 8.1.1). The controlled - // agent MUST NOT include the USE-CANDIDATE attribute in a Binding - // request. - - let (msg, result) = { - let ufrag_pwd = self.ufrag_pwd.lock().await; - let username = - ufrag_pwd.remote_ufrag.clone() + ":" + ufrag_pwd.local_ufrag.as_str(); - let mut msg = Message::new(); - let result = msg.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId::new()), - Box::new(Username::new(ATTR_USERNAME, username)), - Box::::default(), - Box::new(AttrControlling(self.tie_breaker.load(Ordering::SeqCst))), - Box::new(PriorityAttr(pair.local.priority())), - Box::new(MessageIntegrity::new_short_term_integrity( - ufrag_pwd.remote_pwd.clone(), - )), - Box::new(FINGERPRINT), - ]); - (msg, result) - }; - - if let Err(err) = result { - log::error!("{}", err); - None - } else { - log::trace!( - "ping STUN (nominate candidate pair from {} to {}", - pair.local, - pair.remote - ); - let local = pair.local.clone(); - let remote = pair.remote.clone(); - Some((msg, local, remote)) - } - } else { - None - } - }; - - if let Some((msg, local, remote)) = result { - self.send_binding_request(&msg, &local, &remote).await; - } - } - - pub(crate) async fn start(&self) { - if self.is_controlling.load(Ordering::SeqCst) { - ControllingSelector::start(self).await; - } else { - ControlledSelector::start(self).await; - } - } - - pub(crate) async fn contact_candidates(&self) { - if self.is_controlling.load(Ordering::SeqCst) { - ControllingSelector::contact_candidates(self).await; - } else { - ControlledSelector::contact_candidates(self).await; - } - } - - pub(crate) async fn ping_candidate( - &self, - local: &Arc, - remote: &Arc, - ) { - if self.is_controlling.load(Ordering::SeqCst) { - ControllingSelector::ping_candidate(self, local, remote).await; - } else { - ControlledSelector::ping_candidate(self, local, remote).await; - } - } - - pub(crate) async fn handle_success_response( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - remote_addr: SocketAddr, - ) { - if self.is_controlling.load(Ordering::SeqCst) { - ControllingSelector::handle_success_response(self, m, local, remote, remote_addr).await; - } else { - ControlledSelector::handle_success_response(self, m, local, remote, remote_addr).await; - } - } - - pub(crate) async fn handle_binding_request( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ) { - if self.is_controlling.load(Ordering::SeqCst) { - ControllingSelector::handle_binding_request(self, m, local, remote).await; - } else { - ControlledSelector::handle_binding_request(self, m, local, remote).await; - } - } -} - -#[async_trait] -impl ControllingSelector for AgentInternal { - async fn start(&self) { - { - let mut nominated_pair = self.nominated_pair.lock().await; - *nominated_pair = None; - } - *self.start_time.lock() = Instant::now(); - } - - async fn contact_candidates(&self) { - // A lite selector should not contact candidates - if self.lite.load(Ordering::SeqCst) { - // This only happens if both peers are lite. See RFC 8445 S6.1.1 and S6.2 - log::trace!("now falling back to full agent"); - } - - let nominated_pair_is_some = { - let nominated_pair = self.nominated_pair.lock().await; - nominated_pair.is_some() - }; - - if self.agent_conn.get_selected_pair().is_some() { - if self.validate_selected_pair().await { - log::trace!("[{}]: checking keepalive", self.get_name()); - self.check_keepalive().await; - } - } else if nominated_pair_is_some { - self.nominate_pair().await; - } else { - let has_nominated_pair = - if let Some(p) = self.agent_conn.get_best_valid_candidate_pair().await { - self.is_nominatable(&p.local) && self.is_nominatable(&p.remote) - } else { - false - }; - - if has_nominated_pair { - if let Some(p) = self.agent_conn.get_best_valid_candidate_pair().await { - log::trace!( - "Nominatable pair found, nominating ({}, {})", - p.local.to_string(), - p.remote.to_string() - ); - p.nominated.store(true, Ordering::SeqCst); - { - let mut nominated_pair = self.nominated_pair.lock().await; - *nominated_pair = Some(p); - } - } - - self.nominate_pair().await; - } else { - self.ping_all_candidates().await; - } - } - } - - async fn ping_candidate( - &self, - local: &Arc, - remote: &Arc, - ) { - let (msg, result) = { - let ufrag_pwd = self.ufrag_pwd.lock().await; - let username = ufrag_pwd.remote_ufrag.clone() + ":" + ufrag_pwd.local_ufrag.as_str(); - let mut msg = Message::new(); - let result = msg.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId::new()), - Box::new(Username::new(ATTR_USERNAME, username)), - Box::new(AttrControlling(self.tie_breaker.load(Ordering::SeqCst))), - Box::new(PriorityAttr(local.priority())), - Box::new(MessageIntegrity::new_short_term_integrity( - ufrag_pwd.remote_pwd.clone(), - )), - Box::new(FINGERPRINT), - ]); - (msg, result) - }; - - if let Err(err) = result { - log::error!("{}", err); - } else { - self.send_binding_request(&msg, local, remote).await; - } - } - - async fn handle_success_response( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - remote_addr: SocketAddr, - ) { - if let Some(pending_request) = self.handle_inbound_binding_success(m.transaction_id).await { - let transaction_addr = pending_request.destination; - - // Assert that NAT is not symmetric - // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 - if transaction_addr != remote_addr { - log::debug!("discard message: transaction source and destination does not match expected({}), actual({})", transaction_addr, remote); - return; - } - - log::trace!( - "inbound STUN (SuccessResponse) from {} to {}", - remote, - local - ); - let selected_pair_is_none = self.agent_conn.get_selected_pair().is_none(); - - if let Some(p) = self.find_pair(local, remote).await { - p.state - .store(CandidatePairState::Succeeded as u8, Ordering::SeqCst); - log::trace!( - "Found valid candidate pair: {}, p.state: {}, isUseCandidate: {}, {}", - p, - p.state.load(Ordering::SeqCst), - pending_request.is_use_candidate, - selected_pair_is_none - ); - if pending_request.is_use_candidate && selected_pair_is_none { - self.set_selected_pair(Some(Arc::clone(&p))).await; - } - } else { - // This shouldn't happen - log::error!("Success response from invalid candidate pair"); - } - } else { - log::warn!( - "discard message from ({}), unknown TransactionID 0x{:?}", - remote, - m.transaction_id - ); - } - } - - async fn handle_binding_request( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ) { - self.send_binding_success(m, local, remote).await; - log::trace!("controllingSelector: sendBindingSuccess"); - - if let Some(p) = self.find_pair(local, remote).await { - let nominated_pair_is_none = { - let nominated_pair = self.nominated_pair.lock().await; - nominated_pair.is_none() - }; - - log::trace!( - "controllingSelector: after findPair {}, p.state: {}, {}", - p, - p.state.load(Ordering::SeqCst), - nominated_pair_is_none, - //self.agent_conn.get_selected_pair().await.is_none() //, {} - ); - if p.state.load(Ordering::SeqCst) == CandidatePairState::Succeeded as u8 - && nominated_pair_is_none - && self.agent_conn.get_selected_pair().is_none() - { - if let Some(best_pair) = self.agent_conn.get_best_available_candidate_pair().await { - log::trace!( - "controllingSelector: getBestAvailableCandidatePair {}", - best_pair - ); - if best_pair == p - && self.is_nominatable(&p.local) - && self.is_nominatable(&p.remote) - { - log::trace!("The candidate ({}, {}) is the best candidate available, marking it as nominated", - p.local, p.remote); - { - let mut nominated_pair = self.nominated_pair.lock().await; - *nominated_pair = Some(p); - } - self.nominate_pair().await; - } - } else { - log::trace!("No best pair available"); - } - } - } else { - log::trace!("controllingSelector: addPair"); - self.add_pair(local.clone(), remote.clone()).await; - } - } -} - -#[async_trait] -impl ControlledSelector for AgentInternal { - async fn start(&self) {} - - async fn contact_candidates(&self) { - // A lite selector should not contact candidates - if self.lite.load(Ordering::SeqCst) { - self.validate_selected_pair().await; - } else if self.agent_conn.get_selected_pair().is_some() { - if self.validate_selected_pair().await { - log::trace!("[{}]: checking keepalive", self.get_name()); - self.check_keepalive().await; - } - } else { - self.ping_all_candidates().await; - } - } - - async fn ping_candidate( - &self, - local: &Arc, - remote: &Arc, - ) { - let (msg, result) = { - let ufrag_pwd = self.ufrag_pwd.lock().await; - let username = ufrag_pwd.remote_ufrag.clone() + ":" + ufrag_pwd.local_ufrag.as_str(); - let mut msg = Message::new(); - let result = msg.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId::new()), - Box::new(Username::new(ATTR_USERNAME, username)), - Box::new(AttrControlled(self.tie_breaker.load(Ordering::SeqCst))), - Box::new(PriorityAttr(local.priority())), - Box::new(MessageIntegrity::new_short_term_integrity( - ufrag_pwd.remote_pwd.clone(), - )), - Box::new(FINGERPRINT), - ]); - (msg, result) - }; - - if let Err(err) = result { - log::error!("{}", err); - } else { - self.send_binding_request(&msg, local, remote).await; - } - } - - async fn handle_success_response( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - remote_addr: SocketAddr, - ) { - // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 - // If the controlled agent does not accept the request from the - // controlling agent, the controlled agent MUST reject the nomination - // request with an appropriate error code response (e.g., 400) - // [RFC5389]. - - if let Some(pending_request) = self.handle_inbound_binding_success(m.transaction_id).await { - let transaction_addr = pending_request.destination; - - // Assert that NAT is not symmetric - // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 - if transaction_addr != remote_addr { - log::debug!("discard message: transaction source and destination does not match expected({}), actual({})", transaction_addr, remote); - return; - } - - log::trace!( - "inbound STUN (SuccessResponse) from {} to {}", - remote, - local - ); - - if let Some(p) = self.find_pair(local, remote).await { - p.state - .store(CandidatePairState::Succeeded as u8, Ordering::SeqCst); - log::trace!("Found valid candidate pair: {}", p); - } else { - // This shouldn't happen - log::error!("Success response from invalid candidate pair"); - } - } else { - log::warn!( - "discard message from ({}), unknown TransactionID 0x{:?}", - remote, - m.transaction_id - ); - } - } - - async fn handle_binding_request( - &self, - m: &Message, - local: &Arc, - remote: &Arc, - ) { - if self.find_pair(local, remote).await.is_none() { - self.add_pair(local.clone(), remote.clone()).await; - } - - if let Some(p) = self.find_pair(local, remote).await { - let use_candidate = m.contains(ATTR_USE_CANDIDATE); - if use_candidate { - // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 - - if p.state.load(Ordering::SeqCst) == CandidatePairState::Succeeded as u8 { - // If the state of this pair is Succeeded, it means that the check - // previously sent by this pair produced a successful response and - // generated a valid pair (Section 7.2.5.3.2). The agent sets the - // nominated flag value of the valid pair to true. - if self.agent_conn.get_selected_pair().is_none() { - self.set_selected_pair(Some(Arc::clone(&p))).await; - } - self.send_binding_success(m, local, remote).await; - } else { - // If the received Binding request triggered a new check to be - // enqueued in the triggered-check queue (Section 7.3.1.4), once the - // check is sent and if it generates a successful response, and - // generates a valid pair, the agent sets the nominated flag of the - // pair to true. If the request fails (Section 7.2.5.2), the agent - // MUST remove the candidate pair from the valid list, set the - // candidate pair state to Failed, and set the checklist state to - // Failed. - self.ping_candidate(local, remote).await; - } - } else { - self.send_binding_success(m, local, remote).await; - self.ping_candidate(local, remote).await; - } - } - } -} diff --git a/ice/src/agent/agent_stats.rs b/ice/src/agent/agent_stats.rs deleted file mode 100644 index 27ad3cc31..000000000 --- a/ice/src/agent/agent_stats.rs +++ /dev/null @@ -1,283 +0,0 @@ -use std::sync::atomic::Ordering; - -use tokio::time::Instant; - -use crate::agent::agent_internal::AgentInternal; -use crate::candidate::{CandidatePairState, CandidateType}; -use crate::network_type::NetworkType; - -/// Contains ICE candidate pair statistics. -pub struct CandidatePairStats { - /// The timestamp associated with this struct. - pub timestamp: Instant, - - /// The id of the local candidate. - pub local_candidate_id: String, - - /// The id of the remote candidate. - pub remote_candidate_id: String, - - /// The state of the checklist for the local and remote candidates in a pair. - pub state: CandidatePairState, - - /// It is true when this valid pair that should be used for media, - /// if it is the highest-priority one amongst those whose nominated flag is set. - pub nominated: bool, - - /// The total number of packets sent on this candidate pair. - pub packets_sent: u32, - - /// The total number of packets received on this candidate pair. - pub packets_received: u32, - - /// The total number of payload bytes sent on this candidate pair not including headers or - /// padding. - pub bytes_sent: u64, - - /// The total number of payload bytes received on this candidate pair not including headers or - /// padding. - pub bytes_received: u64, - - /// The timestamp at which the last packet was sent on this particular candidate pair, excluding - /// STUN packets. - pub last_packet_sent_timestamp: Instant, - - /// The timestamp at which the last packet was received on this particular candidate pair, - /// excluding STUN packets. - pub last_packet_received_timestamp: Instant, - - /// The timestamp at which the first STUN request was sent on this particular candidate pair. - pub first_request_timestamp: Instant, - - /// The timestamp at which the last STUN request was sent on this particular candidate pair. - /// The average interval between two consecutive connectivity checks sent can be calculated with - /// (last_request_timestamp - first_request_timestamp) / requests_sent. - pub last_request_timestamp: Instant, - - /// Timestamp at which the last STUN response was received on this particular candidate pair. - pub last_response_timestamp: Instant, - - /// The sum of all round trip time measurements in seconds since the beginning of the session, - /// based on STUN connectivity check responses (responses_received), including those that reply - /// to requests that are sent in order to verify consent. The average round trip time can be - /// computed from total_round_trip_time by dividing it by responses_received. - pub total_round_trip_time: f64, - - /// The latest round trip time measured in seconds, computed from both STUN connectivity checks, - /// including those that are sent for consent verification. - pub current_round_trip_time: f64, - - /// It is calculated by the underlying congestion control by combining the available bitrate for - /// all the outgoing RTP streams using this candidate pair. The bitrate measurement does not - /// count the size of the IP or other transport layers like TCP or UDP. It is similar to the - /// TIAS defined in RFC 3890, i.e., it is measured in bits per second and the bitrate is - /// calculated over a 1 second window. - pub available_outgoing_bitrate: f64, - - /// It is calculated by the underlying congestion control by combining the available bitrate for - /// all the incoming RTP streams using this candidate pair. The bitrate measurement does not - /// count the size of the IP or other transport layers like TCP or UDP. It is similar to the - /// TIAS defined in RFC 3890, i.e., it is measured in bits per second and the bitrate is - /// calculated over a 1 second window. - pub available_incoming_bitrate: f64, - - /// The number of times the circuit breaker is triggered for this particular 5-tuple, - /// ceasing transmission. - pub circuit_breaker_trigger_count: u32, - - /// The total number of connectivity check requests received (including retransmissions). - /// It is impossible for the receiver to tell whether the request was sent in order to check - /// connectivity or check consent, so all connectivity checks requests are counted here. - pub requests_received: u64, - - /// The total number of connectivity check requests sent (not including retransmissions). - pub requests_sent: u64, - - /// The total number of connectivity check responses received. - pub responses_received: u64, - - /// The total number of connectivity check responses sent. Since we cannot distinguish - /// connectivity check requests and consent requests, all responses are counted. - pub responses_sent: u64, - - /// The total number of connectivity check request retransmissions received. - pub retransmissions_received: u64, - - /// The total number of connectivity check request retransmissions sent. - pub retransmissions_sent: u64, - - /// The total number of consent requests sent. - pub consent_requests_sent: u64, - - /// The timestamp at which the latest valid STUN binding response expired. - pub consent_expired_timestamp: Instant, -} - -impl Default for CandidatePairStats { - fn default() -> Self { - Self { - timestamp: Instant::now(), - local_candidate_id: String::new(), - remote_candidate_id: String::new(), - state: CandidatePairState::default(), - nominated: false, - packets_sent: 0, - packets_received: 0, - bytes_sent: 0, - bytes_received: 0, - last_packet_sent_timestamp: Instant::now(), - last_packet_received_timestamp: Instant::now(), - first_request_timestamp: Instant::now(), - last_request_timestamp: Instant::now(), - last_response_timestamp: Instant::now(), - total_round_trip_time: 0.0, - current_round_trip_time: 0.0, - available_outgoing_bitrate: 0.0, - available_incoming_bitrate: 0.0, - circuit_breaker_trigger_count: 0, - requests_received: 0, - requests_sent: 0, - responses_received: 0, - responses_sent: 0, - retransmissions_received: 0, - retransmissions_sent: 0, - consent_requests_sent: 0, - consent_expired_timestamp: Instant::now(), - } - } -} - -/// Contains ICE candidate statistics related to the `ICETransport` objects. -#[derive(Debug, Clone)] -pub struct CandidateStats { - // The timestamp associated with this struct. - pub timestamp: Instant, - - /// The candidate id. - pub id: String, - - /// The type of network interface used by the base of a local candidate (the address the ICE - /// agent sends from). Only present for local candidates; it's not possible to know what type of - /// network interface a remote candidate is using. - /// - /// Note: This stat only tells you about the network interface used by the first "hop"; it's - /// possible that a connection will be bottlenecked by another type of network. For example, - /// when using Wi-Fi tethering, the networkType of the relevant candidate would be "wifi", even - /// when the next hop is over a cellular connection. - pub network_type: NetworkType, - - /// The IP address of the candidate, allowing for IPv4 addresses and IPv6 addresses, but fully - /// qualified domain names (FQDNs) are not allowed. - pub ip: String, - - /// The port number of the candidate. - pub port: u16, - - /// The `Type` field of the ICECandidate. - pub candidate_type: CandidateType, - - /// The `priority` field of the ICECandidate. - pub priority: u32, - - /// The url of the TURN or STUN server indicated in the that translated this IP address. - /// It is the url address surfaced in an PeerConnectionICEEvent. - pub url: String, - - /// The protocol used by the endpoint to communicate with the TURN server. This is only present - /// for local candidates. Valid values for the TURN url protocol is one of udp, tcp, or tls. - pub relay_protocol: String, - - /// It is true if the candidate has been deleted/freed. For host candidates, this means that any - /// network resources (typically a socket) associated with the candidate have been released. For - /// TURN candidates, this means the TURN allocation is no longer active. - /// - /// Only defined for local candidates. For remote candidates, this property is not applicable. - pub deleted: bool, -} - -impl Default for CandidateStats { - fn default() -> Self { - Self { - timestamp: Instant::now(), - id: String::new(), - network_type: NetworkType::default(), - ip: String::new(), - port: 0, - candidate_type: CandidateType::default(), - priority: 0, - url: String::new(), - relay_protocol: String::new(), - deleted: false, - } - } -} - -impl AgentInternal { - /// Returns a list of candidate pair stats. - pub(crate) async fn get_candidate_pairs_stats(&self) -> Vec { - let checklist = self.agent_conn.checklist.lock().await; - let mut res = Vec::with_capacity(checklist.len()); - for cp in &*checklist { - let stat = CandidatePairStats { - timestamp: Instant::now(), - local_candidate_id: cp.local.id(), - remote_candidate_id: cp.remote.id(), - state: cp.state.load(Ordering::SeqCst).into(), - nominated: cp.nominated.load(Ordering::SeqCst), - ..CandidatePairStats::default() - }; - res.push(stat); - } - res - } - - /// Returns a list of local candidates stats. - pub(crate) async fn get_local_candidates_stats(&self) -> Vec { - let local_candidates = self.local_candidates.lock().await; - let mut res = Vec::with_capacity(local_candidates.len()); - for (network_type, local_candidates) in &*local_candidates { - for c in local_candidates { - let stat = CandidateStats { - timestamp: Instant::now(), - id: c.id(), - network_type: *network_type, - ip: c.address(), - port: c.port(), - candidate_type: c.candidate_type(), - priority: c.priority(), - // URL string - relay_protocol: "udp".to_owned(), - // Deleted bool - ..CandidateStats::default() - }; - res.push(stat); - } - } - res - } - - /// Returns a list of remote candidates stats. - pub(crate) async fn get_remote_candidates_stats(&self) -> Vec { - let remote_candidates = self.remote_candidates.lock().await; - let mut res = Vec::with_capacity(remote_candidates.len()); - for (network_type, remote_candidates) in &*remote_candidates { - for c in remote_candidates { - let stat = CandidateStats { - timestamp: Instant::now(), - id: c.id(), - network_type: *network_type, - ip: c.address(), - port: c.port(), - candidate_type: c.candidate_type(), - priority: c.priority(), - // URL string - relay_protocol: "udp".to_owned(), - // Deleted bool - ..CandidateStats::default() - }; - res.push(stat); - } - } - res - } -} diff --git a/ice/src/agent/agent_test.rs b/ice/src/agent/agent_test.rs deleted file mode 100644 index 0af395a6f..000000000 --- a/ice/src/agent/agent_test.rs +++ /dev/null @@ -1,2203 +0,0 @@ -use std::net::Ipv4Addr; -use std::ops::Sub; -use std::str::FromStr; - -use async_trait::async_trait; -use stun::message::*; -use stun::textattrs::Username; -use util::vnet::*; -use util::Conn; -use waitgroup::{WaitGroup, Worker}; - -use super::agent_vnet_test::*; -use super::*; -use crate::agent::agent_transport_test::pipe; -use crate::candidate::candidate_base::*; -use crate::candidate::candidate_host::*; -use crate::candidate::candidate_peer_reflexive::*; -use crate::candidate::candidate_relay::*; -use crate::candidate::candidate_server_reflexive::*; -use crate::control::AttrControlling; -use crate::priority::PriorityAttr; -use crate::use_candidate::UseCandidateAttr; - -#[tokio::test] -async fn test_pair_search() -> Result<()> { - let config = AgentConfig::default(); - let a = Agent::new(config).await?; - - { - { - let checklist = a.internal.agent_conn.checklist.lock().await; - assert!( - checklist.is_empty(), - "TestPairSearch is only a valid test if a.validPairs is empty on construction" - ); - } - - let cp = a - .internal - .agent_conn - .get_best_available_candidate_pair() - .await; - assert!(cp.is_none(), "No Candidate pairs should exist"); - } - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_pair_priority() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.1.1".to_owned(), - port: 19216, - component: 1, - ..Default::default() - }, - ..Default::default() - }; - let host_local: Arc = Arc::new(host_config.new_candidate_host()?); - - let relay_config = CandidateRelayConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.4".to_owned(), - port: 12340, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43210, - ..Default::default() - }; - - let relay_remote = relay_config.new_candidate_relay()?; - - let srflx_config = CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "10.10.10.2".to_owned(), - port: 19218, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43212, - }; - - let srflx_remote = srflx_config.new_candidate_server_reflexive()?; - - let prflx_config = CandidatePeerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "10.10.10.2".to_owned(), - port: 19217, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43211, - }; - - let prflx_remote = prflx_config.new_candidate_peer_reflexive()?; - - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.5".to_owned(), - port: 12350, - component: 1, - ..Default::default() - }, - ..Default::default() - }; - let host_remote = host_config.new_candidate_host()?; - - let remotes: Vec> = vec![ - Arc::new(relay_remote), - Arc::new(srflx_remote), - Arc::new(prflx_remote), - Arc::new(host_remote), - ]; - - { - for remote in remotes { - if a.internal.find_pair(&host_local, &remote).await.is_none() { - a.internal - .add_pair(host_local.clone(), remote.clone()) - .await; - } - - if let Some(p) = a.internal.find_pair(&host_local, &remote).await { - p.state - .store(CandidatePairState::Succeeded as u8, Ordering::SeqCst); - } - - if let Some(best_pair) = a - .internal - .agent_conn - .get_best_available_candidate_pair() - .await - { - assert_eq!( - best_pair.to_string(), - CandidatePair { - remote: remote.clone(), - local: host_local.clone(), - ..Default::default() - } - .to_string(), - "Unexpected bestPair {best_pair} (expected remote: {remote})", - ); - } else { - panic!("expected Some, but got None"); - } - } - } - - a.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_agent_get_stats() -> Result<()> { - let (conn_a, conn_b, agent_a, agent_b) = pipe(None, None).await?; - assert_eq!(agent_a.get_bytes_received(), 0); - assert_eq!(agent_a.get_bytes_sent(), 0); - assert_eq!(agent_b.get_bytes_received(), 0); - assert_eq!(agent_b.get_bytes_sent(), 0); - - let _na = conn_a.send(&[0u8; 10]).await?; - let mut buf = vec![0u8; 10]; - let _nb = conn_b.recv(&mut buf).await?; - - assert_eq!(agent_a.get_bytes_received(), 0); - assert_eq!(agent_a.get_bytes_sent(), 10); - - assert_eq!(agent_b.get_bytes_received(), 10); - assert_eq!(agent_b.get_bytes_sent(), 0); - - Ok(()) -} - -#[tokio::test] -async fn test_on_selected_candidate_pair_change() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - let (callback_called_tx, mut callback_called_rx) = mpsc::channel::<()>(1); - let callback_called_tx = Arc::new(Mutex::new(Some(callback_called_tx))); - let cb: OnSelectedCandidatePairChangeHdlrFn = Box::new(move |_, _| { - let callback_called_tx_clone = Arc::clone(&callback_called_tx); - Box::pin(async move { - let mut tx = callback_called_tx_clone.lock().await; - tx.take(); - }) - }); - a.on_selected_candidate_pair_change(cb); - - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.1.1".to_owned(), - port: 19216, - component: 1, - ..Default::default() - }, - ..Default::default() - }; - let host_local = host_config.new_candidate_host()?; - - let relay_config = CandidateRelayConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.4".to_owned(), - port: 12340, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43210, - ..Default::default() - }; - let relay_remote = relay_config.new_candidate_relay()?; - - // select the pair - let p = Arc::new(CandidatePair::new( - Arc::new(host_local), - Arc::new(relay_remote), - false, - )); - a.internal.set_selected_pair(Some(p)).await; - - // ensure that the callback fired on setting the pair - let _ = callback_called_rx.recv().await; - - a.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_handle_peer_reflexive_udp_pflx_candidate() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.0.2".to_owned(), - port: 777, - component: 1, - conn: Some(Arc::new(MockConn {})), - ..Default::default() - }, - ..Default::default() - }; - - let local: Arc = Arc::new(host_config.new_candidate_host()?); - let remote = SocketAddr::from_str("172.17.0.3:999")?; - - let (username, local_pwd, tie_breaker) = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - ( - ufrag_pwd.local_ufrag.to_owned() + ":" + ufrag_pwd.remote_ufrag.as_str(), - ufrag_pwd.local_pwd.clone(), - a.internal.tie_breaker.load(Ordering::SeqCst), - ) - }; - - let mut msg = Message::new(); - msg.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId::new()), - Box::new(Username::new(ATTR_USERNAME, username)), - Box::new(UseCandidateAttr::new()), - Box::new(AttrControlling(tie_breaker)), - Box::new(PriorityAttr(local.priority())), - Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)), - Box::new(FINGERPRINT), - ])?; - - { - a.internal.handle_inbound(&mut msg, &local, remote).await; - - let remote_candidates = a.internal.remote_candidates.lock().await; - // length of remote candidate list must be one now - assert_eq!( - remote_candidates.len(), - 1, - "failed to add a network type to the remote candidate list" - ); - - // length of remote candidate list for a network type must be 1 - if let Some(cands) = remote_candidates.get(&local.network_type()) { - assert_eq!( - cands.len(), - 1, - "failed to add prflx candidate to remote candidate list" - ); - - let c = &cands[0]; - - assert_eq!( - c.candidate_type(), - CandidateType::PeerReflexive, - "candidate type must be prflx" - ); - - assert_eq!(c.address(), "172.17.0.3", "IP address mismatch"); - - assert_eq!(c.port(), 999, "Port number mismatch"); - } else { - panic!( - "expected non-empty remote candidate for network type {}", - local.network_type() - ); - } - } - - a.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_handle_peer_reflexive_unknown_remote() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let mut tid = TransactionId::default(); - tid.0[..3].copy_from_slice("ABC".as_bytes()); - - let remote_pwd = { - { - let mut pending_binding_requests = a.internal.pending_binding_requests.lock().await; - *pending_binding_requests = vec![BindingRequest { - timestamp: Instant::now(), - transaction_id: tid, - destination: SocketAddr::from_str("0.0.0.0:0")?, - is_use_candidate: false, - }]; - } - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - ufrag_pwd.remote_pwd.clone() - }; - - let host_config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.0.2".to_owned(), - port: 777, - component: 1, - conn: Some(Arc::new(MockConn {})), - ..Default::default() - }, - ..Default::default() - }; - - let local: Arc = Arc::new(host_config.new_candidate_host()?); - let remote = SocketAddr::from_str("172.17.0.3:999")?; - - let mut msg = Message::new(); - msg.build(&[ - Box::new(BINDING_SUCCESS), - Box::new(tid), - Box::new(MessageIntegrity::new_short_term_integrity(remote_pwd)), - Box::new(FINGERPRINT), - ])?; - - { - a.internal.handle_inbound(&mut msg, &local, remote).await; - - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_eq!( - remote_candidates.len(), - 0, - "unknown remote was able to create a candidate" - ); - } - - a.close().await?; - Ok(()) -} - -//use std::io::Write; - -// Assert that Agent on startup sends message, and doesn't wait for connectivityTicker to fire -#[tokio::test] -async fn test_connectivity_on_startup() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - // Create a network with two interfaces - let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?)); - - let net0 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.1".to_owned()], - ..Default::default() - }))); - let net1 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.2".to_owned()], - ..Default::default() - }))); - - connect_net2router(&net0, &wan).await?; - connect_net2router(&net1, &wan).await?; - start_router(&wan).await?; - - let (a_notifier, mut a_connected) = on_connected(); - let (b_notifier, mut b_connected) = on_connected(); - - let keepalive_interval = Some(Duration::from_secs(3600)); //time.Hour - let check_interval = Duration::from_secs(3600); //time.Hour - let cfg0 = AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(net0), - - keepalive_interval, - check_interval, - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); - - let cfg1 = AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(net1), - - keepalive_interval, - check_interval, - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); - - // Manual signaling - let (a_ufrag, a_pwd) = a_agent.get_local_user_credentials().await; - let (b_ufrag, b_pwd) = b_agent.get_local_user_credentials().await; - - gather_and_exchange_candidates(&a_agent, &b_agent).await?; - - let (accepted_tx, mut accepted_rx) = mpsc::channel::<()>(1); - let (accepting_tx, mut accepting_rx) = mpsc::channel::<()>(1); - let (_a_cancel_tx, a_cancel_rx) = mpsc::channel(1); - let (_b_cancel_tx, b_cancel_rx) = mpsc::channel(1); - - let accepting_tx = Arc::new(Mutex::new(Some(accepting_tx))); - a_agent.on_connection_state_change(Box::new(move |s: ConnectionState| { - let accepted_tx_clone = Arc::clone(&accepting_tx); - Box::pin(async move { - if s == ConnectionState::Checking { - let mut tx = accepted_tx_clone.lock().await; - tx.take(); - } - }) - })); - - tokio::spawn(async move { - let result = a_agent.accept(a_cancel_rx, b_ufrag, b_pwd).await; - assert!(result.is_ok(), "agent accept expected OK"); - drop(accepted_tx); - }); - - let _ = accepting_rx.recv().await; - - let _ = b_agent.dial(b_cancel_rx, a_ufrag, a_pwd).await?; - - // Ensure accepted - let _ = accepted_rx.recv().await; - - // Ensure pair selected - // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - { - let mut w = wan.lock().await; - w.stop().await?; - } - - Ok(()) -} - -#[tokio::test] -async fn test_connectivity_lite() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let stun_server_url = Url { - scheme: SchemeType::Stun, - host: "1.2.3.4".to_owned(), - port: 3478, - proto: ProtoType::Udp, - ..Default::default() - }; - - let nat_type = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, - filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, - ..Default::default() - }; - - let v = build_vnet(nat_type, nat_type).await?; - - let (a_notifier, mut a_connected) = on_connected(); - let (b_notifier, mut b_connected) = on_connected(); - - let cfg0 = AgentConfig { - urls: vec![stun_server_url], - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(Arc::clone(&v.net0)), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); - - let cfg1 = AgentConfig { - urls: vec![], - lite: true, - candidate_types: vec![CandidateType::Host], - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(Arc::clone(&v.net1)), - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); - - let _ = connect_with_vnet(&a_agent, &b_agent).await?; - - // Ensure pair selected - // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - v.close().await?; - - Ok(()) -} - -struct MockPacketConn; - -#[async_trait] -impl Conn for MockPacketConn { - async fn connect(&self, _addr: SocketAddr) -> std::result::Result<(), util::Error> { - Ok(()) - } - - async fn recv(&self, _buf: &mut [u8]) -> std::result::Result { - Ok(0) - } - - async fn recv_from( - &self, - _buf: &mut [u8], - ) -> std::result::Result<(usize, SocketAddr), util::Error> { - Ok((0, SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0))) - } - - async fn send(&self, _buf: &[u8]) -> std::result::Result { - Ok(0) - } - - async fn send_to( - &self, - _buf: &[u8], - _target: SocketAddr, - ) -> std::result::Result { - Ok(0) - } - - fn local_addr(&self) -> std::result::Result { - Ok(SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> std::result::Result<(), util::Error> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -fn build_msg(c: MessageClass, username: String, key: String) -> Result { - let mut msg = Message::new(); - msg.build(&[ - Box::new(MessageType::new(METHOD_BINDING, c)), - Box::new(TransactionId::new()), - Box::new(Username::new(ATTR_USERNAME, username)), - Box::new(MessageIntegrity::new_short_term_integrity(key)), - Box::new(FINGERPRINT), - ])?; - Ok(msg) -} - -#[tokio::test] -async fn test_inbound_validity() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let remote = SocketAddr::from_str("172.17.0.3:999")?; - let local: Arc = Arc::new( - CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.0.2".to_owned(), - port: 777, - component: 1, - conn: Some(Arc::new(MockPacketConn {})), - ..Default::default() - }, - ..Default::default() - } - .new_candidate_host()?, - ); - - //"Invalid Binding requests should be discarded" - { - let a = Agent::new(AgentConfig::default()).await?; - - { - let local_pwd = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - ufrag_pwd.local_pwd.clone() - }; - a.internal - .handle_inbound( - &mut build_msg(CLASS_REQUEST, "invalid".to_owned(), local_pwd)?, - &local, - remote, - ) - .await; - { - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_ne!( - remote_candidates.len(), - 1, - "Binding with invalid Username was able to create prflx candidate" - ); - } - - let username = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag) - }; - a.internal - .handle_inbound( - &mut build_msg(CLASS_REQUEST, username, "Invalid".to_owned())?, - &local, - remote, - ) - .await; - { - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_ne!( - remote_candidates.len(), - 1, - "Binding with invalid MessageIntegrity was able to create prflx candidate" - ); - } - } - - a.close().await?; - } - - //"Invalid Binding success responses should be discarded" - { - let a = Agent::new(AgentConfig::default()).await?; - - { - let username = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag) - }; - a.internal - .handle_inbound( - &mut build_msg(CLASS_SUCCESS_RESPONSE, username, "Invalid".to_owned())?, - &local, - remote, - ) - .await; - { - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_ne!( - remote_candidates.len(), - 1, - "Binding with invalid Username was able to create prflx candidate" - ); - } - } - - a.close().await?; - } - - //"Discard non-binding messages" - { - let a = Agent::new(AgentConfig::default()).await?; - - { - let username = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag) - }; - a.internal - .handle_inbound( - &mut build_msg(CLASS_ERROR_RESPONSE, username, "Invalid".to_owned())?, - &local, - remote, - ) - .await; - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_ne!( - remote_candidates.len(), - 1, - "non-binding message was able to create prflxRemote" - ); - } - - a.close().await?; - } - - //"Valid bind request" - { - let a = Agent::new(AgentConfig::default()).await?; - - { - let (username, local_pwd) = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - ( - format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag), - ufrag_pwd.local_pwd.clone(), - ) - }; - a.internal - .handle_inbound( - &mut build_msg(CLASS_REQUEST, username, local_pwd)?, - &local, - remote, - ) - .await; - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_eq!( - remote_candidates.len(), - 1, - "Binding with valid values was unable to create prflx candidate" - ); - } - - a.close().await?; - } - - //"Valid bind without fingerprint" - { - let a = Agent::new(AgentConfig::default()).await?; - - { - let (username, local_pwd) = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - ( - format!("{}:{}", ufrag_pwd.local_ufrag, ufrag_pwd.remote_ufrag), - ufrag_pwd.local_pwd.clone(), - ) - }; - - let mut msg = Message::new(); - msg.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId::new()), - Box::new(Username::new(ATTR_USERNAME, username)), - Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)), - ])?; - - a.internal.handle_inbound(&mut msg, &local, remote).await; - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_eq!( - remote_candidates.len(), - 1, - "Binding with valid values (but no fingerprint) was unable to create prflx candidate" - ); - } - - a.close().await?; - } - - //"Success with invalid TransactionID" - { - let a = Agent::new(AgentConfig::default()).await?; - - { - let remote = SocketAddr::from_str("172.17.0.3:999")?; - - let mut t_id = TransactionId::default(); - t_id.0[..3].copy_from_slice(b"ABC"); - - let remote_pwd = { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - ufrag_pwd.remote_pwd.clone() - }; - - let mut msg = Message::new(); - msg.build(&[ - Box::new(BINDING_SUCCESS), - Box::new(t_id), - Box::new(MessageIntegrity::new_short_term_integrity(remote_pwd)), - Box::new(FINGERPRINT), - ])?; - - a.internal.handle_inbound(&mut msg, &local, remote).await; - - { - let remote_candidates = a.internal.remote_candidates.lock().await; - assert_eq!( - remote_candidates.len(), - 0, - "unknown remote was able to create a candidate" - ); - } - } - - a.close().await?; - } - - Ok(()) -} - -#[tokio::test] -async fn test_invalid_agent_starts() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let (_cancel_tx1, cancel_rx1) = mpsc::channel(1); - let result = a.dial(cancel_rx1, "".to_owned(), "bar".to_owned()).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrRemoteUfragEmpty, err); - } - - let (_cancel_tx2, cancel_rx2) = mpsc::channel(1); - let result = a.dial(cancel_rx2, "foo".to_owned(), "".to_owned()).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrRemotePwdEmpty, err); - } - - let (cancel_tx3, cancel_rx3) = mpsc::channel(1); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(100)).await; - drop(cancel_tx3); - }); - - let result = a.dial(cancel_rx3, "foo".to_owned(), "bar".to_owned()).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrCanceledByCaller, err); - } - - let (_cancel_tx4, cancel_rx4) = mpsc::channel(1); - let result = a.dial(cancel_rx4, "foo".to_owned(), "bar".to_owned()).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrMultipleStart, err); - } - - a.close().await?; - - Ok(()) -} - -//use std::io::Write; - -// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages -#[tokio::test] -async fn test_connection_state_callback() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let disconnected_duration = Duration::from_secs(1); - let failed_duration = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let cfg0 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(disconnected_duration), - failed_timeout: Some(failed_duration), - keepalive_interval: Some(keepalive_interval), - ..Default::default() - }; - - let cfg1 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(disconnected_duration), - failed_timeout: Some(failed_duration), - keepalive_interval: Some(keepalive_interval), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let b_agent = Arc::new(Agent::new(cfg1).await?); - - let (is_checking_tx, mut is_checking_rx) = mpsc::channel::<()>(1); - let (is_connected_tx, mut is_connected_rx) = mpsc::channel::<()>(1); - let (is_disconnected_tx, mut is_disconnected_rx) = mpsc::channel::<()>(1); - let (is_failed_tx, mut is_failed_rx) = mpsc::channel::<()>(1); - let (is_closed_tx, mut is_closed_rx) = mpsc::channel::<()>(1); - - let is_checking_tx = Arc::new(Mutex::new(Some(is_checking_tx))); - let is_connected_tx = Arc::new(Mutex::new(Some(is_connected_tx))); - let is_disconnected_tx = Arc::new(Mutex::new(Some(is_disconnected_tx))); - let is_failed_tx = Arc::new(Mutex::new(Some(is_failed_tx))); - let is_closed_tx = Arc::new(Mutex::new(Some(is_closed_tx))); - - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_checking_tx_clone = Arc::clone(&is_checking_tx); - let is_connected_tx_clone = Arc::clone(&is_connected_tx); - let is_disconnected_tx_clone = Arc::clone(&is_disconnected_tx); - let is_failed_tx_clone = Arc::clone(&is_failed_tx); - let is_closed_tx_clone = Arc::clone(&is_closed_tx); - Box::pin(async move { - match c { - ConnectionState::Checking => { - log::debug!("drop is_checking_tx"); - let mut tx = is_checking_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Connected => { - log::debug!("drop is_connected_tx"); - let mut tx = is_connected_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Disconnected => { - log::debug!("drop is_disconnected_tx"); - let mut tx = is_disconnected_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Failed => { - log::debug!("drop is_failed_tx"); - let mut tx = is_failed_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Closed => { - log::debug!("drop is_closed_tx"); - let mut tx = is_closed_tx_clone.lock().await; - tx.take(); - } - _ => {} - }; - }) - })); - - connect_with_vnet(&a_agent, &b_agent).await?; - - log::debug!("wait is_checking_tx"); - let _ = is_checking_rx.recv().await; - log::debug!("wait is_connected_rx"); - let _ = is_connected_rx.recv().await; - log::debug!("wait is_disconnected_rx"); - let _ = is_disconnected_rx.recv().await; - log::debug!("wait is_failed_rx"); - let _ = is_failed_rx.recv().await; - - a_agent.close().await?; - b_agent.close().await?; - - log::debug!("wait is_closed_rx"); - let _ = is_closed_rx.recv().await; - - Ok(()) -} - -#[tokio::test] -async fn test_invalid_gather() -> Result<()> { - //"Gather with no OnCandidate should error" - let a = Agent::new(AgentConfig::default()).await?; - - if let Err(err) = a.gather_candidates() { - assert_eq!( - Error::ErrNoOnCandidateHandler, - err, - "trickle GatherCandidates succeeded without OnCandidate" - ); - } - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_candidate_pair_stats() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let host_local: Arc = Arc::new( - CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.1.1".to_owned(), - port: 19216, - component: 1, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_host()?, - ); - - let relay_remote: Arc = Arc::new( - CandidateRelayConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.4".to_owned(), - port: 2340, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43210, - ..Default::default() - } - .new_candidate_relay()?, - ); - - let srflx_remote: Arc = Arc::new( - CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "10.10.10.2".to_owned(), - port: 19218, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43212, - } - .new_candidate_server_reflexive()?, - ); - - let prflx_remote: Arc = Arc::new( - CandidatePeerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "10.10.10.2".to_owned(), - port: 19217, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43211, - } - .new_candidate_peer_reflexive()?, - ); - - let host_remote: Arc = Arc::new( - CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.5".to_owned(), - port: 12350, - component: 1, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_host()?, - ); - - for remote in &[ - Arc::clone(&relay_remote), - Arc::clone(&srflx_remote), - Arc::clone(&prflx_remote), - Arc::clone(&host_remote), - ] { - let p = a.internal.find_pair(&host_local, remote).await; - - if p.is_none() { - a.internal - .add_pair(Arc::clone(&host_local), Arc::clone(remote)) - .await; - } - } - - { - if let Some(p) = a.internal.find_pair(&host_local, &prflx_remote).await { - p.state - .store(CandidatePairState::Failed as u8, Ordering::SeqCst); - } - } - - let stats = a.get_candidate_pairs_stats().await; - assert_eq!(stats.len(), 4, "expected 4 candidate pairs stats"); - - let (mut relay_pair_stat, mut srflx_pair_stat, mut prflx_pair_stat, mut host_pair_stat) = ( - CandidatePairStats::default(), - CandidatePairStats::default(), - CandidatePairStats::default(), - CandidatePairStats::default(), - ); - - for cps in stats { - assert_eq!( - cps.local_candidate_id, - host_local.id(), - "invalid local candidate id" - ); - - if cps.remote_candidate_id == relay_remote.id() { - relay_pair_stat = cps; - } else if cps.remote_candidate_id == srflx_remote.id() { - srflx_pair_stat = cps; - } else if cps.remote_candidate_id == prflx_remote.id() { - prflx_pair_stat = cps; - } else if cps.remote_candidate_id == host_remote.id() { - host_pair_stat = cps; - } else { - panic!("invalid remote candidate ID"); - } - } - - assert_eq!( - relay_pair_stat.remote_candidate_id, - relay_remote.id(), - "missing host-relay pair stat" - ); - assert_eq!( - srflx_pair_stat.remote_candidate_id, - srflx_remote.id(), - "missing host-srflx pair stat" - ); - assert_eq!( - prflx_pair_stat.remote_candidate_id, - prflx_remote.id(), - "missing host-prflx pair stat" - ); - assert_eq!( - host_pair_stat.remote_candidate_id, - host_remote.id(), - "missing host-host pair stat" - ); - assert_eq!( - prflx_pair_stat.state, - CandidatePairState::Failed, - "expected host-prfflx pair to have state failed, it has state {} instead", - prflx_pair_stat.state - ); - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_local_candidate_stats() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let host_local: Arc = Arc::new( - CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.1.1".to_owned(), - port: 19216, - component: 1, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_host()?, - ); - - let srflx_local: Arc = Arc::new( - CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "192.168.1.1".to_owned(), - port: 19217, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43212, - } - .new_candidate_server_reflexive()?, - ); - - { - let mut local_candidates = a.internal.local_candidates.lock().await; - local_candidates.insert( - NetworkType::Udp4, - vec![Arc::clone(&host_local), Arc::clone(&srflx_local)], - ); - } - - let local_stats = a.get_local_candidates_stats().await; - assert_eq!( - local_stats.len(), - 2, - "expected 2 local candidates stats, got {} instead", - local_stats.len() - ); - - let (mut host_local_stat, mut srflx_local_stat) = - (CandidateStats::default(), CandidateStats::default()); - for stats in local_stats { - let candidate = if stats.id == host_local.id() { - host_local_stat = stats.clone(); - Arc::clone(&host_local) - } else if stats.id == srflx_local.id() { - srflx_local_stat = stats.clone(); - Arc::clone(&srflx_local) - } else { - panic!("invalid local candidate ID"); - }; - - assert_eq!( - stats.candidate_type, - candidate.candidate_type(), - "invalid stats CandidateType" - ); - assert_eq!( - stats.priority, - candidate.priority(), - "invalid stats CandidateType" - ); - assert_eq!(stats.ip, candidate.address(), "invalid stats IP"); - } - - assert_eq!( - host_local_stat.id, - host_local.id(), - "missing host local stat" - ); - assert_eq!( - srflx_local_stat.id, - srflx_local.id(), - "missing srflx local stat" - ); - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_remote_candidate_stats() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let relay_remote: Arc = Arc::new( - CandidateRelayConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.4".to_owned(), - port: 12340, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43210, - ..Default::default() - } - .new_candidate_relay()?, - ); - - let srflx_remote: Arc = Arc::new( - CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "10.10.10.2".to_owned(), - port: 19218, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43212, - } - .new_candidate_server_reflexive()?, - ); - - let prflx_remote: Arc = Arc::new( - CandidatePeerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "10.10.10.2".to_owned(), - port: 19217, - component: 1, - ..Default::default() - }, - rel_addr: "4.3.2.1".to_owned(), - rel_port: 43211, - } - .new_candidate_peer_reflexive()?, - ); - - let host_remote: Arc = Arc::new( - CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "1.2.3.5".to_owned(), - port: 12350, - component: 1, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_host()?, - ); - - { - let mut remote_candidates = a.internal.remote_candidates.lock().await; - remote_candidates.insert( - NetworkType::Udp4, - vec![ - Arc::clone(&relay_remote), - Arc::clone(&srflx_remote), - Arc::clone(&prflx_remote), - Arc::clone(&host_remote), - ], - ); - } - - let remote_stats = a.get_remote_candidates_stats().await; - assert_eq!( - remote_stats.len(), - 4, - "expected 4 remote candidates stats, got {} instead", - remote_stats.len() - ); - - let (mut relay_remote_stat, mut srflx_remote_stat, mut prflx_remote_stat, mut host_remote_stat) = ( - CandidateStats::default(), - CandidateStats::default(), - CandidateStats::default(), - CandidateStats::default(), - ); - for stats in remote_stats { - let candidate = if stats.id == relay_remote.id() { - relay_remote_stat = stats.clone(); - Arc::clone(&relay_remote) - } else if stats.id == srflx_remote.id() { - srflx_remote_stat = stats.clone(); - Arc::clone(&srflx_remote) - } else if stats.id == prflx_remote.id() { - prflx_remote_stat = stats.clone(); - Arc::clone(&prflx_remote) - } else if stats.id == host_remote.id() { - host_remote_stat = stats.clone(); - Arc::clone(&host_remote) - } else { - panic!("invalid remote candidate ID"); - }; - - assert_eq!( - stats.candidate_type, - candidate.candidate_type(), - "invalid stats CandidateType" - ); - assert_eq!( - stats.priority, - candidate.priority(), - "invalid stats CandidateType" - ); - assert_eq!(stats.ip, candidate.address(), "invalid stats IP"); - } - - assert_eq!( - relay_remote_stat.id, - relay_remote.id(), - "missing relay remote stat" - ); - assert_eq!( - srflx_remote_stat.id, - srflx_remote.id(), - "missing srflx remote stat" - ); - assert_eq!( - prflx_remote_stat.id, - prflx_remote.id(), - "missing prflx remote stat" - ); - assert_eq!( - host_remote_stat.id, - host_remote.id(), - "missing host remote stat" - ); - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_init_ext_ip_mapping() -> Result<()> { - // a.extIPMapper should be nil by default - let a = Agent::new(AgentConfig::default()).await?; - assert!( - a.ext_ip_mapper.is_none(), - "a.extIPMapper should be none by default" - ); - a.close().await?; - - // a.extIPMapper should be nil when NAT1To1IPs is a non-nil empty array - let a = Agent::new(AgentConfig { - nat_1to1_ips: vec![], - nat_1to1_ip_candidate_type: CandidateType::Host, - ..Default::default() - }) - .await?; - assert!( - a.ext_ip_mapper.is_none(), - "a.extIPMapper should be none by default" - ); - a.close().await?; - - // NewAgent should return an error when 1:1 NAT for host candidate is enabled - // but the candidate type does not appear in the CandidateTypes. - if let Err(err) = Agent::new(AgentConfig { - nat_1to1_ips: vec!["1.2.3.4".to_owned()], - nat_1to1_ip_candidate_type: CandidateType::Host, - candidate_types: vec![CandidateType::Relay], - ..Default::default() - }) - .await - { - assert_eq!( - Error::ErrIneffectiveNat1to1IpMappingHost, - err, - "Unexpected error: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - // NewAgent should return an error when 1:1 NAT for srflx candidate is enabled - // but the candidate type does not appear in the CandidateTypes. - if let Err(err) = Agent::new(AgentConfig { - nat_1to1_ips: vec!["1.2.3.4".to_owned()], - nat_1to1_ip_candidate_type: CandidateType::ServerReflexive, - candidate_types: vec![CandidateType::Relay], - ..Default::default() - }) - .await - { - assert_eq!( - Error::ErrIneffectiveNat1to1IpMappingSrflx, - err, - "Unexpected error: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - // NewAgent should return an error when 1:1 NAT for host candidate is enabled - // along with mDNS with MulticastDNSModeQueryAndGather - if let Err(err) = Agent::new(AgentConfig { - nat_1to1_ips: vec!["1.2.3.4".to_owned()], - nat_1to1_ip_candidate_type: CandidateType::Host, - multicast_dns_mode: MulticastDnsMode::QueryAndGather, - ..Default::default() - }) - .await - { - assert_eq!( - Error::ErrMulticastDnsWithNat1to1IpMapping, - err, - "Unexpected error: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - // NewAgent should return if newExternalIPMapper() returns an error. - if let Err(err) = Agent::new(AgentConfig { - nat_1to1_ips: vec!["bad.2.3.4".to_owned()], // bad IP - nat_1to1_ip_candidate_type: CandidateType::Host, - ..Default::default() - }) - .await - { - assert_eq!( - Error::ErrInvalidNat1to1IpMapping, - err, - "Unexpected error: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_binding_request_timeout() -> Result<()> { - const EXPECTED_REMOVAL_COUNT: usize = 2; - - let a = Agent::new(AgentConfig::default()).await?; - - let now = Instant::now(); - { - { - let mut pending_binding_requests = a.internal.pending_binding_requests.lock().await; - pending_binding_requests.push(BindingRequest { - timestamp: now, // valid - ..Default::default() - }); - pending_binding_requests.push(BindingRequest { - timestamp: now.sub(Duration::from_millis(3900)), // valid - ..Default::default() - }); - pending_binding_requests.push(BindingRequest { - timestamp: now.sub(Duration::from_millis(4100)), // invalid - ..Default::default() - }); - pending_binding_requests.push(BindingRequest { - timestamp: now.sub(Duration::from_secs(75)), // invalid - ..Default::default() - }); - } - - a.internal.invalidate_pending_binding_requests(now).await; - { - let pending_binding_requests = a.internal.pending_binding_requests.lock().await; - assert_eq!(pending_binding_requests.len(), EXPECTED_REMOVAL_COUNT, "Binding invalidation due to timeout did not remove the correct number of binding requests") - } - } - - a.close().await?; - - Ok(()) -} - -// test_agent_credentials checks if local username fragments and passwords (if set) meet RFC standard -// and ensure it's backwards compatible with previous versions of the pion/ice -#[tokio::test] -async fn test_agent_credentials() -> Result<()> { - // Agent should not require any of the usernames and password to be set - // If set, they should follow the default 16/128 bits random number generator strategy - - let a = Agent::new(AgentConfig::default()).await?; - { - let ufrag_pwd = a.internal.ufrag_pwd.lock().await; - assert!(ufrag_pwd.local_ufrag.as_bytes().len() * 8 >= 24); - assert!(ufrag_pwd.local_pwd.as_bytes().len() * 8 >= 128); - } - a.close().await?; - - // Should honor RFC standards - // Local values MUST be unguessable, with at least 128 bits of - // random number generator output used to generate the password, and - // at least 24 bits of output to generate the username fragment. - - if let Err(err) = Agent::new(AgentConfig { - local_ufrag: "xx".to_owned(), - ..Default::default() - }) - .await - { - assert_eq!(Error::ErrLocalUfragInsufficientBits, err); - } else { - panic!("expected error, but got ok"); - } - - if let Err(err) = Agent::new(AgentConfig { - local_pwd: "xxxxxx".to_owned(), - ..Default::default() - }) - .await - { - assert_eq!(Error::ErrLocalPwdInsufficientBits, err); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -// Assert that Agent on Failure deletes all existing candidates -// User can then do an ICE Restart to bring agent back -#[tokio::test] -async fn test_connection_state_failed_delete_all_candidates() -> Result<()> { - let one_second = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let cfg0 = AgentConfig { - network_types: supported_network_types(), - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - ..Default::default() - }; - let cfg1 = AgentConfig { - network_types: supported_network_types(), - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let b_agent = Arc::new(Agent::new(cfg1).await?); - - let (is_failed_tx, mut is_failed_rx) = mpsc::channel::<()>(1); - let is_failed_tx = Arc::new(Mutex::new(Some(is_failed_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_failed_tx_clone = Arc::clone(&is_failed_tx); - Box::pin(async move { - if c == ConnectionState::Failed { - let mut tx = is_failed_tx_clone.lock().await; - tx.take(); - } - }) - })); - - connect_with_vnet(&a_agent, &b_agent).await?; - let _ = is_failed_rx.recv().await; - - { - { - let remote_candidates = a_agent.internal.remote_candidates.lock().await; - assert_eq!(remote_candidates.len(), 0); - } - { - let local_candidates = a_agent.internal.local_candidates.lock().await; - assert_eq!(local_candidates.len(), 0); - } - } - - a_agent.close().await?; - b_agent.close().await?; - - Ok(()) -} - -// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides -#[tokio::test] -async fn test_connection_state_connecting_to_failed() -> Result<()> { - let one_second = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let cfg0 = AgentConfig { - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - ..Default::default() - }; - let cfg1 = AgentConfig { - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let b_agent = Arc::new(Agent::new(cfg1).await?); - - let is_failed = WaitGroup::new(); - let is_checking = WaitGroup::new(); - - let connection_state_check = move |wf: Worker, wc: Worker| { - let wf = Arc::new(Mutex::new(Some(wf))); - let wc = Arc::new(Mutex::new(Some(wc))); - let hdlr_fn: OnConnectionStateChangeHdlrFn = Box::new(move |c: ConnectionState| { - let wf_clone = Arc::clone(&wf); - let wc_clone = Arc::clone(&wc); - Box::pin(async move { - if c == ConnectionState::Failed { - let mut f = wf_clone.lock().await; - f.take(); - } else if c == ConnectionState::Checking { - let mut c = wc_clone.lock().await; - c.take(); - } else if c == ConnectionState::Connected || c == ConnectionState::Completed { - panic!("Unexpected ConnectionState: {c}"); - } - }) - }); - hdlr_fn - }; - - let (wf1, wc1) = (is_failed.worker(), is_checking.worker()); - a_agent.on_connection_state_change(connection_state_check(wf1, wc1)); - - let (wf2, wc2) = (is_failed.worker(), is_checking.worker()); - b_agent.on_connection_state_change(connection_state_check(wf2, wc2)); - - let agent_a = Arc::clone(&a_agent); - tokio::spawn(async move { - let (_cancel_tx, cancel_rx) = mpsc::channel(1); - let result = agent_a - .accept(cancel_rx, "InvalidFrag".to_owned(), "InvalidPwd".to_owned()) - .await; - assert!(result.is_err()); - }); - - let agent_b = Arc::clone(&b_agent); - tokio::spawn(async move { - let (_cancel_tx, cancel_rx) = mpsc::channel(1); - let result = agent_b - .dial(cancel_rx, "InvalidFrag".to_owned(), "InvalidPwd".to_owned()) - .await; - assert!(result.is_err()); - }); - - is_checking.wait().await; - is_failed.wait().await; - - a_agent.close().await?; - b_agent.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_agent_restart_during_gather() -> Result<()> { - //"Restart During Gather" - - let agent = Agent::new(AgentConfig::default()).await?; - - agent - .gathering_state - .store(GatheringState::Gathering as u8, Ordering::SeqCst); - - if let Err(err) = agent.restart("".to_owned(), "".to_owned()).await { - assert_eq!(Error::ErrRestartWhenGathering, err); - } else { - panic!("expected error, but got ok"); - } - - agent.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_agent_restart_when_closed() -> Result<()> { - //"Restart When Closed" - - let agent = Agent::new(AgentConfig::default()).await?; - agent.close().await?; - - if let Err(err) = agent.restart("".to_owned(), "".to_owned()).await { - assert_eq!(Error::ErrClosed, err); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_agent_restart_one_side() -> Result<()> { - let one_second = Duration::from_secs(1); - - //"Restart One Side" - let (_, _, agent_a, agent_b) = pipe( - Some(AgentConfig { - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - ..Default::default() - }), - Some(AgentConfig { - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - ..Default::default() - }), - ) - .await?; - - let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); - let cancel_tx = Arc::new(Mutex::new(Some(cancel_tx))); - agent_b.on_connection_state_change(Box::new(move |c: ConnectionState| { - let cancel_tx_clone = Arc::clone(&cancel_tx); - Box::pin(async move { - if c == ConnectionState::Failed || c == ConnectionState::Disconnected { - let mut tx = cancel_tx_clone.lock().await; - tx.take(); - } - }) - })); - - agent_a.restart("".to_owned(), "".to_owned()).await?; - - let _ = cancel_rx.recv().await; - - agent_a.close().await?; - agent_b.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_agent_restart_both_side() -> Result<()> { - let one_second = Duration::from_secs(1); - //"Restart Both Sides" - - // Get all addresses of candidates concatenated - let generate_candidate_address_strings = - |res: Result>>| -> String { - assert!(res.is_ok()); - - let mut out = String::new(); - if let Ok(candidates) = res { - for c in candidates { - out += c.address().as_str(); - out += ":"; - out += c.port().to_string().as_str(); - } - } - out - }; - - // Store the original candidates, confirm that after we reconnect we have new pairs - let (_, _, agent_a, agent_b) = pipe( - Some(AgentConfig { - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - ..Default::default() - }), - Some(AgentConfig { - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - ..Default::default() - }), - ) - .await?; - - let conn_afirst_candidates = - generate_candidate_address_strings(agent_a.get_local_candidates().await); - let conn_bfirst_candidates = - generate_candidate_address_strings(agent_b.get_local_candidates().await); - - let (a_notifier, mut a_connected) = on_connected(); - agent_a.on_connection_state_change(a_notifier); - - let (b_notifier, mut b_connected) = on_connected(); - agent_b.on_connection_state_change(b_notifier); - - // Restart and Re-Signal - agent_a.restart("".to_owned(), "".to_owned()).await?; - agent_b.restart("".to_owned(), "".to_owned()).await?; - - // Exchange Candidates and Credentials - let (ufrag, pwd) = agent_b.get_local_user_credentials().await; - agent_a.set_remote_credentials(ufrag, pwd).await?; - - let (ufrag, pwd) = agent_a.get_local_user_credentials().await; - agent_b.set_remote_credentials(ufrag, pwd).await?; - - gather_and_exchange_candidates(&agent_a, &agent_b).await?; - - // Wait until both have gone back to connected - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - // Assert that we have new candidates each time - assert_ne!( - conn_afirst_candidates, - generate_candidate_address_strings(agent_a.get_local_candidates().await) - ); - assert_ne!( - conn_bfirst_candidates, - generate_candidate_address_strings(agent_b.get_local_candidates().await) - ); - - agent_a.close().await?; - agent_b.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_get_remote_credentials() -> Result<()> { - let a = Agent::new(AgentConfig::default()).await?; - - let (remote_ufrag, remote_pwd) = { - let mut ufrag_pwd = a.internal.ufrag_pwd.lock().await; - "remoteUfrag".clone_into(&mut ufrag_pwd.remote_ufrag); - "remotePwd".clone_into(&mut ufrag_pwd.remote_pwd); - ( - ufrag_pwd.remote_ufrag.to_owned(), - ufrag_pwd.remote_pwd.to_owned(), - ) - }; - - let (actual_ufrag, actual_pwd) = a.get_remote_user_credentials().await; - - assert_eq!(actual_ufrag, remote_ufrag); - assert_eq!(actual_pwd, remote_pwd); - - a.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_close_in_connection_state_callback() -> Result<()> { - let disconnected_duration = Duration::from_secs(1); - let failed_duration = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let cfg0 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(disconnected_duration), - failed_timeout: Some(failed_duration), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(500), - ..Default::default() - }; - - let cfg1 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(disconnected_duration), - failed_timeout: Some(failed_duration), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(500), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let b_agent = Arc::new(Agent::new(cfg1).await?); - - let (is_closed_tx, mut is_closed_rx) = mpsc::channel::<()>(1); - let (is_connected_tx, mut is_connected_rx) = mpsc::channel::<()>(1); - let is_closed_tx = Arc::new(Mutex::new(Some(is_closed_tx))); - let is_connected_tx = Arc::new(Mutex::new(Some(is_connected_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_closed_tx_clone = Arc::clone(&is_closed_tx); - let is_connected_tx_clone = Arc::clone(&is_connected_tx); - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = is_connected_tx_clone.lock().await; - tx.take(); - } else if c == ConnectionState::Closed { - let mut tx = is_closed_tx_clone.lock().await; - tx.take(); - } - }) - })); - - connect_with_vnet(&a_agent, &b_agent).await?; - - let _ = is_connected_rx.recv().await; - a_agent.close().await?; - - let _ = is_closed_rx.recv().await; - b_agent.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_run_task_in_connection_state_callback() -> Result<()> { - let one_second = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let cfg0 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(50), - ..Default::default() - }; - - let cfg1 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(50), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let b_agent = Arc::new(Agent::new(cfg1).await?); - - let (is_complete_tx, mut is_complete_rx) = mpsc::channel::<()>(1); - let is_complete_tx = Arc::new(Mutex::new(Some(is_complete_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_complete_tx_clone = Arc::clone(&is_complete_tx); - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = is_complete_tx_clone.lock().await; - tx.take(); - } - }) - })); - - connect_with_vnet(&a_agent, &b_agent).await?; - - let _ = is_complete_rx.recv().await; - let _ = a_agent.get_local_user_credentials().await; - a_agent.restart("".to_owned(), "".to_owned()).await?; - - a_agent.close().await?; - b_agent.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_run_task_in_selected_candidate_pair_change_callback() -> Result<()> { - let one_second = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let cfg0 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(50), - ..Default::default() - }; - - let cfg1 = AgentConfig { - urls: vec![], - network_types: supported_network_types(), - disconnected_timeout: Some(one_second), - failed_timeout: Some(one_second), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(50), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let b_agent = Arc::new(Agent::new(cfg1).await?); - - let (is_tested_tx, mut is_tested_rx) = mpsc::channel::<()>(1); - let is_tested_tx = Arc::new(Mutex::new(Some(is_tested_tx))); - a_agent.on_selected_candidate_pair_change(Box::new( - move |_: &Arc, _: &Arc| { - let is_tested_tx_clone = Arc::clone(&is_tested_tx); - Box::pin(async move { - let mut tx = is_tested_tx_clone.lock().await; - tx.take(); - }) - }, - )); - - let (is_complete_tx, mut is_complete_rx) = mpsc::channel::<()>(1); - let is_complete_tx = Arc::new(Mutex::new(Some(is_complete_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_complete_tx_clone = Arc::clone(&is_complete_tx); - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = is_complete_tx_clone.lock().await; - tx.take(); - } - }) - })); - - connect_with_vnet(&a_agent, &b_agent).await?; - - let _ = is_complete_rx.recv().await; - let _ = is_tested_rx.recv().await; - - let _ = a_agent.get_local_user_credentials().await; - - a_agent.close().await?; - b_agent.close().await?; - - Ok(()) -} - -// Assert that a Lite agent goes to disconnected and failed -#[tokio::test] -async fn test_lite_lifecycle() -> Result<()> { - let (a_notifier, mut a_connected_rx) = on_connected(); - - let a_agent = Arc::new( - Agent::new(AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - ..Default::default() - }) - .await?, - ); - - a_agent.on_connection_state_change(a_notifier); - - let disconnected_duration = Duration::from_secs(1); - let failed_duration = Duration::from_secs(1); - let keepalive_interval = Duration::from_secs(0); - - let b_agent = Arc::new( - Agent::new(AgentConfig { - lite: true, - candidate_types: vec![CandidateType::Host], - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - disconnected_timeout: Some(disconnected_duration), - failed_timeout: Some(failed_duration), - keepalive_interval: Some(keepalive_interval), - check_interval: Duration::from_millis(500), - ..Default::default() - }) - .await?, - ); - - let (b_connected_tx, mut b_connected_rx) = mpsc::channel::<()>(1); - let (b_disconnected_tx, mut b_disconnected_rx) = mpsc::channel::<()>(1); - let (b_failed_tx, mut b_failed_rx) = mpsc::channel::<()>(1); - let b_connected_tx = Arc::new(Mutex::new(Some(b_connected_tx))); - let b_disconnected_tx = Arc::new(Mutex::new(Some(b_disconnected_tx))); - let b_failed_tx = Arc::new(Mutex::new(Some(b_failed_tx))); - - b_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let b_connected_tx_clone = Arc::clone(&b_connected_tx); - let b_disconnected_tx_clone = Arc::clone(&b_disconnected_tx); - let b_failed_tx_clone = Arc::clone(&b_failed_tx); - - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = b_connected_tx_clone.lock().await; - tx.take(); - } else if c == ConnectionState::Disconnected { - let mut tx = b_disconnected_tx_clone.lock().await; - tx.take(); - } else if c == ConnectionState::Failed { - let mut tx = b_failed_tx_clone.lock().await; - tx.take(); - } - }) - })); - - connect_with_vnet(&b_agent, &a_agent).await?; - - let _ = a_connected_rx.recv().await; - let _ = b_connected_rx.recv().await; - a_agent.close().await?; - - let _ = b_disconnected_rx.recv().await; - let _ = b_failed_rx.recv().await; - - b_agent.close().await?; - - Ok(()) -} diff --git a/ice/src/agent/agent_transport.rs b/ice/src/agent/agent_transport.rs deleted file mode 100644 index 6b213554d..000000000 --- a/ice/src/agent/agent_transport.rs +++ /dev/null @@ -1,251 +0,0 @@ -use std::io; -use std::sync::atomic::Ordering; - -use arc_swap::ArcSwapOption; -use async_trait::async_trait; -use portable_atomic::AtomicBool; -use util::Conn; - -use super::*; -use crate::error::*; - -impl Agent { - /// Connects to the remote agent, acting as the controlling ice agent. - /// The method blocks until at least one ice candidate pair has successfully connected. - /// - /// The operation will be cancelled if `cancel_rx` either receives a message or its channel - /// closes. - pub async fn dial( - &self, - mut cancel_rx: mpsc::Receiver<()>, - remote_ufrag: String, - remote_pwd: String, - ) -> Result> { - let (on_connected_rx, agent_conn) = { - self.internal - .start_connectivity_checks(true, remote_ufrag, remote_pwd) - .await?; - - let mut on_connected_rx = self.internal.on_connected_rx.lock().await; - ( - on_connected_rx.take(), - Arc::clone(&self.internal.agent_conn), - ) - }; - - if let Some(mut on_connected_rx) = on_connected_rx { - // block until pair selected - tokio::select! { - _ = on_connected_rx.recv() => {}, - _ = cancel_rx.recv() => { - return Err(Error::ErrCanceledByCaller); - } - } - } - Ok(agent_conn) - } - - /// Connects to the remote agent, acting as the controlled ice agent. - /// The method blocks until at least one ice candidate pair has successfully connected. - /// - /// The operation will be cancelled if `cancel_rx` either receives a message or its channel - /// closes. - pub async fn accept( - &self, - mut cancel_rx: mpsc::Receiver<()>, - remote_ufrag: String, - remote_pwd: String, - ) -> Result> { - let (on_connected_rx, agent_conn) = { - self.internal - .start_connectivity_checks(false, remote_ufrag, remote_pwd) - .await?; - - let mut on_connected_rx = self.internal.on_connected_rx.lock().await; - ( - on_connected_rx.take(), - Arc::clone(&self.internal.agent_conn), - ) - }; - - if let Some(mut on_connected_rx) = on_connected_rx { - // block until pair selected - tokio::select! { - _ = on_connected_rx.recv() => {}, - _ = cancel_rx.recv() => { - return Err(Error::ErrCanceledByCaller); - } - } - } - - Ok(agent_conn) - } -} - -pub(crate) struct AgentConn { - pub(crate) selected_pair: ArcSwapOption, - pub(crate) checklist: Mutex>>, - - pub(crate) buffer: Buffer, - pub(crate) bytes_received: AtomicUsize, - pub(crate) bytes_sent: AtomicUsize, - pub(crate) done: AtomicBool, -} - -impl AgentConn { - pub(crate) fn new() -> Self { - Self { - selected_pair: ArcSwapOption::empty(), - checklist: Mutex::new(vec![]), - // Make sure the buffer doesn't grow indefinitely. - // NOTE: We actually won't get anywhere close to this limit. - // SRTP will constantly read from the endpoint and drop packets if it's full. - buffer: Buffer::new(0, MAX_BUFFER_SIZE), - bytes_received: AtomicUsize::new(0), - bytes_sent: AtomicUsize::new(0), - done: AtomicBool::new(false), - } - } - pub(crate) fn get_selected_pair(&self) -> Option> { - self.selected_pair.load().clone() - } - - pub(crate) async fn get_best_available_candidate_pair(&self) -> Option> { - let mut best: Option<&Arc> = None; - - let checklist = self.checklist.lock().await; - for p in &*checklist { - if p.state.load(Ordering::SeqCst) == CandidatePairState::Failed as u8 { - continue; - } - - if let Some(b) = &mut best { - if b.priority() < p.priority() { - *b = p; - } - } else { - best = Some(p); - } - } - - best.cloned() - } - - pub(crate) async fn get_best_valid_candidate_pair(&self) -> Option> { - let mut best: Option<&Arc> = None; - - let checklist = self.checklist.lock().await; - for p in &*checklist { - if p.state.load(Ordering::SeqCst) != CandidatePairState::Succeeded as u8 { - continue; - } - - if let Some(b) = &mut best { - if b.priority() < p.priority() { - *b = p; - } - } else { - best = Some(p); - } - } - - best.cloned() - } - - /// Returns the number of bytes sent. - pub fn bytes_sent(&self) -> usize { - self.bytes_sent.load(Ordering::SeqCst) - } - - /// Returns the number of bytes received. - pub fn bytes_received(&self) -> usize { - self.bytes_received.load(Ordering::SeqCst) - } -} - -#[async_trait] -impl Conn for AgentConn { - async fn connect(&self, _addr: SocketAddr) -> std::result::Result<(), util::Error> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, buf: &mut [u8]) -> std::result::Result { - if self.done.load(Ordering::SeqCst) { - return Err(io::Error::new(io::ErrorKind::Other, "Conn is closed").into()); - } - - let n = match self.buffer.read(buf, None).await { - Ok(n) => n, - Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), - }; - self.bytes_received.fetch_add(n, Ordering::SeqCst); - - Ok(n) - } - - async fn recv_from( - &self, - buf: &mut [u8], - ) -> std::result::Result<(usize, SocketAddr), util::Error> { - if let Some(raddr) = self.remote_addr() { - let n = self.recv(buf).await?; - Ok((n, raddr)) - } else { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - } - - async fn send(&self, buf: &[u8]) -> std::result::Result { - if self.done.load(Ordering::SeqCst) { - return Err(io::Error::new(io::ErrorKind::Other, "Conn is closed").into()); - } - - if is_message(buf) { - return Err(util::Error::Other("ErrIceWriteStunMessage".into())); - } - - let result = if let Some(pair) = self.get_selected_pair() { - pair.write(buf).await - } else if let Some(pair) = self.get_best_available_candidate_pair().await { - pair.write(buf).await - } else { - Ok(0) - }; - - match result { - Ok(n) => { - self.bytes_sent.fetch_add(buf.len(), Ordering::SeqCst); - Ok(n) - } - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), - } - } - - async fn send_to( - &self, - _buf: &[u8], - _target: SocketAddr, - ) -> std::result::Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> std::result::Result { - if let Some(pair) = self.get_selected_pair() { - Ok(pair.local.addr()) - } else { - Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Addr Not Available").into()) - } - } - - fn remote_addr(&self) -> Option { - self.get_selected_pair().map(|pair| pair.remote.addr()) - } - - async fn close(&self) -> std::result::Result<(), util::Error> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/ice/src/agent/agent_transport_test.rs b/ice/src/agent/agent_transport_test.rs deleted file mode 100644 index 8ab909df6..000000000 --- a/ice/src/agent/agent_transport_test.rs +++ /dev/null @@ -1,125 +0,0 @@ -use util::vnet::*; -use util::Conn; -use waitgroup::WaitGroup; - -use super::agent_vnet_test::*; -use super::*; -use crate::agent::agent_transport::AgentConn; - -pub(crate) async fn pipe( - default_config0: Option, - default_config1: Option, -) -> Result<(Arc, Arc, Arc, Arc)> { - let (a_notifier, mut a_connected) = on_connected(); - let (b_notifier, mut b_connected) = on_connected(); - - let mut cfg0 = default_config0.unwrap_or_default(); - cfg0.urls = vec![]; - cfg0.network_types = supported_network_types(); - - let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); - - let mut cfg1 = default_config1.unwrap_or_default(); - cfg1.urls = vec![]; - cfg1.network_types = supported_network_types(); - - let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); - - let (a_conn, b_conn) = connect_with_vnet(&a_agent, &b_agent).await?; - - // Ensure pair selected - // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - Ok((a_conn, b_conn, a_agent, b_agent)) -} - -#[tokio::test] -async fn test_remote_local_addr() -> Result<()> { - // Agent0 is behind 1:1 NAT - let nat_type0 = nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }; - // Agent1 is behind 1:1 NAT - let nat_type1 = nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }; - - let v = build_vnet(nat_type0, nat_type1).await?; - - let stun_server_url = Url { - scheme: SchemeType::Stun, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - proto: ProtoType::Udp, - ..Default::default() - }; - - //"Disconnected Returns nil" - { - let disconnected_conn = AgentConn::new(); - let result = disconnected_conn.local_addr(); - assert!(result.is_err(), "Disconnected Returns nil"); - } - - //"Remote/Local Pair Match between Agents" - { - let (ca, cb) = pipe_with_vnet( - &v, - AgentTestConfig { - urls: vec![stun_server_url.clone()], - ..Default::default() - }, - AgentTestConfig { - urls: vec![stun_server_url], - ..Default::default() - }, - ) - .await?; - - let a_laddr = ca.local_addr()?; - let b_laddr = cb.local_addr()?; - - // Assert addresses - assert_eq!(a_laddr.ip().to_string(), VNET_LOCAL_IPA.to_string()); - assert_eq!(b_laddr.ip().to_string(), VNET_LOCAL_IPB.to_string()); - - // Close - //ca.close().await?; - //cb.close().await?; - } - - v.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_conn_stats() -> Result<()> { - let (ca, cb, _, _) = pipe(None, None).await?; - let na = ca.send(&[0u8; 10]).await?; - - let wg = WaitGroup::new(); - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - let mut buf = vec![0u8; 10]; - let nb = cb.recv(&mut buf).await?; - assert_eq!(nb, 10, "bytes received don't match"); - - Result::<()>::Ok(()) - }); - - wg.wait().await; - - assert_eq!(na, 10, "bytes sent don't match"); - - Ok(()) -} diff --git a/ice/src/agent/agent_vnet_test.rs b/ice/src/agent/agent_vnet_test.rs deleted file mode 100644 index 5974a63e0..000000000 --- a/ice/src/agent/agent_vnet_test.rs +++ /dev/null @@ -1,1019 +0,0 @@ -use std::net::{IpAddr, Ipv4Addr}; -use std::result::Result; -use std::str::FromStr; - -use async_trait::async_trait; -use portable_atomic::AtomicU64; -use util::vnet::chunk::Chunk; -use util::vnet::router::Nic; -use util::vnet::*; -use util::Conn; -use waitgroup::WaitGroup; - -use super::*; -use crate::candidate::candidate_base::unmarshal_candidate; - -pub(crate) struct MockConn; - -#[async_trait] -impl Conn for MockConn { - async fn connect(&self, _addr: SocketAddr) -> Result<(), util::Error> { - Ok(()) - } - async fn recv(&self, _buf: &mut [u8]) -> Result { - Ok(0) - } - async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr), util::Error> { - Ok((0, SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0))) - } - async fn send(&self, _buf: &[u8]) -> Result { - Ok(0) - } - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { - Ok(0) - } - fn local_addr(&self) -> Result { - Ok(SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)) - } - fn remote_addr(&self) -> Option { - None - } - async fn close(&self) -> Result<(), util::Error> { - Ok(()) - } - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -pub(crate) struct VNet { - pub(crate) wan: Arc>, - pub(crate) net0: Arc, - pub(crate) net1: Arc, - pub(crate) server: turn::server::Server, -} - -impl VNet { - pub(crate) async fn close(&self) -> Result<(), Error> { - self.server.close().await?; - let mut w = self.wan.lock().await; - w.stop().await?; - Ok(()) - } -} - -pub(crate) const VNET_GLOBAL_IPA: &str = "27.1.1.1"; -pub(crate) const VNET_LOCAL_IPA: &str = "192.168.0.1"; -pub(crate) const VNET_LOCAL_SUBNET_MASK_A: &str = "24"; -pub(crate) const VNET_GLOBAL_IPB: &str = "28.1.1.1"; -pub(crate) const VNET_LOCAL_IPB: &str = "10.2.0.1"; -pub(crate) const VNET_LOCAL_SUBNET_MASK_B: &str = "24"; -pub(crate) const VNET_STUN_SERVER_IP: &str = "1.2.3.4"; -pub(crate) const VNET_STUN_SERVER_PORT: u16 = 3478; - -pub(crate) async fn build_simple_vnet( - _nat_type0: nat::NatType, - _nat_type1: nat::NatType, -) -> Result { - // WAN - let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?)); - - let wnet = Arc::new(net::Net::new(Some(net::NetConfig { - static_ip: VNET_STUN_SERVER_IP.to_owned(), // will be assigned to eth0 - ..Default::default() - }))); - - connect_net2router(&wnet, &wan).await?; - - // LAN - let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: format!("{VNET_LOCAL_IPA}/{VNET_LOCAL_SUBNET_MASK_A}"), - ..Default::default() - })?)); - - let net0 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.1".to_owned()], - ..Default::default() - }))); - let net1 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.2".to_owned()], - ..Default::default() - }))); - - connect_net2router(&net0, &lan).await?; - connect_net2router(&net1, &lan).await?; - connect_router2router(&lan, &wan).await?; - - // start routers... - start_router(&wan).await?; - - let server = add_vnet_stun(wnet).await?; - - Ok(VNet { - wan, - net0, - net1, - server, - }) -} - -pub(crate) async fn build_vnet( - nat_type0: nat::NatType, - nat_type1: nat::NatType, -) -> Result { - // WAN - let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?)); - - let wnet = Arc::new(net::Net::new(Some(net::NetConfig { - static_ip: VNET_STUN_SERVER_IP.to_owned(), // will be assigned to eth0 - ..Default::default() - }))); - - connect_net2router(&wnet, &wan).await?; - - // LAN 0 - let lan0 = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - static_ips: if nat_type0.mode == nat::NatMode::Nat1To1 { - vec![format!("{VNET_GLOBAL_IPA}/{VNET_LOCAL_IPA}")] - } else { - vec![VNET_GLOBAL_IPA.to_owned()] - }, - cidr: format!("{VNET_LOCAL_IPA}/{VNET_LOCAL_SUBNET_MASK_A}"), - nat_type: Some(nat_type0), - ..Default::default() - })?)); - - let net0 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec![VNET_LOCAL_IPA.to_owned()], - ..Default::default() - }))); - - connect_net2router(&net0, &lan0).await?; - connect_router2router(&lan0, &wan).await?; - - // LAN 1 - let lan1 = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - static_ips: if nat_type1.mode == nat::NatMode::Nat1To1 { - vec![format!("{VNET_GLOBAL_IPB}/{VNET_LOCAL_IPB}")] - } else { - vec![VNET_GLOBAL_IPB.to_owned()] - }, - cidr: format!("{VNET_LOCAL_IPB}/{VNET_LOCAL_SUBNET_MASK_B}"), - nat_type: Some(nat_type1), - ..Default::default() - })?)); - - let net1 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec![VNET_LOCAL_IPB.to_owned()], - ..Default::default() - }))); - - connect_net2router(&net1, &lan1).await?; - connect_router2router(&lan1, &wan).await?; - - // start routers... - start_router(&wan).await?; - - let server = add_vnet_stun(wnet).await?; - - Ok(VNet { - wan, - net0, - net1, - server, - }) -} - -pub(crate) struct TestAuthHandler { - pub(crate) cred_map: HashMap>, -} - -impl TestAuthHandler { - pub(crate) fn new() -> Self { - let mut cred_map = HashMap::new(); - cred_map.insert( - "user".to_owned(), - turn::auth::generate_auth_key("user", "webrtc.rs", "pass"), - ); - - TestAuthHandler { cred_map } - } -} - -impl turn::auth::AuthHandler for TestAuthHandler { - fn auth_handle( - &self, - username: &str, - _realm: &str, - _src_addr: SocketAddr, - ) -> Result, turn::Error> { - if let Some(pw) = self.cred_map.get(username) { - Ok(pw.to_vec()) - } else { - Err(turn::Error::Other("fake error".to_owned())) - } - } -} - -pub(crate) async fn add_vnet_stun(wan_net: Arc) -> Result { - // Run TURN(STUN) server - let conn = wan_net - .bind(SocketAddr::from_str(&format!( - "{VNET_STUN_SERVER_IP}:{VNET_STUN_SERVER_PORT}" - ))?) - .await?; - - let server = turn::server::Server::new(turn::server::config::ServerConfig { - conn_configs: vec![turn::server::config::ConnConfig { - conn, - relay_addr_generator: Box::new( - turn::relay::relay_static::RelayAddressGeneratorStatic { - relay_address: IpAddr::from_str(VNET_STUN_SERVER_IP)?, - address: "0.0.0.0".to_owned(), - net: wan_net, - }, - ), - }], - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(TestAuthHandler::new()), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - Ok(server) -} - -pub(crate) async fn connect_with_vnet( - a_agent: &Arc, - b_agent: &Arc, -) -> Result<(Arc, Arc), Error> { - // Manual signaling - let (a_ufrag, a_pwd) = a_agent.get_local_user_credentials().await; - let (b_ufrag, b_pwd) = b_agent.get_local_user_credentials().await; - - gather_and_exchange_candidates(a_agent, b_agent).await?; - - let (accepted_tx, mut accepted_rx) = mpsc::channel(1); - let (_a_cancel_tx, a_cancel_rx) = mpsc::channel(1); - - let agent_a = Arc::clone(a_agent); - tokio::spawn(async move { - let a_conn = agent_a.accept(a_cancel_rx, b_ufrag, b_pwd).await?; - - let _ = accepted_tx.send(a_conn).await; - - Result::<(), Error>::Ok(()) - }); - - let (_b_cancel_tx, b_cancel_rx) = mpsc::channel(1); - let b_conn = b_agent.dial(b_cancel_rx, a_ufrag, a_pwd).await?; - - // Ensure accepted - if let Some(a_conn) = accepted_rx.recv().await { - Ok((a_conn, b_conn)) - } else { - Err(Error::Other("no a_conn".to_owned())) - } -} - -#[derive(Default)] -pub(crate) struct AgentTestConfig { - pub(crate) urls: Vec, - pub(crate) nat_1to1_ip_candidate_type: CandidateType, -} - -pub(crate) async fn pipe_with_vnet( - v: &VNet, - a0test_config: AgentTestConfig, - a1test_config: AgentTestConfig, -) -> Result<(Arc, Arc), Error> { - let (a_notifier, mut a_connected) = on_connected(); - let (b_notifier, mut b_connected) = on_connected(); - - let nat_1to1_ips = if a0test_config.nat_1to1_ip_candidate_type != CandidateType::Unspecified { - vec![VNET_GLOBAL_IPA.to_owned()] - } else { - vec![] - }; - - let cfg0 = AgentConfig { - urls: a0test_config.urls, - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - nat_1to1_ips, - nat_1to1_ip_candidate_type: a0test_config.nat_1to1_ip_candidate_type, - net: Some(Arc::clone(&v.net0)), - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); - - let nat_1to1_ips = if a1test_config.nat_1to1_ip_candidate_type != CandidateType::Unspecified { - vec![VNET_GLOBAL_IPB.to_owned()] - } else { - vec![] - }; - let cfg1 = AgentConfig { - urls: a1test_config.urls, - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - nat_1to1_ips, - nat_1to1_ip_candidate_type: a1test_config.nat_1to1_ip_candidate_type, - net: Some(Arc::clone(&v.net1)), - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); - - let (a_conn, b_conn) = connect_with_vnet(&a_agent, &b_agent).await?; - - // Ensure pair selected - // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - Ok((a_conn, b_conn)) -} - -pub(crate) fn on_connected() -> (OnConnectionStateChangeHdlrFn, mpsc::Receiver<()>) { - let (done_tx, done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - let hdlr_fn: OnConnectionStateChangeHdlrFn = Box::new(move |state: ConnectionState| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if state == ConnectionState::Connected { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }); - (hdlr_fn, done_rx) -} - -pub(crate) async fn gather_and_exchange_candidates( - a_agent: &Arc, - b_agent: &Arc, -) -> Result<(), Error> { - let wg = WaitGroup::new(); - - let w1 = Arc::new(Mutex::new(Some(wg.worker()))); - a_agent.on_candidate(Box::new( - move |candidate: Option>| { - let w3 = Arc::clone(&w1); - Box::pin(async move { - if candidate.is_none() { - let mut w = w3.lock().await; - w.take(); - } - }) - }, - )); - a_agent.gather_candidates()?; - - let w2 = Arc::new(Mutex::new(Some(wg.worker()))); - b_agent.on_candidate(Box::new( - move |candidate: Option>| { - let w3 = Arc::clone(&w2); - Box::pin(async move { - if candidate.is_none() { - let mut w = w3.lock().await; - w.take(); - } - }) - }, - )); - b_agent.gather_candidates()?; - - wg.wait().await; - - let candidates = a_agent.get_local_candidates().await?; - for c in candidates { - let c2: Arc = - Arc::new(unmarshal_candidate(c.marshal().as_str())?); - b_agent.add_remote_candidate(&c2)?; - } - - let candidates = b_agent.get_local_candidates().await?; - for c in candidates { - let c2: Arc = - Arc::new(unmarshal_candidate(c.marshal().as_str())?); - a_agent.add_remote_candidate(&c2)?; - } - - Ok(()) -} - -pub(crate) async fn start_router(router: &Arc>) -> Result<(), Error> { - let mut w = router.lock().await; - Ok(w.start().await?) -} - -pub(crate) async fn connect_net2router( - net: &Arc, - router: &Arc>, -) -> Result<(), Error> { - let nic = net.get_nic()?; - - { - let mut w = router.lock().await; - w.add_net(Arc::clone(&nic)).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(router)).await?; - } - - Ok(()) -} - -pub(crate) async fn connect_router2router( - child: &Arc>, - parent: &Arc>, -) -> Result<(), Error> { - { - let mut w = parent.lock().await; - w.add_router(Arc::clone(child)).await?; - } - - { - let l = child.lock().await; - l.set_router(Arc::clone(parent)).await?; - } - - Ok(()) -} - -#[tokio::test] -async fn test_connectivity_simple_vnet_full_cone_nats_on_both_ends() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let stun_server_url = Url { - scheme: SchemeType::Stun, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - proto: ProtoType::Udp, - ..Default::default() - }; - - // buildVNet with a Full-cone NATs both LANs - let nat_type = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, - filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, - ..Default::default() - }; - - let v = build_simple_vnet(nat_type, nat_type).await?; - - log::debug!("Connecting..."); - let a0test_config = AgentTestConfig { - urls: vec![stun_server_url.clone()], - ..Default::default() - }; - let a1test_config = AgentTestConfig { - urls: vec![stun_server_url.clone()], - ..Default::default() - }; - let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; - - tokio::time::sleep(Duration::from_secs(1)).await; - - log::debug!("Closing..."); - v.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_connectivity_vnet_full_cone_nats_on_both_ends() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let stun_server_url = Url { - scheme: SchemeType::Stun, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - proto: ProtoType::Udp, - ..Default::default() - }; - - let _turn_server_url = Url { - scheme: SchemeType::Turn, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - username: "user".to_owned(), - password: "pass".to_owned(), - proto: ProtoType::Udp, - }; - - // buildVNet with a Full-cone NATs both LANs - let nat_type = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, - filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, - ..Default::default() - }; - - let v = build_vnet(nat_type, nat_type).await?; - - log::debug!("Connecting..."); - let a0test_config = AgentTestConfig { - urls: vec![stun_server_url.clone()], - ..Default::default() - }; - let a1test_config = AgentTestConfig { - urls: vec![stun_server_url.clone()], - ..Default::default() - }; - let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; - - tokio::time::sleep(Duration::from_secs(1)).await; - - log::debug!("Closing..."); - v.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_connectivity_vnet_symmetric_nats_on_both_ends() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let stun_server_url = Url { - scheme: SchemeType::Stun, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - proto: ProtoType::Udp, - ..Default::default() - }; - - let turn_server_url = Url { - scheme: SchemeType::Turn, - host: VNET_STUN_SERVER_IP.to_owned(), - port: VNET_STUN_SERVER_PORT, - username: "user".to_owned(), - password: "pass".to_owned(), - proto: ProtoType::Udp, - }; - - // buildVNet with a Symmetric NATs for both LANs - let nat_type = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - ..Default::default() - }; - - let v = build_vnet(nat_type, nat_type).await?; - - log::debug!("Connecting..."); - let a0test_config = AgentTestConfig { - urls: vec![stun_server_url.clone(), turn_server_url.clone()], - ..Default::default() - }; - let a1test_config = AgentTestConfig { - urls: vec![stun_server_url.clone()], - ..Default::default() - }; - let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; - - tokio::time::sleep(Duration::from_secs(1)).await; - - log::debug!("Closing..."); - v.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_connectivity_vnet_1to1_nat_with_host_candidate_vs_symmetric_nats() -> Result<(), Error> -{ - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - // Agent0 is behind 1:1 NAT - let nat_type0 = nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }; - // Agent1 is behind a symmetric NAT - let nat_type1 = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - ..Default::default() - }; - log::debug!("natType0: {:?}", nat_type0); - log::debug!("natType1: {:?}", nat_type1); - - let v = build_vnet(nat_type0, nat_type1).await?; - - log::debug!("Connecting..."); - let a0test_config = AgentTestConfig { - urls: vec![], - nat_1to1_ip_candidate_type: CandidateType::Host, // Use 1:1 NAT IP as a host candidate - }; - let a1test_config = AgentTestConfig { - urls: vec![], - ..Default::default() - }; - let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; - - tokio::time::sleep(Duration::from_secs(1)).await; - - log::debug!("Closing..."); - v.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_connectivity_vnet_1to1_nat_with_srflx_candidate_vs_symmetric_nats( -) -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - // Agent0 is behind 1:1 NAT - let nat_type0 = nat::NatType { - mode: nat::NatMode::Nat1To1, - ..Default::default() - }; - // Agent1 is behind a symmetric NAT - let nat_type1 = nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - filtering_behavior: nat::EndpointDependencyType::EndpointAddrPortDependent, - ..Default::default() - }; - log::debug!("natType0: {:?}", nat_type0); - log::debug!("natType1: {:?}", nat_type1); - - let v = build_vnet(nat_type0, nat_type1).await?; - - log::debug!("Connecting..."); - let a0test_config = AgentTestConfig { - urls: vec![], - nat_1to1_ip_candidate_type: CandidateType::ServerReflexive, // Use 1:1 NAT IP as a srflx candidate - }; - let a1test_config = AgentTestConfig { - urls: vec![], - ..Default::default() - }; - let (_ca, _cb) = pipe_with_vnet(&v, a0test_config, a1test_config).await?; - - tokio::time::sleep(Duration::from_secs(1)).await; - - log::debug!("Closing..."); - v.close().await?; - - Ok(()) -} - -async fn block_until_state_seen( - expected_state: ConnectionState, - state_queue: &mut mpsc::Receiver, -) { - while let Some(s) = state_queue.recv().await { - if s == expected_state { - return; - } - } -} - -// test_disconnected_to_connected asserts that an agent can go to disconnected, and then return to connected successfully -#[tokio::test] -async fn test_disconnected_to_connected() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - // Create a network with two interfaces - let wan = router::Router::new(router::RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?; - - let drop_all_data = Arc::new(AtomicU64::new(0)); - let drop_all_data2 = Arc::clone(&drop_all_data); - wan.add_chunk_filter(Box::new(move |_c: &(dyn Chunk + Send + Sync)| -> bool { - drop_all_data2.load(Ordering::SeqCst) != 1 - })) - .await; - let wan = Arc::new(Mutex::new(wan)); - - let net0 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.1".to_owned()], - ..Default::default() - }))); - let net1 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.2".to_owned()], - ..Default::default() - }))); - - connect_net2router(&net0, &wan).await?; - connect_net2router(&net1, &wan).await?; - start_router(&wan).await?; - - let disconnected_timeout = Duration::from_secs(1); - let keepalive_interval = Duration::from_millis(20); - - // Create two agents and connect them - let controlling_agent = Arc::new( - Agent::new(AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(Arc::clone(&net0)), - disconnected_timeout: Some(disconnected_timeout), - keepalive_interval: Some(keepalive_interval), - check_interval: keepalive_interval, - ..Default::default() - }) - .await?, - ); - - let controlled_agent = Arc::new( - Agent::new(AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(Arc::clone(&net1)), - disconnected_timeout: Some(disconnected_timeout), - keepalive_interval: Some(keepalive_interval), - check_interval: keepalive_interval, - ..Default::default() - }) - .await?, - ); - - let (controlling_state_changes_tx, mut controlling_state_changes_rx) = - mpsc::channel::(100); - let controlling_state_changes_tx = Arc::new(controlling_state_changes_tx); - controlling_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let controlling_state_changes_tx_clone = Arc::clone(&controlling_state_changes_tx); - Box::pin(async move { - let _ = controlling_state_changes_tx_clone.try_send(c); - }) - })); - - let (controlled_state_changes_tx, mut controlled_state_changes_rx) = - mpsc::channel::(100); - let controlled_state_changes_tx = Arc::new(controlled_state_changes_tx); - controlled_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let controlled_state_changes_tx_clone = Arc::clone(&controlled_state_changes_tx); - Box::pin(async move { - let _ = controlled_state_changes_tx_clone.try_send(c); - }) - })); - - connect_with_vnet(&controlling_agent, &controlled_agent).await?; - - // Assert we have gone to connected - block_until_state_seen( - ConnectionState::Connected, - &mut controlling_state_changes_rx, - ) - .await; - block_until_state_seen(ConnectionState::Connected, &mut controlled_state_changes_rx).await; - - // Drop all packets, and block until we have gone to disconnected - drop_all_data.store(1, Ordering::SeqCst); - block_until_state_seen( - ConnectionState::Disconnected, - &mut controlling_state_changes_rx, - ) - .await; - block_until_state_seen( - ConnectionState::Disconnected, - &mut controlled_state_changes_rx, - ) - .await; - - // Allow all packets through again, block until we have gone to connected - drop_all_data.store(0, Ordering::SeqCst); - block_until_state_seen( - ConnectionState::Connected, - &mut controlling_state_changes_rx, - ) - .await; - block_until_state_seen(ConnectionState::Connected, &mut controlled_state_changes_rx).await; - - { - let mut w = wan.lock().await; - w.stop().await?; - } - - controlling_agent.close().await?; - controlled_agent.close().await?; - - Ok(()) -} - -//use std::io::Write; - -// Agent.Write should use the best valid pair if a selected pair is not yet available -#[tokio::test] -async fn test_write_use_valid_pair() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - // Create a network with two interfaces - let wan = router::Router::new(router::RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?; - - wan.add_chunk_filter(Box::new(move |c: &(dyn Chunk + Send + Sync)| -> bool { - let raw = c.user_data(); - if stun::message::is_message(&raw) { - let mut m = stun::message::Message { - raw, - ..Default::default() - }; - let result = m.decode(); - if result.is_err() | m.contains(stun::attributes::ATTR_USE_CANDIDATE) { - return false; - } - } - - true - })) - .await; - let wan = Arc::new(Mutex::new(wan)); - - let net0 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.1".to_owned()], - ..Default::default() - }))); - let net1 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ips: vec!["192.168.0.2".to_owned()], - ..Default::default() - }))); - - connect_net2router(&net0, &wan).await?; - connect_net2router(&net1, &wan).await?; - start_router(&wan).await?; - - // Create two agents and connect them - let controlling_agent = Arc::new( - Agent::new(AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(Arc::clone(&net0)), - ..Default::default() - }) - .await?, - ); - - let controlled_agent = Arc::new( - Agent::new(AgentConfig { - network_types: supported_network_types(), - multicast_dns_mode: MulticastDnsMode::Disabled, - net: Some(Arc::clone(&net1)), - ..Default::default() - }) - .await?, - ); - - gather_and_exchange_candidates(&controlling_agent, &controlled_agent).await?; - - let (controlling_ufrag, controlling_pwd) = controlling_agent.get_local_user_credentials().await; - let (controlled_ufrag, controlled_pwd) = controlled_agent.get_local_user_credentials().await; - - let controlling_agent_tx = Arc::clone(&controlling_agent); - tokio::spawn(async move { - let test_message = "Test Message"; - let controlling_agent_conn = { - controlling_agent_tx - .internal - .start_connectivity_checks(true, controlled_ufrag, controlled_pwd) - .await?; - Arc::clone(&controlling_agent_tx.internal.agent_conn) as Arc - }; - - log::debug!("controlling_agent start_connectivity_checks done..."); - loop { - let result = controlling_agent_conn.send(test_message.as_bytes()).await; - if result.is_err() { - break; - } - - tokio::time::sleep(Duration::from_millis(20)).await; - } - - Result::<(), Error>::Ok(()) - }); - - let controlled_agent_conn = { - controlled_agent - .internal - .start_connectivity_checks(false, controlling_ufrag, controlling_pwd) - .await?; - Arc::clone(&controlled_agent.internal.agent_conn) as Arc - }; - - log::debug!("controlled_agent start_connectivity_checks done..."); - - let test_message = "Test Message"; - let mut read_buf = vec![0u8; test_message.as_bytes().len()]; - controlled_agent_conn.recv(&mut read_buf).await?; - - assert_eq!(read_buf, test_message.as_bytes(), "should match"); - - { - let mut w = wan.lock().await; - w.stop().await?; - } - - controlling_agent.close().await?; - controlled_agent.close().await?; - - Ok(()) -} diff --git a/ice/src/agent/mod.rs b/ice/src/agent/mod.rs deleted file mode 100644 index 389b861c9..000000000 --- a/ice/src/agent/mod.rs +++ /dev/null @@ -1,517 +0,0 @@ -#[cfg(test)] -mod agent_gather_test; -#[cfg(test)] -mod agent_test; -#[cfg(test)] -mod agent_transport_test; -#[cfg(test)] -pub(crate) mod agent_vnet_test; - -pub mod agent_config; -pub mod agent_gather; -pub(crate) mod agent_internal; -pub mod agent_selector; -pub mod agent_stats; -pub mod agent_transport; - -use std::collections::HashMap; -use std::future::Future; -use std::net::{Ipv4Addr, SocketAddr}; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use agent_config::*; -use agent_internal::*; -use agent_stats::*; -use mdns::conn::*; -use portable_atomic::{AtomicU8, AtomicUsize}; -use stun::agent::*; -use stun::attributes::*; -use stun::fingerprint::*; -use stun::integrity::*; -use stun::message::*; -use stun::xoraddr::*; -use tokio::sync::{broadcast, mpsc, Mutex}; -use tokio::time::{Duration, Instant}; -use util::vnet::net::*; -use util::Buffer; - -use crate::agent::agent_gather::GatherCandidatesInternalParams; -use crate::candidate::*; -use crate::error::*; -use crate::external_ip_mapper::*; -use crate::mdns::*; -use crate::network_type::*; -use crate::rand::*; -use crate::state::*; -use crate::tcp_type::TcpType; -use crate::udp_mux::UDPMux; -use crate::udp_network::UDPNetwork; -use crate::url::*; - -#[derive(Debug, Clone)] -pub(crate) struct BindingRequest { - pub(crate) timestamp: Instant, - pub(crate) transaction_id: TransactionId, - pub(crate) destination: SocketAddr, - pub(crate) is_use_candidate: bool, -} - -impl Default for BindingRequest { - fn default() -> Self { - Self { - timestamp: Instant::now(), - transaction_id: TransactionId::default(), - destination: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), - is_use_candidate: false, - } - } -} - -pub type OnConnectionStateChangeHdlrFn = Box< - dyn (FnMut(ConnectionState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; -pub type OnSelectedCandidatePairChangeHdlrFn = Box< - dyn (FnMut( - &Arc, - &Arc, - ) -> Pin + Send + 'static>>) - + Send - + Sync, ->; -pub type OnCandidateHdlrFn = Box< - dyn (FnMut( - Option>, - ) -> Pin + Send + 'static>>) - + Send - + Sync, ->; -pub type GatherCandidateCancelFn = Box; - -struct ChanReceivers { - chan_state_rx: mpsc::Receiver, - chan_candidate_rx: mpsc::Receiver>>, - chan_candidate_pair_rx: mpsc::Receiver<()>, -} - -/// Represents the ICE agent. -pub struct Agent { - pub(crate) internal: Arc, - - pub(crate) udp_network: UDPNetwork, - pub(crate) interface_filter: Arc>, - pub(crate) ip_filter: Arc>, - pub(crate) mdns_mode: MulticastDnsMode, - pub(crate) mdns_name: String, - pub(crate) mdns_conn: Option>, - pub(crate) net: Arc, - - // 1:1 D-NAT IP address mapping - pub(crate) ext_ip_mapper: Arc>, - pub(crate) gathering_state: Arc, //GatheringState, - pub(crate) candidate_types: Vec, - pub(crate) urls: Vec, - pub(crate) network_types: Vec, - - pub(crate) gather_candidate_cancel: Option, -} - -impl Agent { - /// Creates a new Agent. - pub async fn new(config: AgentConfig) -> Result { - let mut mdns_name = config.multicast_dns_host_name.clone(); - if mdns_name.is_empty() { - mdns_name = generate_multicast_dns_name(); - } - - if !mdns_name.ends_with(".local") || mdns_name.split('.').count() != 2 { - return Err(Error::ErrInvalidMulticastDnshostName); - } - - let mdns_mode = config.multicast_dns_mode; - - let mdns_conn = - match create_multicast_dns(mdns_mode, &mdns_name, &config.multicast_dns_dest_addr) { - Ok(c) => c, - Err(err) => { - // Opportunistic mDNS: If we can't open the connection, that's ok: we - // can continue without it. - log::warn!("Failed to initialize mDNS {}: {}", mdns_name, err); - None - } - }; - - let (mut ai, chan_receivers) = AgentInternal::new(&config); - let (chan_state_rx, chan_candidate_rx, chan_candidate_pair_rx) = ( - chan_receivers.chan_state_rx, - chan_receivers.chan_candidate_rx, - chan_receivers.chan_candidate_pair_rx, - ); - - config.init_with_defaults(&mut ai); - - let candidate_types = if config.candidate_types.is_empty() { - default_candidate_types() - } else { - config.candidate_types.clone() - }; - - if ai.lite.load(Ordering::SeqCst) - && (candidate_types.len() != 1 || candidate_types[0] != CandidateType::Host) - { - Self::close_multicast_conn(&mdns_conn).await; - return Err(Error::ErrLiteUsingNonHostCandidates); - } - - if !config.urls.is_empty() - && !contains_candidate_type(CandidateType::ServerReflexive, &candidate_types) - && !contains_candidate_type(CandidateType::Relay, &candidate_types) - { - Self::close_multicast_conn(&mdns_conn).await; - return Err(Error::ErrUselessUrlsProvided); - } - - let ext_ip_mapper = match config.init_ext_ip_mapping(mdns_mode, &candidate_types) { - Ok(ext_ip_mapper) => ext_ip_mapper, - Err(err) => { - Self::close_multicast_conn(&mdns_conn).await; - return Err(err); - } - }; - - let net = if let Some(net) = config.net { - if net.is_virtual() { - log::warn!("vnet is enabled"); - if mdns_mode != MulticastDnsMode::Disabled { - log::warn!("vnet does not support mDNS yet"); - } - } - - net - } else { - Arc::new(Net::new(None)) - }; - - let agent = Self { - udp_network: config.udp_network, - internal: Arc::new(ai), - interface_filter: Arc::clone(&config.interface_filter), - ip_filter: Arc::clone(&config.ip_filter), - mdns_mode, - mdns_name, - mdns_conn, - net, - ext_ip_mapper: Arc::new(ext_ip_mapper), - gathering_state: Arc::new(AtomicU8::new(0)), //GatheringState::New, - candidate_types, - urls: config.urls.clone(), - network_types: config.network_types.clone(), - - gather_candidate_cancel: None, //TODO: add cancel - }; - - agent.internal.start_on_connection_state_change_routine( - chan_state_rx, - chan_candidate_rx, - chan_candidate_pair_rx, - ); - - // Restart is also used to initialize the agent for the first time - if let Err(err) = agent.restart(config.local_ufrag, config.local_pwd).await { - Self::close_multicast_conn(&agent.mdns_conn).await; - let _ = agent.close().await; - return Err(err); - } - - Ok(agent) - } - - pub fn get_bytes_received(&self) -> usize { - self.internal.agent_conn.bytes_received() - } - - pub fn get_bytes_sent(&self) -> usize { - self.internal.agent_conn.bytes_sent() - } - - /// Sets a handler that is fired when the connection state changes. - pub fn on_connection_state_change(&self, f: OnConnectionStateChangeHdlrFn) { - self.internal - .on_connection_state_change_hdlr - .store(Some(Arc::new(Mutex::new(f)))) - } - - /// Sets a handler that is fired when the final candidate pair is selected. - pub fn on_selected_candidate_pair_change(&self, f: OnSelectedCandidatePairChangeHdlrFn) { - self.internal - .on_selected_candidate_pair_change_hdlr - .store(Some(Arc::new(Mutex::new(f)))) - } - - /// Sets a handler that is fired when new candidates gathered. When the gathering process - /// complete the last candidate is nil. - pub fn on_candidate(&self, f: OnCandidateHdlrFn) { - self.internal - .on_candidate_hdlr - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// Adds a new remote candidate. - pub fn add_remote_candidate(&self, c: &Arc) -> Result<()> { - // cannot check for network yet because it might not be applied - // when mDNS hostame is used. - if c.tcp_type() == TcpType::Active { - // TCP Candidates with tcptype active will probe server passive ones, so - // no need to do anything with them. - log::info!("Ignoring remote candidate with tcpType active: {}", c); - return Ok(()); - } - - // If we have a mDNS Candidate lets fully resolve it before adding it locally - if c.candidate_type() == CandidateType::Host && c.address().ends_with(".local") { - if self.mdns_mode == MulticastDnsMode::Disabled { - log::warn!( - "remote mDNS candidate added, but mDNS is disabled: ({})", - c.address() - ); - return Ok(()); - } - - if c.candidate_type() != CandidateType::Host { - return Err(Error::ErrAddressParseFailed); - } - - let ai = Arc::clone(&self.internal); - let host_candidate = Arc::clone(c); - let mdns_conn = self.mdns_conn.clone(); - tokio::spawn(async move { - if let Some(mdns_conn) = mdns_conn { - if let Ok(candidate) = - Self::resolve_and_add_multicast_candidate(mdns_conn, host_candidate).await - { - ai.add_remote_candidate(&candidate).await; - } - } - }); - } else { - let ai = Arc::clone(&self.internal); - let candidate = Arc::clone(c); - tokio::spawn(async move { - ai.add_remote_candidate(&candidate).await; - }); - } - - Ok(()) - } - - /// Returns the local candidates. - pub async fn get_local_candidates(&self) -> Result>> { - let mut res = vec![]; - - { - let local_candidates = self.internal.local_candidates.lock().await; - for candidates in local_candidates.values() { - for candidate in candidates { - res.push(Arc::clone(candidate)); - } - } - } - - Ok(res) - } - - /// Returns the local user credentials. - pub async fn get_local_user_credentials(&self) -> (String, String) { - let ufrag_pwd = self.internal.ufrag_pwd.lock().await; - (ufrag_pwd.local_ufrag.clone(), ufrag_pwd.local_pwd.clone()) - } - - /// Returns the remote user credentials. - pub async fn get_remote_user_credentials(&self) -> (String, String) { - let ufrag_pwd = self.internal.ufrag_pwd.lock().await; - (ufrag_pwd.remote_ufrag.clone(), ufrag_pwd.remote_pwd.clone()) - } - - /// Cleans up the Agent. - pub async fn close(&self) -> Result<()> { - if let Some(gather_candidate_cancel) = &self.gather_candidate_cancel { - gather_candidate_cancel(); - } - - if let UDPNetwork::Muxed(ref udp_mux) = self.udp_network { - let (ufrag, _) = self.get_local_user_credentials().await; - udp_mux.remove_conn_by_ufrag(&ufrag).await; - } - - //FIXME: deadlock here - self.internal.close().await - } - - /// Returns the selected pair or nil if there is none - pub fn get_selected_candidate_pair(&self) -> Option> { - self.internal.agent_conn.get_selected_pair() - } - - /// Sets the credentials of the remote agent. - pub async fn set_remote_credentials( - &self, - remote_ufrag: String, - remote_pwd: String, - ) -> Result<()> { - self.internal - .set_remote_credentials(remote_ufrag, remote_pwd) - .await - } - - /// Restarts the ICE Agent with the provided ufrag/pwd - /// If no ufrag/pwd is provided the Agent will generate one itself. - /// - /// Restart must only be called when `GatheringState` is `GatheringStateComplete` - /// a user must then call `GatherCandidates` explicitly to start generating new ones. - pub async fn restart(&self, mut ufrag: String, mut pwd: String) -> Result<()> { - if ufrag.is_empty() { - ufrag = generate_ufrag(); - } - if pwd.is_empty() { - pwd = generate_pwd(); - } - - if ufrag.len() * 8 < 24 { - return Err(Error::ErrLocalUfragInsufficientBits); - } - if pwd.len() * 8 < 128 { - return Err(Error::ErrLocalPwdInsufficientBits); - } - - if GatheringState::from(self.gathering_state.load(Ordering::SeqCst)) - == GatheringState::Gathering - { - return Err(Error::ErrRestartWhenGathering); - } - self.gathering_state - .store(GatheringState::New as u8, Ordering::SeqCst); - - { - let done_tx = self.internal.done_tx.lock().await; - if done_tx.is_none() { - return Err(Error::ErrClosed); - } - } - - // Clear all agent needed to take back to fresh state - { - let mut ufrag_pwd = self.internal.ufrag_pwd.lock().await; - ufrag_pwd.local_ufrag = ufrag; - ufrag_pwd.local_pwd = pwd; - ufrag_pwd.remote_ufrag = String::new(); - ufrag_pwd.remote_pwd = String::new(); - } - { - let mut pending_binding_requests = self.internal.pending_binding_requests.lock().await; - *pending_binding_requests = vec![]; - } - - { - let mut checklist = self.internal.agent_conn.checklist.lock().await; - *checklist = vec![]; - } - - self.internal.set_selected_pair(None).await; - self.internal.delete_all_candidates().await; - self.internal.start().await; - - // Restart is used by NewAgent. Accept/Connect should be used to move to checking - // for new Agents - if self.internal.connection_state.load(Ordering::SeqCst) != ConnectionState::New as u8 { - self.internal - .update_connection_state(ConnectionState::Checking) - .await; - } - - Ok(()) - } - - /// Initiates the trickle based gathering process. - pub fn gather_candidates(&self) -> Result<()> { - if self.gathering_state.load(Ordering::SeqCst) != GatheringState::New as u8 { - return Err(Error::ErrMultipleGatherAttempted); - } - - if self.internal.on_candidate_hdlr.load().is_none() { - return Err(Error::ErrNoOnCandidateHandler); - } - - if let Some(gather_candidate_cancel) = &self.gather_candidate_cancel { - gather_candidate_cancel(); // Cancel previous gathering routine - } - - //TODO: a.gatherCandidateCancel = cancel - - let params = GatherCandidatesInternalParams { - udp_network: self.udp_network.clone(), - candidate_types: self.candidate_types.clone(), - urls: self.urls.clone(), - network_types: self.network_types.clone(), - mdns_mode: self.mdns_mode, - mdns_name: self.mdns_name.clone(), - net: Arc::clone(&self.net), - interface_filter: self.interface_filter.clone(), - ip_filter: self.ip_filter.clone(), - ext_ip_mapper: Arc::clone(&self.ext_ip_mapper), - agent_internal: Arc::clone(&self.internal), - gathering_state: Arc::clone(&self.gathering_state), - chan_candidate_tx: Arc::clone(&self.internal.chan_candidate_tx), - }; - tokio::spawn(async move { - Self::gather_candidates_internal(params).await; - }); - - Ok(()) - } - - /// Returns a list of candidate pair stats. - pub async fn get_candidate_pairs_stats(&self) -> Vec { - self.internal.get_candidate_pairs_stats().await - } - - /// Returns a list of local candidates stats. - pub async fn get_local_candidates_stats(&self) -> Vec { - self.internal.get_local_candidates_stats().await - } - - /// Returns a list of remote candidates stats. - pub async fn get_remote_candidates_stats(&self) -> Vec { - self.internal.get_remote_candidates_stats().await - } - - async fn resolve_and_add_multicast_candidate( - mdns_conn: Arc, - c: Arc, - ) -> Result> { - //TODO: hook up _close_query_signal_tx to Agent or Candidate's Close signal? - let (_close_query_signal_tx, close_query_signal_rx) = mpsc::channel(1); - let src = match mdns_conn.query(&c.address(), close_query_signal_rx).await { - Ok((_, src)) => src, - Err(err) => { - log::warn!("Failed to discover mDNS candidate {}: {}", c.address(), err); - return Err(err.into()); - } - }; - - c.set_ip(&src.ip())?; - - Ok(c) - } - - async fn close_multicast_conn(mdns_conn: &Option>) { - if let Some(conn) = mdns_conn { - if let Err(err) = conn.close().await { - log::warn!("failed to close mDNS Conn: {}", err); - } - } - } -} diff --git a/ice/src/candidate/candidate_base.rs b/ice/src/candidate/candidate_base.rs deleted file mode 100644 index 08125672c..000000000 --- a/ice/src/candidate/candidate_base.rs +++ /dev/null @@ -1,525 +0,0 @@ -use std::fmt; -use std::ops::Add; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use async_trait::async_trait; -use crc::{Crc, CRC_32_ISCSI}; -use portable_atomic::{AtomicU16, AtomicU64, AtomicU8}; -use tokio::sync::{broadcast, Mutex}; -use util::sync::Mutex as SyncMutex; - -use super::*; -use crate::candidate::candidate_host::CandidateHostConfig; -use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; -use crate::candidate::candidate_relay::CandidateRelayConfig; -use crate::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; -use crate::error::*; -use crate::util::*; - -#[derive(Default)] -pub struct CandidateBaseConfig { - pub candidate_id: String, - pub network: String, - pub address: String, - pub port: u16, - pub component: u16, - pub priority: u32, - pub foundation: String, - pub conn: Option>, - pub initialized_ch: Option>, -} - -pub struct CandidateBase { - pub(crate) id: String, - pub(crate) network_type: AtomicU8, - pub(crate) candidate_type: CandidateType, - - pub(crate) component: AtomicU16, - pub(crate) address: String, - pub(crate) port: u16, - pub(crate) related_address: Option, - pub(crate) tcp_type: TcpType, - - pub(crate) resolved_addr: SyncMutex, - - pub(crate) last_sent: AtomicU64, - pub(crate) last_received: AtomicU64, - - pub(crate) conn: Option>, - pub(crate) closed_ch: Arc>>>, - - pub(crate) foundation_override: String, - pub(crate) priority_override: u32, - - //CandidateHost - pub(crate) network: String, - //CandidateRelay - pub(crate) relay_client: Option>, -} - -impl Default for CandidateBase { - fn default() -> Self { - Self { - id: String::new(), - network_type: AtomicU8::new(0), - candidate_type: CandidateType::default(), - - component: AtomicU16::new(0), - address: String::new(), - port: 0, - related_address: None, - tcp_type: TcpType::default(), - - resolved_addr: SyncMutex::new(SocketAddr::new(IpAddr::from([0, 0, 0, 0]), 0)), - - last_sent: AtomicU64::new(0), - last_received: AtomicU64::new(0), - - conn: None, - closed_ch: Arc::new(Mutex::new(None)), - - foundation_override: String::new(), - priority_override: 0, - network: String::new(), - relay_client: None, - } - } -} - -// String makes the candidateBase printable -impl fmt::Display for CandidateBase { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(related_address) = self.related_address() { - write!( - f, - "{} {} {}:{}{}", - self.network_type(), - self.candidate_type(), - self.address(), - self.port(), - related_address, - ) - } else { - write!( - f, - "{} {} {}:{}", - self.network_type(), - self.candidate_type(), - self.address(), - self.port(), - ) - } - } -} - -#[async_trait] -impl Candidate for CandidateBase { - fn foundation(&self) -> String { - if !self.foundation_override.is_empty() { - return self.foundation_override.clone(); - } - - let mut buf = vec![]; - buf.extend_from_slice(self.candidate_type().to_string().as_bytes()); - buf.extend_from_slice(self.address.as_bytes()); - buf.extend_from_slice(self.network_type().to_string().as_bytes()); - - let checksum = Crc::::new(&CRC_32_ISCSI).checksum(&buf); - - format!("{checksum}") - } - - /// Returns Candidate ID. - fn id(&self) -> String { - self.id.clone() - } - - /// Returns candidate component. - fn component(&self) -> u16 { - self.component.load(Ordering::SeqCst) - } - - fn set_component(&self, component: u16) { - self.component.store(component, Ordering::SeqCst); - } - - /// Returns a time indicating the last time this candidate was received. - fn last_received(&self) -> SystemTime { - UNIX_EPOCH.add(Duration::from_nanos( - self.last_received.load(Ordering::SeqCst), - )) - } - - /// Returns a time indicating the last time this candidate was sent. - fn last_sent(&self) -> SystemTime { - UNIX_EPOCH.add(Duration::from_nanos(self.last_sent.load(Ordering::SeqCst))) - } - - /// Returns candidate NetworkType. - fn network_type(&self) -> NetworkType { - NetworkType::from(self.network_type.load(Ordering::SeqCst)) - } - - /// Returns Candidate Address. - fn address(&self) -> String { - self.address.clone() - } - - /// Returns Candidate Port. - fn port(&self) -> u16 { - self.port - } - - /// Computes the priority for this ICE Candidate. - fn priority(&self) -> u32 { - if self.priority_override != 0 { - return self.priority_override; - } - - // The local preference MUST be an integer from 0 (lowest preference) to - // 65535 (highest preference) inclusive. When there is only a single IP - // address, this value SHOULD be set to 65535. If there are multiple - // candidates for a particular component for a particular data stream - // that have the same type, the local preference MUST be unique for each - // one. - (1 << 24) * u32::from(self.candidate_type().preference()) - + (1 << 8) * u32::from(self.local_preference()) - + (256 - u32::from(self.component())) - } - - /// Returns `Option`. - fn related_address(&self) -> Option { - self.related_address.as_ref().cloned() - } - - /// Returns candidate type. - fn candidate_type(&self) -> CandidateType { - self.candidate_type - } - - fn tcp_type(&self) -> TcpType { - self.tcp_type - } - - /// Returns the string representation of the ICECandidate. - fn marshal(&self) -> String { - let mut val = format!( - "{} {} {} {} {} {} typ {}", - self.foundation(), - self.component(), - self.network_type().network_short(), - self.priority(), - self.address(), - self.port(), - self.candidate_type() - ); - - if self.tcp_type != TcpType::Unspecified { - val += format!(" tcptype {}", self.tcp_type()).as_str(); - } - - if let Some(related_address) = self.related_address() { - val += format!( - " raddr {} rport {}", - related_address.address, related_address.port, - ) - .as_str(); - } - - val - } - - fn addr(&self) -> SocketAddr { - *self.resolved_addr.lock() - } - - /// Stops the recvLoop. - async fn close(&self) -> Result<()> { - { - let mut closed_ch = self.closed_ch.lock().await; - if closed_ch.is_none() { - return Err(Error::ErrClosed); - } - closed_ch.take(); - } - - if let Some(relay_client) = &self.relay_client { - let _ = relay_client.close().await; - } - - if let Some(conn) = &self.conn { - let _ = conn.close().await; - } - - Ok(()) - } - - fn seen(&self, outbound: bool) { - let d = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_else(|_| Duration::from_secs(0)); - - if outbound { - self.set_last_sent(d); - } else { - self.set_last_received(d); - } - } - - async fn write_to(&self, raw: &[u8], dst: &(dyn Candidate + Send + Sync)) -> Result { - let n = if let Some(conn) = &self.conn { - let addr = dst.addr(); - conn.send_to(raw, addr).await? - } else { - 0 - }; - self.seen(true); - Ok(n) - } - - /// Used to compare two candidateBases. - fn equal(&self, other: &dyn Candidate) -> bool { - self.network_type() == other.network_type() - && self.candidate_type() == other.candidate_type() - && self.address() == other.address() - && self.port() == other.port() - && self.tcp_type() == other.tcp_type() - && self.related_address() == other.related_address() - } - - fn set_ip(&self, ip: &IpAddr) -> Result<()> { - let network_type = determine_network_type(&self.network, ip)?; - - self.network_type - .store(network_type as u8, Ordering::SeqCst); - - let addr = create_addr(network_type, *ip, self.port); - *self.resolved_addr.lock() = addr; - - Ok(()) - } - - fn get_conn(&self) -> Option<&Arc> { - self.conn.as_ref() - } - - fn get_closed_ch(&self) -> Arc>>> { - self.closed_ch.clone() - } -} - -impl CandidateBase { - pub fn set_last_received(&self, d: Duration) { - #[allow(clippy::cast_possible_truncation)] - self.last_received - .store(d.as_nanos() as u64, Ordering::SeqCst); - } - - pub fn set_last_sent(&self, d: Duration) { - #[allow(clippy::cast_possible_truncation)] - self.last_sent.store(d.as_nanos() as u64, Ordering::SeqCst); - } - - /// Returns the local preference for this candidate. - pub fn local_preference(&self) -> u16 { - if self.network_type().is_tcp() { - // RFC 6544, section 4.2 - // - // In Section 4.1.2.1 of [RFC5245], a recommended formula for UDP ICE - // candidate prioritization is defined. For TCP candidates, the same - // formula and candidate type preferences SHOULD be used, and the - // RECOMMENDED type preferences for the new candidate types defined in - // this document (see Section 5) are 105 for NAT-assisted candidates and - // 75 for UDP-tunneled candidates. - // - // (...) - // - // With TCP candidates, the local preference part of the recommended - // priority formula is updated to also include the directionality - // (active, passive, or simultaneous-open) of the TCP connection. The - // RECOMMENDED local preference is then defined as: - // - // local preference = (2^13) * direction-pref + other-pref - // - // The direction-pref MUST be between 0 and 7 (both inclusive), with 7 - // being the most preferred. The other-pref MUST be between 0 and 8191 - // (both inclusive), with 8191 being the most preferred. It is - // RECOMMENDED that the host, UDP-tunneled, and relayed TCP candidates - // have the direction-pref assigned as follows: 6 for active, 4 for - // passive, and 2 for S-O. For the NAT-assisted and server reflexive - // candidates, the RECOMMENDED values are: 6 for S-O, 4 for active, and - // 2 for passive. - // - // (...) - // - // If any two candidates have the same type-preference and direction- - // pref, they MUST have a unique other-pref. With this specification, - // this usually only happens with multi-homed hosts, in which case - // other-pref is the preference for the particular IP address from which - // the candidate was obtained. When there is only a single IP address, - // this value SHOULD be set to the maximum allowed value (8191). - let other_pref: u16 = 8191; - - let direction_pref: u16 = match self.candidate_type() { - CandidateType::Host | CandidateType::Relay => match self.tcp_type() { - TcpType::Active => 6, - TcpType::Passive => 4, - TcpType::SimultaneousOpen => 2, - TcpType::Unspecified => 0, - }, - CandidateType::PeerReflexive | CandidateType::ServerReflexive => { - match self.tcp_type() { - TcpType::SimultaneousOpen => 6, - TcpType::Active => 4, - TcpType::Passive => 2, - TcpType::Unspecified => 0, - } - } - CandidateType::Unspecified => 0, - }; - - (1 << 13) * direction_pref + other_pref - } else { - DEFAULT_LOCAL_PREFERENCE - } - } -} - -/// Creates a Candidate from its string representation. -pub fn unmarshal_candidate(raw: &str) -> Result { - let split: Vec<&str> = raw.split_whitespace().collect(); - if split.len() < 8 { - return Err(Error::Other(format!( - "{:?} ({})", - Error::ErrAttributeTooShortIceCandidate, - split.len() - ))); - } - - // Foundation - let foundation = split[0].to_owned(); - - // Component - let component: u16 = split[1].parse()?; - - // Network - let network = split[2].to_owned(); - - // Priority - let priority: u32 = split[3].parse()?; - - // Address - let address = split[4].to_owned(); - - // Port - let port: u16 = split[5].parse()?; - - let typ = split[7]; - - let mut rel_addr = String::new(); - let mut rel_port = 0; - let mut tcp_type = TcpType::Unspecified; - - if split.len() > 8 { - let split2 = &split[8..]; - - if split2[0] == "raddr" { - if split2.len() < 4 { - return Err(Error::Other(format!( - "{:?}: incorrect length", - Error::ErrParseRelatedAddr - ))); - } - - // RelatedAddress - split2[1].clone_into(&mut rel_addr); - - // RelatedPort - rel_port = split2[3].parse()?; - } else if split2[0] == "tcptype" { - if split2.len() < 2 { - return Err(Error::Other(format!( - "{:?}: incorrect length", - Error::ErrParseType - ))); - } - - tcp_type = TcpType::from(split2[1]); - } - } - - match typ { - "host" => { - let config = CandidateHostConfig { - base_config: CandidateBaseConfig { - network, - address, - port, - component, - priority, - foundation, - ..CandidateBaseConfig::default() - }, - tcp_type, - }; - config.new_candidate_host() - } - "srflx" => { - let config = CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network, - address, - port, - component, - priority, - foundation, - ..CandidateBaseConfig::default() - }, - rel_addr, - rel_port, - }; - config.new_candidate_server_reflexive() - } - "prflx" => { - let config = CandidatePeerReflexiveConfig { - base_config: CandidateBaseConfig { - network, - address, - port, - component, - priority, - foundation, - ..CandidateBaseConfig::default() - }, - rel_addr, - rel_port, - }; - - config.new_candidate_peer_reflexive() - } - "relay" => { - let config = CandidateRelayConfig { - base_config: CandidateBaseConfig { - network, - address, - port, - component, - priority, - foundation, - ..CandidateBaseConfig::default() - }, - rel_addr, - rel_port, - ..CandidateRelayConfig::default() - }; - config.new_candidate_relay() - } - _ => Err(Error::Other(format!( - "{:?} ({})", - Error::ErrUnknownCandidateType, - typ - ))), - } -} diff --git a/ice/src/candidate/candidate_host.rs b/ice/src/candidate/candidate_host.rs deleted file mode 100644 index 6cc2441ba..000000000 --- a/ice/src/candidate/candidate_host.rs +++ /dev/null @@ -1,45 +0,0 @@ -use portable_atomic::{AtomicU16, AtomicU8}; - -use super::candidate_base::*; -use super::*; -use crate::rand::generate_cand_id; - -/// The config required to create a new `CandidateHost`. -#[derive(Default)] -pub struct CandidateHostConfig { - pub base_config: CandidateBaseConfig, - - pub tcp_type: TcpType, -} - -impl CandidateHostConfig { - /// Creates a new host candidate. - pub fn new_candidate_host(self) -> Result { - let mut candidate_id = self.base_config.candidate_id; - if candidate_id.is_empty() { - candidate_id = generate_cand_id(); - } - - let c = CandidateBase { - id: candidate_id, - address: self.base_config.address.clone(), - candidate_type: CandidateType::Host, - component: AtomicU16::new(self.base_config.component), - port: self.base_config.port, - tcp_type: self.tcp_type, - foundation_override: self.base_config.foundation, - priority_override: self.base_config.priority, - network: self.base_config.network, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - conn: self.base_config.conn, - ..CandidateBase::default() - }; - - if !self.base_config.address.ends_with(".local") { - let ip = self.base_config.address.parse()?; - c.set_ip(&ip)?; - }; - - Ok(c) - } -} diff --git a/ice/src/candidate/candidate_pair_test.rs b/ice/src/candidate/candidate_pair_test.rs deleted file mode 100644 index 7b2765a82..000000000 --- a/ice/src/candidate/candidate_pair_test.rs +++ /dev/null @@ -1,155 +0,0 @@ -use super::*; -use crate::candidate::candidate_host::CandidateHostConfig; -use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; -use crate::candidate::candidate_relay::CandidateRelayConfig; -use crate::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; - -pub(crate) fn host_candidate() -> Result { - CandidateHostConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "0.0.0.0".to_owned(), - component: COMPONENT_RTP, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_host() -} - -pub(crate) fn prflx_candidate() -> Result { - CandidatePeerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "0.0.0.0".to_owned(), - component: COMPONENT_RTP, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_peer_reflexive() -} - -pub(crate) fn srflx_candidate() -> Result { - CandidateServerReflexiveConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "0.0.0.0".to_owned(), - component: COMPONENT_RTP, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_server_reflexive() -} - -pub(crate) fn relay_candidate() -> Result { - CandidateRelayConfig { - base_config: CandidateBaseConfig { - network: "udp".to_owned(), - address: "0.0.0.0".to_owned(), - component: COMPONENT_RTP, - ..Default::default() - }, - ..Default::default() - } - .new_candidate_relay() -} - -#[test] -fn test_candidate_pair_priority() -> Result<()> { - let tests = vec![ - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(host_candidate()?), - false, - ), - 9151314440652587007, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(host_candidate()?), - true, - ), - 9151314440652587007, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(prflx_candidate()?), - true, - ), - 7998392936314175488, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(prflx_candidate()?), - false, - ), - 7998392936314175487, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(srflx_candidate()?), - true, - ), - 7277816996102668288, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(srflx_candidate()?), - false, - ), - 7277816996102668287, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(relay_candidate()?), - true, - ), - 72057593987596288, - ), - ( - CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(relay_candidate()?), - false, - ), - 72057593987596287, - ), - ]; - - for (pair, want) in tests { - let got = pair.priority(); - assert_eq!( - got, want, - "CandidatePair({pair}).Priority() = {got}, want {want}" - ); - } - - Ok(()) -} - -#[test] -fn test_candidate_pair_equality() -> Result<()> { - let pair_a = CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(srflx_candidate()?), - true, - ); - let pair_b = CandidatePair::new( - Arc::new(host_candidate()?), - Arc::new(srflx_candidate()?), - false, - ); - - assert_eq!(pair_a, pair_b, "Expected {pair_a} to equal {pair_b}"); - - Ok(()) -} diff --git a/ice/src/candidate/candidate_peer_reflexive.rs b/ice/src/candidate/candidate_peer_reflexive.rs deleted file mode 100644 index dbb0b7d27..000000000 --- a/ice/src/candidate/candidate_peer_reflexive.rs +++ /dev/null @@ -1,54 +0,0 @@ -use portable_atomic::{AtomicU16, AtomicU8}; - -use util::sync::Mutex as SyncMutex; - -use super::candidate_base::*; -use super::*; -use crate::error::*; -use crate::rand::generate_cand_id; -use crate::util::*; - -/// The config required to create a new `CandidatePeerReflexive`. -#[derive(Default)] -pub struct CandidatePeerReflexiveConfig { - pub base_config: CandidateBaseConfig, - - pub rel_addr: String, - pub rel_port: u16, -} - -impl CandidatePeerReflexiveConfig { - /// Creates a new peer reflective candidate. - pub fn new_candidate_peer_reflexive(self) -> Result { - let ip: IpAddr = match self.base_config.address.parse() { - Ok(ip) => ip, - Err(_) => return Err(Error::ErrAddressParseFailed), - }; - let network_type = determine_network_type(&self.base_config.network, &ip)?; - - let mut candidate_id = self.base_config.candidate_id; - if candidate_id.is_empty() { - candidate_id = generate_cand_id(); - } - - let c = CandidateBase { - id: candidate_id, - network_type: AtomicU8::new(network_type as u8), - candidate_type: CandidateType::PeerReflexive, - address: self.base_config.address, - port: self.base_config.port, - resolved_addr: SyncMutex::new(create_addr(network_type, ip, self.base_config.port)), - component: AtomicU16::new(self.base_config.component), - foundation_override: self.base_config.foundation, - priority_override: self.base_config.priority, - related_address: Some(CandidateRelatedAddress { - address: self.rel_addr, - port: self.rel_port, - }), - conn: self.base_config.conn, - ..CandidateBase::default() - }; - - Ok(c) - } -} diff --git a/ice/src/candidate/candidate_relay.rs b/ice/src/candidate/candidate_relay.rs deleted file mode 100644 index 5a4548412..000000000 --- a/ice/src/candidate/candidate_relay.rs +++ /dev/null @@ -1,57 +0,0 @@ -use portable_atomic::{AtomicU16, AtomicU8}; -use std::sync::Arc; - -use util::sync::Mutex as SyncMutex; - -use super::candidate_base::*; -use super::*; -use crate::error::*; -use crate::rand::generate_cand_id; -use crate::util::*; - -/// The config required to create a new `CandidateRelay`. -#[derive(Default)] -pub struct CandidateRelayConfig { - pub base_config: CandidateBaseConfig, - - pub rel_addr: String, - pub rel_port: u16, - pub relay_client: Option>, -} - -impl CandidateRelayConfig { - /// Creates a new relay candidate. - pub fn new_candidate_relay(self) -> Result { - let mut candidate_id = self.base_config.candidate_id; - if candidate_id.is_empty() { - candidate_id = generate_cand_id(); - } - - let ip: IpAddr = match self.base_config.address.parse() { - Ok(ip) => ip, - Err(_) => return Err(Error::ErrAddressParseFailed), - }; - let network_type = determine_network_type(&self.base_config.network, &ip)?; - - let c = CandidateBase { - id: candidate_id, - network_type: AtomicU8::new(network_type as u8), - candidate_type: CandidateType::Relay, - address: self.base_config.address, - port: self.base_config.port, - resolved_addr: SyncMutex::new(create_addr(network_type, ip, self.base_config.port)), - component: AtomicU16::new(self.base_config.component), - foundation_override: self.base_config.foundation, - priority_override: self.base_config.priority, - related_address: Some(CandidateRelatedAddress { - address: self.rel_addr, - port: self.rel_port, - }), - conn: self.base_config.conn, - relay_client: self.relay_client.clone(), - ..CandidateBase::default() - }; - - Ok(c) - } -} diff --git a/ice/src/candidate/candidate_relay_test.rs b/ice/src/candidate/candidate_relay_test.rs deleted file mode 100644 index c1fd4bfd0..000000000 --- a/ice/src/candidate/candidate_relay_test.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::result::Result; -use std::time::Duration; - -use tokio::net::UdpSocket; -use turn::auth::AuthHandler; - -use super::*; -use crate::agent::agent_config::AgentConfig; -use crate::agent::agent_vnet_test::{connect_with_vnet, on_connected}; -use crate::agent::Agent; -use crate::error::Error; -use crate::url::{ProtoType, SchemeType, Url}; - -pub(crate) struct OptimisticAuthHandler; - -impl AuthHandler for OptimisticAuthHandler { - fn auth_handle( - &self, - _username: &str, - _realm: &str, - _src_addr: SocketAddr, - ) -> Result, turn::Error> { - Ok(turn::auth::generate_auth_key( - "username", - "webrtc.rs", - "password", - )) - } -} - -//use std::io::Write; - -#[tokio::test] -async fn test_relay_only_connection() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let server_listener = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); - let server_port = server_listener.local_addr()?.port(); - - let server = turn::server::Server::new(turn::server::config::ServerConfig { - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(OptimisticAuthHandler {}), - conn_configs: vec![turn::server::config::ConnConfig { - conn: server_listener, - relay_addr_generator: Box::new(turn::relay::relay_none::RelayAddressGeneratorNone { - address: "127.0.0.1".to_owned(), - net: Arc::new(util::vnet::net::Net::new(None)), - }), - }], - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - let cfg0 = AgentConfig { - network_types: supported_network_types(), - urls: vec![Url { - scheme: SchemeType::Turn, - host: "127.0.0.1".to_owned(), - username: "username".to_owned(), - password: "password".to_owned(), - port: server_port, - proto: ProtoType::Udp, - }], - candidate_types: vec![CandidateType::Relay], - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); - - let cfg1 = AgentConfig { - network_types: supported_network_types(), - urls: vec![Url { - scheme: SchemeType::Turn, - host: "127.0.0.1".to_owned(), - username: "username".to_owned(), - password: "password".to_owned(), - port: server_port, - proto: ProtoType::Udp, - }], - candidate_types: vec![CandidateType::Relay], - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); - - connect_with_vnet(&a_agent, &b_agent).await?; - - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - a_agent.close().await?; - b_agent.close().await?; - server.close().await?; - - Ok(()) -} diff --git a/ice/src/candidate/candidate_server_reflexive.rs b/ice/src/candidate/candidate_server_reflexive.rs deleted file mode 100644 index c8e9133bb..000000000 --- a/ice/src/candidate/candidate_server_reflexive.rs +++ /dev/null @@ -1,54 +0,0 @@ -use portable_atomic::{AtomicU16, AtomicU8}; - -use util::sync::Mutex as SyncMutex; - -use super::candidate_base::*; -use super::*; -use crate::error::*; -use crate::rand::generate_cand_id; -use crate::util::*; - -/// The config required to create a new `CandidateServerReflexive`. -#[derive(Default)] -pub struct CandidateServerReflexiveConfig { - pub base_config: CandidateBaseConfig, - - pub rel_addr: String, - pub rel_port: u16, -} - -impl CandidateServerReflexiveConfig { - /// Creates a new server reflective candidate. - pub fn new_candidate_server_reflexive(self) -> Result { - let ip: IpAddr = match self.base_config.address.parse() { - Ok(ip) => ip, - Err(_) => return Err(Error::ErrAddressParseFailed), - }; - let network_type = determine_network_type(&self.base_config.network, &ip)?; - - let mut candidate_id = self.base_config.candidate_id; - if candidate_id.is_empty() { - candidate_id = generate_cand_id(); - } - - let c = CandidateBase { - id: candidate_id, - network_type: AtomicU8::new(network_type as u8), - candidate_type: CandidateType::ServerReflexive, - address: self.base_config.address, - port: self.base_config.port, - resolved_addr: SyncMutex::new(create_addr(network_type, ip, self.base_config.port)), - component: AtomicU16::new(self.base_config.component), - foundation_override: self.base_config.foundation, - priority_override: self.base_config.priority, - related_address: Some(CandidateRelatedAddress { - address: self.rel_addr, - port: self.rel_port, - }), - conn: self.base_config.conn, - ..CandidateBase::default() - }; - - Ok(c) - } -} diff --git a/ice/src/candidate/candidate_server_reflexive_test.rs b/ice/src/candidate/candidate_server_reflexive_test.rs deleted file mode 100644 index ca40de1a7..000000000 --- a/ice/src/candidate/candidate_server_reflexive_test.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::time::Duration; - -use tokio::net::UdpSocket; - -use super::candidate_relay_test::OptimisticAuthHandler; -use super::*; -use crate::agent::agent_config::AgentConfig; -use crate::agent::agent_vnet_test::{connect_with_vnet, on_connected}; -use crate::agent::Agent; -use crate::url::{SchemeType, Url}; - -//use std::io::Write; - -#[tokio::test] -async fn test_server_reflexive_only_connection() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let server_listener = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); - let server_port = server_listener.local_addr()?.port(); - - let server = turn::server::Server::new(turn::server::config::ServerConfig { - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(OptimisticAuthHandler {}), - conn_configs: vec![turn::server::config::ConnConfig { - conn: server_listener, - relay_addr_generator: Box::new(turn::relay::relay_none::RelayAddressGeneratorNone { - address: "127.0.0.1".to_owned(), - net: Arc::new(util::vnet::net::Net::new(None)), - }), - }], - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - let cfg0 = AgentConfig { - network_types: vec![NetworkType::Udp4], - urls: vec![Url { - scheme: SchemeType::Stun, - host: "127.0.0.1".to_owned(), - port: server_port, - ..Default::default() - }], - candidate_types: vec![CandidateType::ServerReflexive], - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); - - let cfg1 = AgentConfig { - network_types: vec![NetworkType::Udp4], - urls: vec![Url { - scheme: SchemeType::Stun, - host: "127.0.0.1".to_owned(), - port: server_port, - ..Default::default() - }], - candidate_types: vec![CandidateType::ServerReflexive], - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); - - connect_with_vnet(&a_agent, &b_agent).await?; - - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - a_agent.close().await?; - b_agent.close().await?; - server.close().await?; - - Ok(()) -} diff --git a/ice/src/candidate/candidate_test.rs b/ice/src/candidate/candidate_test.rs deleted file mode 100644 index b9f2928c7..000000000 --- a/ice/src/candidate/candidate_test.rs +++ /dev/null @@ -1,411 +0,0 @@ -use std::time::UNIX_EPOCH; - -use super::*; - -#[test] -fn test_candidate_priority() -> Result<()> { - let tests = vec![ - ( - CandidateBase { - candidate_type: CandidateType::Host, - component: AtomicU16::new(COMPONENT_RTP), - ..Default::default() - }, - 2130706431, - ), - ( - CandidateBase { - candidate_type: CandidateType::Host, - component: AtomicU16::new(COMPONENT_RTP), - network_type: AtomicU8::new(NetworkType::Tcp4 as u8), - tcp_type: TcpType::Active, - ..Default::default() - }, - 2128609279, - ), - ( - CandidateBase { - candidate_type: CandidateType::Host, - component: AtomicU16::new(COMPONENT_RTP), - network_type: AtomicU8::new(NetworkType::Tcp4 as u8), - tcp_type: TcpType::Passive, - ..Default::default() - }, - 2124414975, - ), - ( - CandidateBase { - candidate_type: CandidateType::Host, - component: AtomicU16::new(COMPONENT_RTP), - network_type: AtomicU8::new(NetworkType::Tcp4 as u8), - tcp_type: TcpType::SimultaneousOpen, - ..Default::default() - }, - 2120220671, - ), - ( - CandidateBase { - candidate_type: CandidateType::PeerReflexive, - component: AtomicU16::new(COMPONENT_RTP), - ..Default::default() - }, - 1862270975, - ), - ( - CandidateBase { - candidate_type: CandidateType::PeerReflexive, - component: AtomicU16::new(COMPONENT_RTP), - network_type: AtomicU8::new(NetworkType::Tcp6 as u8), - tcp_type: TcpType::SimultaneousOpen, - ..Default::default() - }, - 1860173823, - ), - ( - CandidateBase { - candidate_type: CandidateType::PeerReflexive, - component: AtomicU16::new(COMPONENT_RTP), - network_type: AtomicU8::new(NetworkType::Tcp6 as u8), - tcp_type: TcpType::Active, - ..Default::default() - }, - 1855979519, - ), - ( - CandidateBase { - candidate_type: CandidateType::PeerReflexive, - component: AtomicU16::new(COMPONENT_RTP), - network_type: AtomicU8::new(NetworkType::Tcp6 as u8), - tcp_type: TcpType::Passive, - ..Default::default() - }, - 1851785215, - ), - ( - CandidateBase { - candidate_type: CandidateType::ServerReflexive, - component: AtomicU16::new(COMPONENT_RTP), - ..Default::default() - }, - 1694498815, - ), - ( - CandidateBase { - candidate_type: CandidateType::Relay, - component: AtomicU16::new(COMPONENT_RTP), - ..Default::default() - }, - 16777215, - ), - ]; - - for (candidate, want) in tests { - let got = candidate.priority(); - assert_eq!( - got, want, - "Candidate({candidate}).Priority() = {got}, want {want}" - ); - } - - Ok(()) -} - -#[test] -fn test_candidate_last_sent() -> Result<()> { - let candidate = CandidateBase::default(); - assert_eq!(candidate.last_sent(), UNIX_EPOCH); - - let now = SystemTime::now(); - let d = now.duration_since(UNIX_EPOCH)?; - candidate.set_last_sent(d); - assert_eq!(candidate.last_sent(), now); - - Ok(()) -} - -#[test] -fn test_candidate_last_received() -> Result<()> { - let candidate = CandidateBase::default(); - assert_eq!(candidate.last_received(), UNIX_EPOCH); - - let now = SystemTime::now(); - let d = now.duration_since(UNIX_EPOCH)?; - candidate.set_last_received(d); - assert_eq!(candidate.last_received(), now); - - Ok(()) -} - -#[test] -fn test_candidate_foundation() -> Result<()> { - // All fields are the same - assert_eq!( - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation(), - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation() - ); - - // Different Address - assert_ne!( - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation(), - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "B".to_owned(), - ..Default::default() - }) - .foundation(), - ); - - // Different networkType - assert_ne!( - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation(), - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp6 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation(), - ); - - // Different candidateType - assert_ne!( - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation(), - (CandidateBase { - candidate_type: CandidateType::PeerReflexive, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - ..Default::default() - }) - .foundation(), - ); - - // Port has no effect - assert_eq!( - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - port: 8080, - ..Default::default() - }) - .foundation(), - (CandidateBase { - candidate_type: CandidateType::Host, - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - address: "A".to_owned(), - port: 80, - ..Default::default() - }) - .foundation() - ); - - Ok(()) -} - -#[test] -fn test_candidate_pair_state_serialization() { - let tests = vec![ - (CandidatePairState::Unspecified, "\"unspecified\""), - (CandidatePairState::Waiting, "\"waiting\""), - (CandidatePairState::InProgress, "\"in-progress\""), - (CandidatePairState::Failed, "\"failed\""), - (CandidatePairState::Succeeded, "\"succeeded\""), - ]; - - for (candidate_pair_state, expected_string) in tests { - assert_eq!( - expected_string.to_string(), - serde_json::to_string(&candidate_pair_state).unwrap() - ); - } -} - -#[test] -fn test_candidate_pair_state_to_string() { - let tests = vec![ - (CandidatePairState::Unspecified, "unspecified"), - (CandidatePairState::Waiting, "waiting"), - (CandidatePairState::InProgress, "in-progress"), - (CandidatePairState::Failed, "failed"), - (CandidatePairState::Succeeded, "succeeded"), - ]; - - for (candidate_pair_state, expected_string) in tests { - assert_eq!(candidate_pair_state.to_string(), expected_string); - } -} - -#[test] -fn test_candidate_type_serialization() { - let tests = vec![ - (CandidateType::Unspecified, "\"unspecified\""), - (CandidateType::Host, "\"host\""), - (CandidateType::ServerReflexive, "\"srflx\""), - (CandidateType::PeerReflexive, "\"prflx\""), - (CandidateType::Relay, "\"relay\""), - ]; - - for (candidate_type, expected_string) in tests { - assert_eq!( - serde_json::to_string(&candidate_type).unwrap(), - expected_string.to_string() - ); - } -} - -#[test] -fn test_candidate_type_to_string() { - let tests = vec![ - (CandidateType::Unspecified, "Unknown candidate type"), - (CandidateType::Host, "host"), - (CandidateType::ServerReflexive, "srflx"), - (CandidateType::PeerReflexive, "prflx"), - (CandidateType::Relay, "relay"), - ]; - - for (candidate_type, expected_string) in tests { - assert_eq!(candidate_type.to_string(), expected_string); - } -} - -#[test] -fn test_candidate_marshal() -> Result<()> { - let tests = vec![ - ( - Some(CandidateBase{ - network_type: AtomicU8::new(NetworkType::Udp6 as u8), - candidate_type: CandidateType::Host, - address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a".to_owned(), - port: 53987, - priority_override: 500, - foundation_override: "750".to_owned(), - ..Default::default() - }), - "750 1 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host", - ), - ( - Some(CandidateBase{ - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - candidate_type: CandidateType::Host, - address: "10.0.75.1".to_owned(), - port: 53634, - ..Default::default() - }), - "4273957277 1 udp 2130706431 10.0.75.1 53634 typ host", - ), - ( - Some(CandidateBase{ - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - candidate_type: CandidateType::ServerReflexive, - address: "191.228.238.68".to_owned(), - port: 53991, - related_address: Some(CandidateRelatedAddress{ - address: "192.168.0.274".to_owned(), - port:53991 - }), - ..Default::default() - }), - "647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991", - ), - ( - Some(CandidateBase{ - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - candidate_type: CandidateType::Relay, - address: "50.0.0.1".to_owned(), - port: 5000, - related_address: Some( - CandidateRelatedAddress{ - address: "192.168.0.1".to_owned(), - port:5001} - ), - ..Default::default() - }), - "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001", - ), - ( - Some(CandidateBase{ - network_type: AtomicU8::new(NetworkType::Tcp4 as u8), - candidate_type: CandidateType::Host, - address: "192.168.0.196".to_owned(), - port: 0, - tcp_type: TcpType::Active, - ..Default::default() - }), - "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active", - ), - ( - Some(CandidateBase{ - network_type: AtomicU8::new(NetworkType::Udp4 as u8), - candidate_type: CandidateType::Host, - address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local".to_owned(), - port: 60542, - ..Default::default() - }), - "1380287402 1 udp 2130706431 e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local 60542 typ host", - ), - // Invalid candidates - (None, ""), - (None, "1938809241"), - (None, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2"), - (None, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host"), - (None, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3"), - (None, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr"), - (None, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3"), - (None, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host"), - (None, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host"), - (None, "4207374051 INVALID udp 2130706431 10.0.75.1 INVALID typ host"), - (None, "4207374051 1 udp 2130706431 10.0.75.1 53634 typ INVALID"), - ]; - - for (candidate, marshaled) in tests { - let actual_candidate = unmarshal_candidate(marshaled); - if let Some(candidate) = candidate { - if let Ok(actual_candidate) = actual_candidate { - assert!( - candidate.equal(&actual_candidate), - "{} vs {}", - candidate.marshal(), - marshaled - ); - assert_eq!(marshaled, actual_candidate.marshal()); - } else { - panic!("expected ok"); - } - } else { - assert!(actual_candidate.is_err(), "expected error"); - } - } - - Ok(()) -} diff --git a/ice/src/candidate/mod.rs b/ice/src/candidate/mod.rs deleted file mode 100644 index d764d44c7..000000000 --- a/ice/src/candidate/mod.rs +++ /dev/null @@ -1,325 +0,0 @@ -#[cfg(test)] -mod candidate_pair_test; -#[cfg(test)] -mod candidate_relay_test; -#[cfg(test)] -mod candidate_server_reflexive_test; -#[cfg(test)] -mod candidate_test; - -pub mod candidate_base; -pub mod candidate_host; -pub mod candidate_peer_reflexive; -pub mod candidate_relay; -pub mod candidate_server_reflexive; - -use std::fmt; -use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use async_trait::async_trait; -use candidate_base::*; -use portable_atomic::{AtomicBool, AtomicU16, AtomicU8}; -use serde::{Deserialize, Serialize}; -use tokio::sync::{broadcast, Mutex}; - -use crate::error::Result; -use crate::network_type::*; -use crate::tcp_type::*; - -pub(crate) const RECEIVE_MTU: usize = 8192; -pub(crate) const DEFAULT_LOCAL_PREFERENCE: u16 = 65535; - -/// Indicates that the candidate is used for RTP. -pub(crate) const COMPONENT_RTP: u16 = 1; -/// Indicates that the candidate is used for RTCP. -pub(crate) const COMPONENT_RTCP: u16 = 0; - -/// Candidate represents an ICE candidate -#[async_trait] -pub trait Candidate: fmt::Display { - /// An arbitrary string used in the freezing algorithm to - /// group similar candidates. It is the same for two candidates that - /// have the same type, base IP address, protocol (UDP, TCP, etc.), - /// and STUN or TURN server. - fn foundation(&self) -> String; - - /// A unique identifier for just this candidate - /// Unlike the foundation this is different for each candidate. - fn id(&self) -> String; - - /// A component is a piece of a data stream. - /// An example is one for RTP, and one for RTCP - fn component(&self) -> u16; - fn set_component(&self, c: u16); - - /// The last time this candidate received traffic - fn last_received(&self) -> SystemTime; - - /// The last time this candidate sent traffic - fn last_sent(&self) -> SystemTime; - - fn network_type(&self) -> NetworkType; - fn address(&self) -> String; - fn port(&self) -> u16; - - fn priority(&self) -> u32; - - /// A transport address related to candidate, - /// which is useful for diagnostics and other purposes. - fn related_address(&self) -> Option; - - fn candidate_type(&self) -> CandidateType; - fn tcp_type(&self) -> TcpType; - - fn marshal(&self) -> String; - - fn addr(&self) -> SocketAddr; - - async fn close(&self) -> Result<()>; - fn seen(&self, outbound: bool); - - async fn write_to(&self, raw: &[u8], dst: &(dyn Candidate + Send + Sync)) -> Result; - fn equal(&self, other: &dyn Candidate) -> bool; - fn set_ip(&self, ip: &IpAddr) -> Result<()>; - fn get_conn(&self) -> Option<&Arc>; - fn get_closed_ch(&self) -> Arc>>>; -} - -/// Represents the type of candidate `CandidateType` enum. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum CandidateType { - #[serde(rename = "unspecified")] - Unspecified, - #[serde(rename = "host")] - Host, - #[serde(rename = "srflx")] - ServerReflexive, - #[serde(rename = "prflx")] - PeerReflexive, - #[serde(rename = "relay")] - Relay, -} - -// String makes CandidateType printable -impl fmt::Display for CandidateType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - CandidateType::Host => "host", - CandidateType::ServerReflexive => "srflx", - CandidateType::PeerReflexive => "prflx", - CandidateType::Relay => "relay", - CandidateType::Unspecified => "Unknown candidate type", - }; - write!(f, "{s}") - } -} - -impl Default for CandidateType { - fn default() -> Self { - Self::Unspecified - } -} - -impl CandidateType { - /// Returns the preference weight of a `CandidateType`. - /// - /// 4.1.2.2. Guidelines for Choosing Type and Local Preferences - /// The RECOMMENDED values are 126 for host candidates, 100 - /// for server reflexive candidates, 110 for peer reflexive candidates, - /// and 0 for relayed candidates. - #[must_use] - pub const fn preference(self) -> u16 { - match self { - Self::Host => 126, - Self::PeerReflexive => 110, - Self::ServerReflexive => 100, - Self::Relay | CandidateType::Unspecified => 0, - } - } -} - -pub(crate) fn contains_candidate_type( - candidate_type: CandidateType, - candidate_type_list: &[CandidateType], -) -> bool { - if candidate_type_list.is_empty() { - return false; - } - for ct in candidate_type_list { - if *ct == candidate_type { - return true; - } - } - false -} - -/// Convey transport addresses related to the candidate, useful for diagnostics and other purposes. -#[derive(PartialEq, Eq, Debug, Clone)] -pub struct CandidateRelatedAddress { - pub address: String, - pub port: u16, -} - -// String makes CandidateRelatedAddress printable -impl fmt::Display for CandidateRelatedAddress { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, " related {}:{}", self.address, self.port) - } -} - -/// Represent the ICE candidate pair state. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum CandidatePairState { - #[serde(rename = "unspecified")] - Unspecified = 0, - - /// Means a check has not been performed for this pair. - #[serde(rename = "waiting")] - Waiting = 1, - - /// Means a check has been sent for this pair, but the transaction is in progress. - #[serde(rename = "in-progress")] - InProgress = 2, - - /// Means a check for this pair was already done and failed, either never producing any response - /// or producing an unrecoverable failure response. - #[serde(rename = "failed")] - Failed = 3, - - /// Means a check for this pair was already done and produced a successful result. - #[serde(rename = "succeeded")] - Succeeded = 4, -} - -impl From for CandidatePairState { - fn from(v: u8) -> Self { - match v { - 1 => Self::Waiting, - 2 => Self::InProgress, - 3 => Self::Failed, - 4 => Self::Succeeded, - _ => Self::Unspecified, - } - } -} - -impl Default for CandidatePairState { - fn default() -> Self { - Self::Unspecified - } -} - -impl fmt::Display for CandidatePairState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::Waiting => "waiting", - Self::InProgress => "in-progress", - Self::Failed => "failed", - Self::Succeeded => "succeeded", - Self::Unspecified => "unspecified", - }; - - write!(f, "{s}") - } -} - -/// Represents a combination of a local and remote candidate. -pub struct CandidatePair { - pub(crate) ice_role_controlling: AtomicBool, - pub remote: Arc, - pub local: Arc, - pub(crate) binding_request_count: AtomicU16, - pub(crate) state: AtomicU8, // convert it to CandidatePairState, - pub(crate) nominated: AtomicBool, -} - -impl Default for CandidatePair { - fn default() -> Self { - Self { - ice_role_controlling: AtomicBool::new(false), - remote: Arc::new(CandidateBase::default()), - local: Arc::new(CandidateBase::default()), - state: AtomicU8::new(CandidatePairState::Waiting as u8), - binding_request_count: AtomicU16::new(0), - nominated: AtomicBool::new(false), - } - } -} - -impl fmt::Debug for CandidatePair { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "prio {} (local, prio {}) {} <-> {} (remote, prio {})", - self.priority(), - self.local.priority(), - self.local, - self.remote, - self.remote.priority() - ) - } -} - -impl fmt::Display for CandidatePair { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "prio {} (local, prio {}) {} <-> {} (remote, prio {})", - self.priority(), - self.local.priority(), - self.local, - self.remote, - self.remote.priority() - ) - } -} - -impl PartialEq for CandidatePair { - fn eq(&self, other: &Self) -> bool { - self.local.equal(&*other.local) && self.remote.equal(&*other.remote) - } -} - -impl CandidatePair { - #[must_use] - pub fn new( - local: Arc, - remote: Arc, - controlling: bool, - ) -> Self { - Self { - ice_role_controlling: AtomicBool::new(controlling), - remote, - local, - state: AtomicU8::new(CandidatePairState::Waiting as u8), - binding_request_count: AtomicU16::new(0), - nominated: AtomicBool::new(false), - } - } - - /// RFC 5245 - 5.7.2. Computing Pair Priority and Ordering Pairs - /// Let G be the priority for the candidate provided by the controlling - /// agent. Let D be the priority for the candidate provided by the - /// controlled agent. - /// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0) - pub fn priority(&self) -> u64 { - let (g, d) = if self.ice_role_controlling.load(Ordering::SeqCst) { - (self.local.priority(), self.remote.priority()) - } else { - (self.remote.priority(), self.local.priority()) - }; - - // 1<<32 overflows uint32; and if both g && d are - // maxUint32, this result would overflow uint64 - ((1 << 32_u64) - 1) * u64::from(std::cmp::min(g, d)) - + 2 * u64::from(std::cmp::max(g, d)) - + u64::from(g > d) - } - - pub async fn write(&self, b: &[u8]) -> Result { - self.local.write_to(b, &*self.remote).await - } -} diff --git a/ice/src/control/control_test.rs b/ice/src/control/control_test.rs deleted file mode 100644 index 480c04555..000000000 --- a/ice/src/control/control_test.rs +++ /dev/null @@ -1,168 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_controlled_get_from() -> Result<()> { - let mut m = Message::new(); - let mut c = AttrControlled(4321); - let result = c.get_from(&m); - if let Err(err) = result { - assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); - } else { - panic!("expected error, but got ok"); - } - - m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; - - let mut m1 = Message::new(); - m1.write(&m.raw)?; - - let mut c1 = AttrControlled::default(); - c1.get_from(&m1)?; - - assert_eq!(c1, c, "not equal"); - - //"IncorrectSize" - { - let mut m3 = Message::new(); - m3.add(ATTR_ICE_CONTROLLED, &[0; 100]); - let mut c2 = AttrControlled::default(); - let result = c2.get_from(&m3); - if let Err(err) = result { - assert!(is_attr_size_invalid(&err), "should error"); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_controlling_get_from() -> Result<()> { - let mut m = Message::new(); - let mut c = AttrControlling(4321); - let result = c.get_from(&m); - if let Err(err) = result { - assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); - } else { - panic!("expected error, but got ok"); - } - - m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; - - let mut m1 = Message::new(); - m1.write(&m.raw)?; - - let mut c1 = AttrControlling::default(); - c1.get_from(&m1)?; - - assert_eq!(c1, c, "not equal"); - - //"IncorrectSize" - { - let mut m3 = Message::new(); - m3.add(ATTR_ICE_CONTROLLING, &[0; 100]); - let mut c2 = AttrControlling::default(); - let result = c2.get_from(&m3); - if let Err(err) = result { - assert!(is_attr_size_invalid(&err), "should error"); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_control_get_from() -> Result<()> { - //"Blank" - { - let m = Message::new(); - let mut c = AttrControl::default(); - let result = c.get_from(&m); - if let Err(err) = result { - assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); - } else { - panic!("expected error, but got ok"); - } - } - //"Controlling" - { - let mut m = Message::new(); - let mut c = AttrControl::default(); - let result = c.get_from(&m); - if let Err(err) = result { - assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); - } else { - panic!("expected error, but got ok"); - } - - c.role = Role::Controlling; - c.tie_breaker = TieBreaker(4321); - - m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; - - let mut m1 = Message::new(); - m1.write(&m.raw)?; - - let mut c1 = AttrControl::default(); - c1.get_from(&m1)?; - - assert_eq!(c1, c, "not equal"); - - //"IncorrectSize" - { - let mut m3 = Message::new(); - m3.add(ATTR_ICE_CONTROLLING, &[0; 100]); - let mut c2 = AttrControl::default(); - let result = c2.get_from(&m3); - if let Err(err) = result { - assert!(is_attr_size_invalid(&err), "should error"); - } else { - panic!("expected error, but got ok"); - } - } - } - - //"Controlled" - { - let mut m = Message::new(); - let mut c = AttrControl::default(); - let result = c.get_from(&m); - if let Err(err) = result { - assert_eq!(stun::Error::ErrAttributeNotFound, err, "unexpected error"); - } else { - panic!("expected error, but got ok"); - } - - c.role = Role::Controlled; - c.tie_breaker = TieBreaker(1234); - - m.build(&[Box::new(BINDING_REQUEST), Box::new(c)])?; - - let mut m1 = Message::new(); - m1.write(&m.raw)?; - - let mut c1 = AttrControl::default(); - c1.get_from(&m1)?; - - assert_eq!(c1, c, "not equal"); - - //"IncorrectSize" - { - let mut m3 = Message::new(); - m3.add(ATTR_ICE_CONTROLLING, &[0; 100]); - let mut c2 = AttrControl::default(); - let result = c2.get_from(&m3); - if let Err(err) = result { - assert!(is_attr_size_invalid(&err), "should error"); - } else { - panic!("expected error, but got ok"); - } - } - } - - Ok(()) -} diff --git a/ice/src/control/mod.rs b/ice/src/control/mod.rs deleted file mode 100644 index a79e170c3..000000000 --- a/ice/src/control/mod.rs +++ /dev/null @@ -1,143 +0,0 @@ -#[cfg(test)] -mod control_test; - -use std::fmt; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -/// Common helper for ICE-{CONTROLLED,CONTROLLING} and represents the so-called Tiebreaker number. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct TieBreaker(pub u64); - -pub(crate) const TIE_BREAKER_SIZE: usize = 8; // 64 bit - -impl TieBreaker { - /// Adds Tiebreaker value to m as t attribute. - pub fn add_to_as(self, m: &mut Message, t: AttrType) -> Result<(), stun::Error> { - let mut v = vec![0; TIE_BREAKER_SIZE]; - v.copy_from_slice(&self.0.to_be_bytes()); - m.add(t, &v); - Ok(()) - } - - /// Decodes Tiebreaker value in message getting it as for t type. - pub fn get_from_as(&mut self, m: &Message, t: AttrType) -> Result<(), stun::Error> { - let v = m.get(t)?; - check_size(t, v.len(), TIE_BREAKER_SIZE)?; - self.0 = u64::from_be_bytes([v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]]); - Ok(()) - } -} -/// Represents ICE-CONTROLLED attribute. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct AttrControlled(pub u64); - -impl Setter for AttrControlled { - /// Adds ICE-CONTROLLED to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - TieBreaker(self.0).add_to_as(m, ATTR_ICE_CONTROLLED) - } -} - -impl Getter for AttrControlled { - /// Decodes ICE-CONTROLLED from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let mut t = TieBreaker::default(); - t.get_from_as(m, ATTR_ICE_CONTROLLED)?; - self.0 = t.0; - Ok(()) - } -} - -/// Represents ICE-CONTROLLING attribute. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct AttrControlling(pub u64); - -impl Setter for AttrControlling { - // add_to adds ICE-CONTROLLING to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - TieBreaker(self.0).add_to_as(m, ATTR_ICE_CONTROLLING) - } -} - -impl Getter for AttrControlling { - // get_from decodes ICE-CONTROLLING from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let mut t = TieBreaker::default(); - t.get_from_as(m, ATTR_ICE_CONTROLLING)?; - self.0 = t.0; - Ok(()) - } -} - -/// Helper that wraps ICE-{CONTROLLED,CONTROLLING}. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct AttrControl { - role: Role, - tie_breaker: TieBreaker, -} - -impl Setter for AttrControl { - // add_to adds ICE-CONTROLLED or ICE-CONTROLLING attribute depending on Role. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - if self.role == Role::Controlling { - self.tie_breaker.add_to_as(m, ATTR_ICE_CONTROLLING) - } else { - self.tie_breaker.add_to_as(m, ATTR_ICE_CONTROLLED) - } - } -} - -impl Getter for AttrControl { - // get_from decodes Role and Tiebreaker value from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - if m.contains(ATTR_ICE_CONTROLLING) { - self.role = Role::Controlling; - return self.tie_breaker.get_from_as(m, ATTR_ICE_CONTROLLING); - } - if m.contains(ATTR_ICE_CONTROLLED) { - self.role = Role::Controlled; - return self.tie_breaker.get_from_as(m, ATTR_ICE_CONTROLLED); - } - - Err(stun::Error::ErrAttributeNotFound) - } -} - -/// Represents ICE agent role, which can be controlling or controlled. -/// Possible ICE agent roles. -#[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub enum Role { - Controlling, - Controlled, - Unspecified, -} - -impl Default for Role { - fn default() -> Self { - Self::Controlling - } -} - -impl From<&str> for Role { - fn from(raw: &str) -> Self { - match raw { - "controlling" => Self::Controlling, - "controlled" => Self::Controlled, - _ => Self::Unspecified, - } - } -} - -impl fmt::Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::Controlling => "controlling", - Self::Controlled => "controlled", - Self::Unspecified => "unspecified", - }; - write!(f, "{s}") - } -} diff --git a/ice/src/error.rs b/ice/src/error.rs deleted file mode 100644 index a3f6ff84e..000000000 --- a/ice/src/error.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::num::ParseIntError; -use std::time::SystemTimeError; -use std::{io, net}; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - /// Indicates an error with Unknown info. - #[error("Unknown type")] - ErrUnknownType, - - /// Indicates the scheme type could not be parsed. - #[error("unknown scheme type")] - ErrSchemeType, - - /// Indicates query arguments are provided in a STUN URL. - #[error("queries not supported in stun address")] - ErrStunQuery, - - /// Indicates an malformed query is provided. - #[error("invalid query")] - ErrInvalidQuery, - - /// Indicates malformed hostname is provided. - #[error("invalid hostname")] - ErrHost, - - /// Indicates malformed port is provided. - #[error("invalid port number")] - ErrPort, - - /// Indicates local username fragment insufficient bits are provided. - /// Have to be at least 24 bits long. - #[error("local username fragment is less than 24 bits long")] - ErrLocalUfragInsufficientBits, - - /// Indicates local passoword insufficient bits are provided. - /// Have to be at least 128 bits long. - #[error("local password is less than 128 bits long")] - ErrLocalPwdInsufficientBits, - - /// Indicates an unsupported transport type was provided. - #[error("invalid transport protocol type")] - ErrProtoType, - - /// Indicates the agent is closed. - #[error("the agent is closed")] - ErrClosed, - - /// Indicates agent does not have a valid candidate pair. - #[error("no candidate pairs available")] - ErrNoCandidatePairs, - - /// Indicates agent connection was canceled by the caller. - #[error("connecting canceled by caller")] - ErrCanceledByCaller, - - /// Indicates agent was started twice. - #[error("attempted to start agent twice")] - ErrMultipleStart, - - /// Indicates agent was started with an empty remote ufrag. - #[error("remote ufrag is empty")] - ErrRemoteUfragEmpty, - - /// Indicates agent was started with an empty remote pwd. - #[error("remote pwd is empty")] - ErrRemotePwdEmpty, - - /// Indicates agent was started without on_candidate. - #[error("no on_candidate provided")] - ErrNoOnCandidateHandler, - - /// Indicates GatherCandidates has been called multiple times. - #[error("attempting to gather candidates during gathering state")] - ErrMultipleGatherAttempted, - - /// Indicates agent was give TURN URL with an empty Username. - #[error("username is empty")] - ErrUsernameEmpty, - - /// Indicates agent was give TURN URL with an empty Password. - #[error("password is empty")] - ErrPasswordEmpty, - - /// Indicates we were unable to parse a candidate address. - #[error("failed to parse address")] - ErrAddressParseFailed, - - /// Indicates that non host candidates were selected for a lite agent. - #[error("lite agents must only use host candidates")] - ErrLiteUsingNonHostCandidates, - - /// Indicates that one or more URL was provided to the agent but no host candidate required them. - #[error("agent does not need URL with selected candidate types")] - ErrUselessUrlsProvided, - - /// Indicates that the specified NAT1To1IPCandidateType is unsupported. - #[error("unsupported 1:1 NAT IP candidate type")] - ErrUnsupportedNat1to1IpCandidateType, - - /// Indicates that the given 1:1 NAT IP mapping is invalid. - #[error("invalid 1:1 NAT IP mapping")] - ErrInvalidNat1to1IpMapping, - - /// IPNotFound in NAT1To1IPMapping. - #[error("external mapped IP not found")] - ErrExternalMappedIpNotFound, - - /// Indicates that the mDNS gathering cannot be used along with 1:1 NAT IP mapping for host - /// candidate. - #[error("mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate")] - ErrMulticastDnsWithNat1to1IpMapping, - - /// Indicates that 1:1 NAT IP mapping for host candidate is requested, but the host candidate - /// type is disabled. - #[error("1:1 NAT IP mapping for host candidate ineffective")] - ErrIneffectiveNat1to1IpMappingHost, - - /// Indicates that 1:1 NAT IP mapping for srflx candidate is requested, but the srflx candidate - /// type is disabled. - #[error("1:1 NAT IP mapping for srflx candidate ineffective")] - ErrIneffectiveNat1to1IpMappingSrflx, - - /// Indicates an invalid MulticastDNSHostName. - #[error("invalid mDNS HostName, must end with .local and can only contain a single '.'")] - ErrInvalidMulticastDnshostName, - - /// Indicates Restart was called when Agent is in GatheringStateGathering. - #[error("ICE Agent can not be restarted when gathering")] - ErrRestartWhenGathering, - - /// Indicates a run operation was canceled by its individual done. - #[error("run was canceled by done")] - ErrRunCanceled, - - /// Initialized Indicates TCPMux is not initialized and that invalidTCPMux is used. - #[error("TCPMux is not initialized")] - ErrTcpMuxNotInitialized, - - /// Indicates we already have the connection with same remote addr. - #[error("conn with same remote addr already exists")] - ErrTcpRemoteAddrAlreadyExists, - - #[error("failed to send packet")] - ErrSendPacket, - #[error("attribute not long enough to be ICE candidate")] - ErrAttributeTooShortIceCandidate, - #[error("could not parse component")] - ErrParseComponent, - #[error("could not parse priority")] - ErrParsePriority, - #[error("could not parse port")] - ErrParsePort, - #[error("could not parse related addresses")] - ErrParseRelatedAddr, - #[error("could not parse type")] - ErrParseType, - #[error("unknown candidate type")] - ErrUnknownCandidateType, - #[error("failed to get XOR-MAPPED-ADDRESS response")] - ErrGetXorMappedAddrResponse, - #[error("connection with same remote address already exists")] - ErrConnectionAddrAlreadyExist, - #[error("error reading streaming packet")] - ErrReadingStreamingPacket, - #[error("error writing to")] - ErrWriting, - #[error("error closing connection")] - ErrClosingConnection, - #[error("unable to determine networkType")] - ErrDetermineNetworkType, - #[error("missing protocol scheme")] - ErrMissingProtocolScheme, - #[error("too many colons in address")] - ErrTooManyColonsAddr, - #[error("unexpected error trying to read")] - ErrRead, - #[error("unknown role")] - ErrUnknownRole, - #[error("username mismatch")] - ErrMismatchUsername, - #[error("the ICE conn can't write STUN messages")] - ErrIceWriteStunMessage, - #[error("invalid url")] - ErrInvalidUrl, - #[error("relative URL without a base")] - ErrUrlParse, - #[error("Candidate IP could not be found")] - ErrCandidateIpNotFound, - - #[error("parse int: {0}")] - ParseInt(#[from] ParseIntError), - #[error("parse addr: {0}")] - ParseIp(#[from] net::AddrParseError), - #[error("{0}")] - Io(#[source] IoError), - #[error("{0}")] - Util(#[from] util::Error), - #[error("{0}")] - Stun(#[from] stun::Error), - #[error("{0}")] - ParseUrl(#[from] url::ParseError), - #[error("{0}")] - Mdns(#[from] mdns::Error), - #[error("{0}")] - Turn(#[from] turn::Error), - - #[error("{0}")] - Other(String), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} - -impl From for Error { - fn from(e: SystemTimeError) -> Self { - Error::Other(e.to_string()) - } -} diff --git a/ice/src/external_ip_mapper/external_ip_mapper_test.rs b/ice/src/external_ip_mapper/external_ip_mapper_test.rs deleted file mode 100644 index d9f9e5d60..000000000 --- a/ice/src/external_ip_mapper/external_ip_mapper_test.rs +++ /dev/null @@ -1,251 +0,0 @@ -use super::*; - -#[test] -fn test_external_ip_mapper_validate_ip_string() -> Result<()> { - let ip = validate_ip_string("1.2.3.4")?; - assert!(ip.is_ipv4(), "should be true"); - assert_eq!("1.2.3.4", ip.to_string(), "should be true"); - - let ip = validate_ip_string("2601:4567::5678")?; - assert!(!ip.is_ipv4(), "should be false"); - assert_eq!("2601:4567::5678", ip.to_string(), "should be true"); - - let result = validate_ip_string("bad.6.6.6"); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[test] -fn test_external_ip_mapper_new_external_ip_mapper() -> Result<()> { - // ips being empty should succeed but mapper will still be nil - let m = ExternalIpMapper::new(CandidateType::Unspecified, &[])?; - assert!(m.is_none(), "should be none"); - - // IPv4 with no explicit local IP, defaults to CandidateTypeHost - let m = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4".to_owned()])?.unwrap(); - assert_eq!(m.candidate_type, CandidateType::Host, "should match"); - assert!(m.ipv4_mapping.ip_sole.is_some()); - assert!(m.ipv6_mapping.ip_sole.is_none()); - assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); - assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); - - // IPv4 with no explicit local IP, using CandidateTypeServerReflexive - let m = - ExternalIpMapper::new(CandidateType::ServerReflexive, &["1.2.3.4".to_owned()])?.unwrap(); - assert_eq!( - CandidateType::ServerReflexive, - m.candidate_type, - "should match" - ); - assert!(m.ipv4_mapping.ip_sole.is_some()); - assert!(m.ipv6_mapping.ip_sole.is_none()); - assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); - assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); - - // IPv4 with no explicit local IP, defaults to CandidateTypeHost - let m = ExternalIpMapper::new(CandidateType::Unspecified, &["2601:4567::5678".to_owned()])? - .unwrap(); - assert_eq!(m.candidate_type, CandidateType::Host, "should match"); - assert!(m.ipv4_mapping.ip_sole.is_none()); - assert!(m.ipv6_mapping.ip_sole.is_some()); - assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); - assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); - - // IPv4 and IPv6 in the mix - let m = ExternalIpMapper::new( - CandidateType::Unspecified, - &["1.2.3.4".to_owned(), "2601:4567::5678".to_owned()], - )? - .unwrap(); - assert_eq!(m.candidate_type, CandidateType::Host, "should match"); - assert!(m.ipv4_mapping.ip_sole.is_some()); - assert!(m.ipv6_mapping.ip_sole.is_some()); - assert_eq!(m.ipv4_mapping.ip_map.len(), 0, "should match"); - assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); - - // Unsupported candidate type - CandidateTypePeerReflexive - let result = ExternalIpMapper::new(CandidateType::PeerReflexive, &["1.2.3.4".to_owned()]); - assert!(result.is_err(), "should fail"); - - // Unsupported candidate type - CandidateTypeRelay - let result = ExternalIpMapper::new(CandidateType::PeerReflexive, &["1.2.3.4".to_owned()]); - assert!(result.is_err(), "should fail"); - - // Cannot duplicate mapping IPv4 family - let result = ExternalIpMapper::new( - CandidateType::ServerReflexive, - &["1.2.3.4".to_owned(), "5.6.7.8".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - // Cannot duplicate mapping IPv6 family - let result = ExternalIpMapper::new( - CandidateType::ServerReflexive, - &["2201::1".to_owned(), "2201::0002".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - // Invalid external IP string - let result = ExternalIpMapper::new(CandidateType::ServerReflexive, &["bad.2.3.4".to_owned()]); - assert!(result.is_err(), "should fail"); - - // Invalid local IP string - let result = ExternalIpMapper::new( - CandidateType::ServerReflexive, - &["1.2.3.4/10.0.0.bad".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[test] -fn test_external_ip_mapper_new_external_ip_mapper_with_explicit_local_ip() -> Result<()> { - // IPv4 with explicit local IP, defaults to CandidateTypeHost - let m = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4/10.0.0.1".to_owned()])? - .unwrap(); - assert_eq!(m.candidate_type, CandidateType::Host, "should match"); - assert!(m.ipv4_mapping.ip_sole.is_none()); - assert!(m.ipv6_mapping.ip_sole.is_none()); - assert_eq!(m.ipv4_mapping.ip_map.len(), 1, "should match"); - assert_eq!(m.ipv6_mapping.ip_map.len(), 0, "should match"); - - // Cannot assign two ext IPs for one local IPv4 - let result = ExternalIpMapper::new( - CandidateType::Unspecified, - &["1.2.3.4/10.0.0.1".to_owned(), "1.2.3.5/10.0.0.1".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - // Cannot assign two ext IPs for one local IPv6 - let result = ExternalIpMapper::new( - CandidateType::Unspecified, - &[ - "2200::1/fe80::1".to_owned(), - "2200::0002/fe80::1".to_owned(), - ], - ); - assert!(result.is_err(), "should fail"); - - // Cannot mix different IP family in a pair (1) - let result = - ExternalIpMapper::new(CandidateType::Unspecified, &["2200::1/10.0.0.1".to_owned()]); - assert!(result.is_err(), "should fail"); - - // Cannot mix different IP family in a pair (2) - let result = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4/fe80::1".to_owned()]); - assert!(result.is_err(), "should fail"); - - // Invalid pair - let result = ExternalIpMapper::new( - CandidateType::Unspecified, - &["1.2.3.4/192.168.0.2/10.0.0.1".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[test] -fn test_external_ip_mapper_new_external_ip_mapper_with_implicit_local_ip() -> Result<()> { - // Mixing inpicit and explicit local IPs not allowed - let result = ExternalIpMapper::new( - CandidateType::Unspecified, - &["1.2.3.4".to_owned(), "1.2.3.5/10.0.0.1".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - // Mixing inpicit and explicit local IPs not allowed - let result = ExternalIpMapper::new( - CandidateType::Unspecified, - &["1.2.3.5/10.0.0.1".to_owned(), "1.2.3.4".to_owned()], - ); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[test] -fn test_external_ip_mapper_find_external_ip_without_explicit_local_ip() -> Result<()> { - // IPv4 with explicit local IP, defaults to CandidateTypeHost - let m = ExternalIpMapper::new( - CandidateType::Unspecified, - &["1.2.3.4".to_owned(), "2200::1".to_owned()], - )? - .unwrap(); - assert!(m.ipv4_mapping.ip_sole.is_some()); - assert!(m.ipv6_mapping.ip_sole.is_some()); - - // find external IPv4 - let ext_ip = m.find_external_ip("10.0.0.1")?; - assert_eq!(ext_ip.to_string(), "1.2.3.4", "should match"); - - // find external IPv6 - let ext_ip = m.find_external_ip("fe80::0001")?; // use '0001' instead of '1' on purpose - assert_eq!(ext_ip.to_string(), "2200::1", "should match"); - - // Bad local IP string - let result = m.find_external_ip("really.bad"); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[test] -fn test_external_ip_mapper_find_external_ip_with_explicit_local_ip() -> Result<()> { - // IPv4 with explicit local IP, defaults to CandidateTypeHost - let m = ExternalIpMapper::new( - CandidateType::Unspecified, - &[ - "1.2.3.4/10.0.0.1".to_owned(), - "1.2.3.5/10.0.0.2".to_owned(), - "2200::1/fe80::1".to_owned(), - "2200::2/fe80::2".to_owned(), - ], - )? - .unwrap(); - - // find external IPv4 - let ext_ip = m.find_external_ip("10.0.0.1")?; - assert_eq!(ext_ip.to_string(), "1.2.3.4", "should match"); - - let ext_ip = m.find_external_ip("10.0.0.2")?; - assert_eq!(ext_ip.to_string(), "1.2.3.5", "should match"); - - let result = m.find_external_ip("10.0.0.3"); - assert!(result.is_err(), "should fail"); - - // find external IPv6 - let ext_ip = m.find_external_ip("fe80::0001")?; // use '0001' instead of '1' on purpose - assert_eq!(ext_ip.to_string(), "2200::1", "should match"); - - let ext_ip = m.find_external_ip("fe80::0002")?; // use '0002' instead of '2' on purpose - assert_eq!(ext_ip.to_string(), "2200::2", "should match"); - - let result = m.find_external_ip("fe80::3"); - assert!(result.is_err(), "should fail"); - - // Bad local IP string - let result = m.find_external_ip("really.bad"); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[test] -fn test_external_ip_mapper_find_external_ip_with_empty_map() -> Result<()> { - let m = ExternalIpMapper::new(CandidateType::Unspecified, &["1.2.3.4".to_owned()])?.unwrap(); - - // attempt to find IPv6 that does not exist in the map - let result = m.find_external_ip("fe80::1"); - assert!(result.is_err(), "should fail"); - - let m = ExternalIpMapper::new(CandidateType::Unspecified, &["2200::1".to_owned()])?.unwrap(); - - // attempt to find IPv4 that does not exist in the map - let result = m.find_external_ip("10.0.0.1"); - assert!(result.is_err(), "should fail"); - - Ok(()) -} diff --git a/ice/src/external_ip_mapper/mod.rs b/ice/src/external_ip_mapper/mod.rs deleted file mode 100644 index 0d968b83f..000000000 --- a/ice/src/external_ip_mapper/mod.rs +++ /dev/null @@ -1,133 +0,0 @@ -#[cfg(test)] -mod external_ip_mapper_test; - -use std::collections::HashMap; -use std::net::IpAddr; - -use crate::candidate::*; -use crate::error::*; - -pub(crate) fn validate_ip_string(ip_str: &str) -> Result { - match ip_str.parse() { - Ok(ip) => Ok(ip), - Err(_) => Err(Error::ErrInvalidNat1to1IpMapping), - } -} - -/// Holds the mapping of local and external IP address for a particular IP family. -#[derive(Default, PartialEq, Debug)] -pub(crate) struct IpMapping { - ip_sole: Option, // when non-nil, this is the sole external IP for one local IP assumed - ip_map: HashMap, // local-to-external IP mapping (k: local, v: external) -} - -impl IpMapping { - pub(crate) fn set_sole_ip(&mut self, ip: IpAddr) -> Result<()> { - if self.ip_sole.is_some() || !self.ip_map.is_empty() { - return Err(Error::ErrInvalidNat1to1IpMapping); - } - - self.ip_sole = Some(ip); - - Ok(()) - } - - pub(crate) fn add_ip_mapping(&mut self, loc_ip: IpAddr, ext_ip: IpAddr) -> Result<()> { - if self.ip_sole.is_some() { - return Err(Error::ErrInvalidNat1to1IpMapping); - } - - let loc_ip_str = loc_ip.to_string(); - - // check if dup of local IP - if self.ip_map.contains_key(&loc_ip_str) { - return Err(Error::ErrInvalidNat1to1IpMapping); - } - - self.ip_map.insert(loc_ip_str, ext_ip); - - Ok(()) - } - - pub(crate) fn find_external_ip(&self, loc_ip: IpAddr) -> Result { - if let Some(ip_sole) = &self.ip_sole { - return Ok(*ip_sole); - } - - self.ip_map.get(&loc_ip.to_string()).map_or_else( - || Err(Error::ErrExternalMappedIpNotFound), - |ext_ip| Ok(*ext_ip), - ) - } -} - -#[derive(Default)] -pub(crate) struct ExternalIpMapper { - pub(crate) ipv4_mapping: IpMapping, - pub(crate) ipv6_mapping: IpMapping, - pub(crate) candidate_type: CandidateType, -} - -impl ExternalIpMapper { - pub(crate) fn new(mut candidate_type: CandidateType, ips: &[String]) -> Result> { - if ips.is_empty() { - return Ok(None); - } - if candidate_type == CandidateType::Unspecified { - candidate_type = CandidateType::Host; // defaults to host - } else if candidate_type != CandidateType::Host - && candidate_type != CandidateType::ServerReflexive - { - return Err(Error::ErrUnsupportedNat1to1IpCandidateType); - } - - let mut m = Self { - ipv4_mapping: IpMapping::default(), - ipv6_mapping: IpMapping::default(), - candidate_type, - }; - - for ext_ip_str in ips { - let ip_pair: Vec<&str> = ext_ip_str.split('/').collect(); - if ip_pair.is_empty() || ip_pair.len() > 2 { - return Err(Error::ErrInvalidNat1to1IpMapping); - } - - let ext_ip = validate_ip_string(ip_pair[0])?; - if ip_pair.len() == 1 { - if ext_ip.is_ipv4() { - m.ipv4_mapping.set_sole_ip(ext_ip)?; - } else { - m.ipv6_mapping.set_sole_ip(ext_ip)?; - } - } else { - let loc_ip = validate_ip_string(ip_pair[1])?; - if ext_ip.is_ipv4() { - if !loc_ip.is_ipv4() { - return Err(Error::ErrInvalidNat1to1IpMapping); - } - - m.ipv4_mapping.add_ip_mapping(loc_ip, ext_ip)?; - } else { - if loc_ip.is_ipv4() { - return Err(Error::ErrInvalidNat1to1IpMapping); - } - - m.ipv6_mapping.add_ip_mapping(loc_ip, ext_ip)?; - } - } - } - - Ok(Some(m)) - } - - pub(crate) fn find_external_ip(&self, local_ip_str: &str) -> Result { - let loc_ip = validate_ip_string(local_ip_str)?; - - if loc_ip.is_ipv4() { - self.ipv4_mapping.find_external_ip(loc_ip) - } else { - self.ipv6_mapping.find_external_ip(loc_ip) - } - } -} diff --git a/ice/src/lib.rs b/ice/src/lib.rs deleted file mode 100644 index b2e0af1a2..000000000 --- a/ice/src/lib.rs +++ /dev/null @@ -1,22 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod agent; -pub mod candidate; -pub mod control; -mod error; -pub mod external_ip_mapper; -pub mod mdns; -pub mod network_type; -pub mod priority; -pub mod rand; -pub mod state; -pub mod stats; -pub mod tcp_type; -pub mod udp_mux; -pub mod udp_network; -pub mod url; -pub mod use_candidate; -pub mod util; - -pub use error::Error; diff --git a/ice/src/mdns/mdns_test.rs b/ice/src/mdns/mdns_test.rs deleted file mode 100644 index 604010390..000000000 --- a/ice/src/mdns/mdns_test.rs +++ /dev/null @@ -1,151 +0,0 @@ -use regex::Regex; -use tokio::sync::{mpsc, Mutex}; - -use super::*; -use crate::agent::agent_config::*; -use crate::agent::agent_vnet_test::*; -use crate::agent::*; -use crate::candidate::*; -use crate::error::Error; -use crate::network_type::*; - -#[tokio::test] -// This test is disabled on Windows for now because it gets stuck and never finishes. -// This does not seem to have happened due to a code change. It started happening with -// `ce55c3a066ab461c3e74f0d5ac6f1209205e79bc` but was verified as happening on -// `92cc698a3dc6da459f3bf3789fd046c2dffdf107` too. -#[cfg(not(windows))] -async fn test_multicast_dns_only_connection() -> Result<()> { - let cfg0 = AgentConfig { - network_types: vec![NetworkType::Udp4], - candidate_types: vec![CandidateType::Host], - multicast_dns_mode: MulticastDnsMode::QueryAndGather, - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); - - let cfg1 = AgentConfig { - network_types: vec![NetworkType::Udp4], - candidate_types: vec![CandidateType::Host], - multicast_dns_mode: MulticastDnsMode::QueryAndGather, - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); - - connect_with_vnet(&a_agent, &b_agent).await?; - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - a_agent.close().await?; - b_agent.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_multicast_dns_mixed_connection() -> Result<()> { - let cfg0 = AgentConfig { - network_types: vec![NetworkType::Udp4], - candidate_types: vec![CandidateType::Host], - multicast_dns_mode: MulticastDnsMode::QueryAndGather, - ..Default::default() - }; - - let a_agent = Arc::new(Agent::new(cfg0).await?); - let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); - - let cfg1 = AgentConfig { - network_types: vec![NetworkType::Udp4], - candidate_types: vec![CandidateType::Host], - multicast_dns_mode: MulticastDnsMode::QueryOnly, - ..Default::default() - }; - - let b_agent = Arc::new(Agent::new(cfg1).await?); - let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); - - connect_with_vnet(&a_agent, &b_agent).await?; - let _ = a_connected.recv().await; - let _ = b_connected.recv().await; - - a_agent.close().await?; - b_agent.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_multicast_dns_static_host_name() -> Result<()> { - let cfg0 = AgentConfig { - network_types: vec![NetworkType::Udp4], - candidate_types: vec![CandidateType::Host], - multicast_dns_mode: MulticastDnsMode::QueryAndGather, - multicast_dns_host_name: "invalidHostName".to_owned(), - ..Default::default() - }; - if let Err(err) = Agent::new(cfg0).await { - assert_eq!(err, Error::ErrInvalidMulticastDnshostName); - } else { - panic!("expected error, but got ok"); - } - - let cfg1 = AgentConfig { - network_types: vec![NetworkType::Udp4], - candidate_types: vec![CandidateType::Host], - multicast_dns_mode: MulticastDnsMode::QueryAndGather, - multicast_dns_host_name: "validName.local".to_owned(), - ..Default::default() - }; - - let a = Agent::new(cfg1).await?; - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); - - a.gather_candidates()?; - - log::debug!("wait for gathering is done..."); - let _ = done_rx.recv().await; - log::debug!("gathering is done"); - - Ok(()) -} - -#[test] -fn test_generate_multicast_dnsname() -> Result<()> { - let name = generate_multicast_dns_name(); - - let re = Regex::new( - r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}.local+$", - ); - - if let Ok(re) = re { - assert!( - re.is_match(&name), - "mDNS name must be UUID v4 + \".local\" suffix, got {name}" - ); - } else { - panic!("expected ok, but got err"); - } - - Ok(()) -} diff --git a/ice/src/mdns/mod.rs b/ice/src/mdns/mod.rs deleted file mode 100644 index 981d58fb4..000000000 --- a/ice/src/mdns/mod.rs +++ /dev/null @@ -1,71 +0,0 @@ -#[cfg(test)] -mod mdns_test; - -use std::net::SocketAddr; -use std::str::FromStr; -use std::sync::Arc; - -use mdns::config::*; -use mdns::conn::*; -use uuid::Uuid; - -use crate::error::Result; - -/// Represents the different Multicast modes that ICE can run. -#[derive(PartialEq, Eq, Debug, Copy, Clone)] -pub enum MulticastDnsMode { - /// Means remote mDNS candidates will be discarded, and local host candidates will use IPs. - Disabled, - - /// Means remote mDNS candidates will be accepted, and local host candidates will use IPs. - QueryOnly, - - /// Means remote mDNS candidates will be accepted, and local host candidates will use mDNS. - QueryAndGather, -} - -impl Default for MulticastDnsMode { - fn default() -> Self { - Self::QueryOnly - } -} - -pub(crate) fn generate_multicast_dns_name() -> String { - // https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering - // The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by โ€œ.localโ€. - let u = Uuid::new_v4(); - format!("{u}.local") -} - -pub(crate) fn create_multicast_dns( - mdns_mode: MulticastDnsMode, - mdns_name: &str, - dest_addr: &str, -) -> Result>> { - let local_names = match mdns_mode { - MulticastDnsMode::QueryOnly => vec![], - MulticastDnsMode::QueryAndGather => vec![mdns_name.to_owned()], - MulticastDnsMode::Disabled => return Ok(None), - }; - - let addr = if dest_addr.is_empty() { - //TODO: why DEFAULT_DEST_ADDR doesn't work on Mac/Win? - if cfg!(target_os = "linux") { - SocketAddr::from_str(DEFAULT_DEST_ADDR)? - } else { - SocketAddr::from_str("0.0.0.0:5353")? - } - } else { - SocketAddr::from_str(dest_addr)? - }; - log::info!("mDNS is using {} as dest_addr", addr); - - let conn = DnsConn::server( - addr, - Config { - local_names, - ..Config::default() - }, - )?; - Ok(Some(Arc::new(conn))) -} diff --git a/ice/src/network_type/mod.rs b/ice/src/network_type/mod.rs deleted file mode 100644 index fcd50f9a2..000000000 --- a/ice/src/network_type/mod.rs +++ /dev/null @@ -1,148 +0,0 @@ -#[cfg(test)] -mod network_type_test; - -use std::fmt; -use std::net::IpAddr; - -use serde::{Deserialize, Serialize}; - -use crate::error::*; - -pub(crate) const UDP: &str = "udp"; -pub(crate) const TCP: &str = "tcp"; - -#[must_use] -pub fn supported_network_types() -> Vec { - vec![ - NetworkType::Udp4, - NetworkType::Udp6, - //NetworkType::TCP4, - //NetworkType::TCP6, - ] -} - -/// Represents the type of network. -#[derive(PartialEq, Debug, Copy, Clone, Eq, Hash, Serialize, Deserialize)] -pub enum NetworkType { - #[serde(rename = "unspecified")] - Unspecified, - - /// Indicates UDP over IPv4. - #[serde(rename = "udp4")] - Udp4, - - /// Indicates UDP over IPv6. - #[serde(rename = "udp6")] - Udp6, - - /// Indicates TCP over IPv4. - #[serde(rename = "tcp4")] - Tcp4, - - /// Indicates TCP over IPv6. - #[serde(rename = "tcp6")] - Tcp6, -} - -impl From for NetworkType { - fn from(v: u8) -> Self { - match v { - 1 => Self::Udp4, - 2 => Self::Udp6, - 3 => Self::Tcp4, - 4 => Self::Tcp6, - _ => Self::Unspecified, - } - } -} - -impl fmt::Display for NetworkType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::Udp4 => "udp4", - Self::Udp6 => "udp6", - Self::Tcp4 => "tcp4", - Self::Tcp6 => "tcp6", - Self::Unspecified => "unspecified", - }; - write!(f, "{s}") - } -} - -impl Default for NetworkType { - fn default() -> Self { - Self::Unspecified - } -} - -impl NetworkType { - /// Returns true when network is UDP4 or UDP6. - #[must_use] - pub fn is_udp(self) -> bool { - self == Self::Udp4 || self == Self::Udp6 - } - - /// Returns true when network is TCP4 or TCP6. - #[must_use] - pub fn is_tcp(self) -> bool { - self == Self::Tcp4 || self == Self::Tcp6 - } - - /// Returns the short network description. - #[must_use] - pub fn network_short(self) -> String { - match self { - Self::Udp4 | Self::Udp6 => UDP.to_owned(), - Self::Tcp4 | Self::Tcp6 => TCP.to_owned(), - Self::Unspecified => "Unspecified".to_owned(), - } - } - - /// Returns true if the network is reliable. - #[must_use] - pub const fn is_reliable(self) -> bool { - match self { - Self::Tcp4 | Self::Tcp6 => true, - Self::Udp4 | Self::Udp6 | Self::Unspecified => false, - } - } - - /// Returns whether the network type is IPv4 or not. - #[must_use] - pub const fn is_ipv4(self) -> bool { - match self { - Self::Udp4 | Self::Tcp4 => true, - Self::Udp6 | Self::Tcp6 | Self::Unspecified => false, - } - } - - /// Returns whether the network type is IPv6 or not. - #[must_use] - pub const fn is_ipv6(self) -> bool { - match self { - Self::Udp6 | Self::Tcp6 => true, - Self::Udp4 | Self::Tcp4 | Self::Unspecified => false, - } - } -} - -/// Determines the type of network based on the short network string and an IP address. -pub(crate) fn determine_network_type(network: &str, ip: &IpAddr) -> Result { - let ipv4 = ip.is_ipv4(); - let net = network.to_lowercase(); - if net.starts_with(UDP) { - if ipv4 { - Ok(NetworkType::Udp4) - } else { - Ok(NetworkType::Udp6) - } - } else if net.starts_with(TCP) { - if ipv4 { - Ok(NetworkType::Tcp4) - } else { - Ok(NetworkType::Tcp6) - } - } else { - Err(Error::ErrDetermineNetworkType) - } -} diff --git a/ice/src/network_type/network_type_test.rs b/ice/src/network_type/network_type_test.rs deleted file mode 100644 index fa2a91daa..000000000 --- a/ice/src/network_type/network_type_test.rs +++ /dev/null @@ -1,95 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_network_type_parsing_success() -> Result<()> { - let ipv4: IpAddr = "192.168.0.1".parse().unwrap(); - let ipv6: IpAddr = "fe80::a3:6ff:fec4:5454".parse().unwrap(); - - let tests = vec![ - ("lowercase UDP4", "udp", ipv4, NetworkType::Udp4), - ("uppercase UDP4", "UDP", ipv4, NetworkType::Udp4), - ("lowercase UDP6", "udp", ipv6, NetworkType::Udp6), - ("uppercase UDP6", "UDP", ipv6, NetworkType::Udp6), - ]; - - for (name, in_network, in_ip, expected) in tests { - let actual = determine_network_type(in_network, &in_ip)?; - - assert_eq!( - actual, expected, - "NetworkTypeParsing: '{name}' -- input:{in_network} expected:{expected} actual:{actual}" - ); - } - - Ok(()) -} - -#[test] -fn test_network_type_parsing_failure() -> Result<()> { - let ipv6: IpAddr = "fe80::a3:6ff:fec4:5454".parse().unwrap(); - - let tests = vec![("invalid network", "junkNetwork", ipv6)]; - for (name, in_network, in_ip) in tests { - let result = determine_network_type(in_network, &in_ip); - assert!( - result.is_err(), - "NetworkTypeParsing should fail: '{name}' -- input:{in_network}", - ); - } - - Ok(()) -} - -#[test] -fn test_network_type_is_udp() -> Result<()> { - assert!(NetworkType::Udp4.is_udp()); - assert!(NetworkType::Udp6.is_udp()); - assert!(!NetworkType::Udp4.is_tcp()); - assert!(!NetworkType::Udp6.is_tcp()); - - Ok(()) -} - -#[test] -fn test_network_type_is_tcp() -> Result<()> { - assert!(NetworkType::Tcp4.is_tcp()); - assert!(NetworkType::Tcp6.is_tcp()); - assert!(!NetworkType::Tcp4.is_udp()); - assert!(!NetworkType::Tcp6.is_udp()); - - Ok(()) -} - -#[test] -fn test_network_type_serialization() { - let tests = vec![ - (NetworkType::Tcp4, "\"tcp4\""), - (NetworkType::Tcp6, "\"tcp6\""), - (NetworkType::Udp4, "\"udp4\""), - (NetworkType::Udp6, "\"udp6\""), - (NetworkType::Unspecified, "\"unspecified\""), - ]; - - for (network_type, expected_string) in tests { - assert_eq!( - expected_string.to_string(), - serde_json::to_string(&network_type).unwrap() - ); - } -} - -#[test] -fn test_network_type_to_string() { - let tests = vec![ - (NetworkType::Tcp4, "tcp4"), - (NetworkType::Tcp6, "tcp6"), - (NetworkType::Udp4, "udp4"), - (NetworkType::Udp6, "udp6"), - (NetworkType::Unspecified, "unspecified"), - ]; - - for (network_type, expected_string) in tests { - assert_eq!(network_type.to_string(), expected_string); - } -} diff --git a/ice/src/priority/mod.rs b/ice/src/priority/mod.rs deleted file mode 100644 index 8a00c81e9..000000000 --- a/ice/src/priority/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -#[cfg(test)] -mod priority_test; - -use stun::attributes::ATTR_PRIORITY; -use stun::checks::*; -use stun::message::*; - -/// Represents PRIORITY attribute. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct PriorityAttr(pub u32); - -const PRIORITY_SIZE: usize = 4; // 32 bit - -impl Setter for PriorityAttr { - // add_to adds PRIORITY attribute to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let mut v = vec![0_u8; PRIORITY_SIZE]; - v.copy_from_slice(&self.0.to_be_bytes()); - m.add(ATTR_PRIORITY, &v); - Ok(()) - } -} - -impl PriorityAttr { - /// Decodes PRIORITY attribute from message. - pub fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_PRIORITY)?; - - check_size(ATTR_PRIORITY, v.len(), PRIORITY_SIZE)?; - - let p = u32::from_be_bytes([v[0], v[1], v[2], v[3]]); - self.0 = p; - - Ok(()) - } -} diff --git a/ice/src/priority/priority_test.rs b/ice/src/priority/priority_test.rs deleted file mode 100644 index 231ca7c43..000000000 --- a/ice/src/priority/priority_test.rs +++ /dev/null @@ -1,39 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_priority_get_from() -> Result<()> { - let mut m = Message::new(); - let mut p = PriorityAttr::default(); - let result = p.get_from(&m); - if let Err(err) = result { - assert_eq!(err, stun::Error::ErrAttributeNotFound, "unexpected error"); - } else { - panic!("expected error, but got ok"); - } - - m.build(&[Box::new(BINDING_REQUEST), Box::new(p)])?; - - let mut m1 = Message::new(); - m1.write(&m.raw)?; - - let mut p1 = PriorityAttr::default(); - p1.get_from(&m1)?; - - assert_eq!(p1, p, "not equal"); - - //"IncorrectSize" - { - let mut m3 = Message::new(); - m3.add(ATTR_PRIORITY, &[0; 100]); - let mut p2 = PriorityAttr::default(); - let result = p2.get_from(&m3); - if let Err(err) = result { - assert!(is_attr_size_invalid(&err), "should error"); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} diff --git a/ice/src/rand/mod.rs b/ice/src/rand/mod.rs deleted file mode 100644 index db041ef0b..000000000 --- a/ice/src/rand/mod.rs +++ /dev/null @@ -1,48 +0,0 @@ -#[cfg(test)] -mod rand_test; - -use rand::{thread_rng, Rng}; - -const RUNES_ALPHA: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; -const RUNES_CANDIDATE_ID_FOUNDATION: &[u8] = - b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/+"; - -const LEN_UFRAG: usize = 16; -const LEN_PWD: usize = 32; - -// TODO: cryptographically strong random source -pub fn generate_crypto_random_string(n: usize, runes: &[u8]) -> String { - let mut rng = thread_rng(); - - let rand_string: String = (0..n) - .map(|_| { - let idx = rng.gen_range(0..runes.len()); - runes[idx] as char - }) - .collect(); - - rand_string -} - -/// -/// candidate-id = "candidate" ":" foundation -/// foundation = 1*32ice-char -/// ice-char = ALPHA / DIGIT / "+" / "/" -pub fn generate_cand_id() -> String { - format!( - "candidate:{}", - generate_crypto_random_string(32, RUNES_CANDIDATE_ID_FOUNDATION) - ) -} - -/// Generates ICE pwd. -/// This internally uses `generate_crypto_random_string`. -pub fn generate_pwd() -> String { - generate_crypto_random_string(LEN_PWD, RUNES_ALPHA) -} - -/// ICE user fragment. -/// This internally uses `generate_crypto_random_string`. -pub fn generate_ufrag() -> String { - generate_crypto_random_string(LEN_UFRAG, RUNES_ALPHA) -} diff --git a/ice/src/rand/rand_test.rs b/ice/src/rand/rand_test.rs deleted file mode 100644 index bf2fdcaa1..000000000 --- a/ice/src/rand/rand_test.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::sync::Arc; - -use tokio::sync::Mutex; -use waitgroup::WaitGroup; - -use super::*; -use crate::error::Result; - -#[tokio::test] -async fn test_random_generator_collision() -> Result<()> { - let test_cases = vec![ - ( - "CandidateID", - 0, /*||-> String { - generate_cand_id() - },*/ - ), - ( - "PWD", 1, /*||-> String { - generate_pwd() - },*/ - ), - ( - "Ufrag", 2, /*|| ->String { - generate_ufrag() - },*/ - ), - ]; - - const N: usize = 10; - const ITERATION: usize = 10; - - for (name, test_case) in test_cases { - for _ in 0..ITERATION { - let rands = Arc::new(Mutex::new(vec![])); - - // Create a new wait group. - let wg = WaitGroup::new(); - - for _ in 0..N { - let w = wg.worker(); - let rs = Arc::clone(&rands); - - tokio::spawn(async move { - let _d = w; - - let s = if test_case == 0 { - generate_cand_id() - } else if test_case == 1 { - generate_pwd() - } else { - generate_ufrag() - }; - - let mut r = rs.lock().await; - r.push(s); - }); - } - wg.wait().await; - - let rs = rands.lock().await; - assert_eq!(rs.len(), N, "{name} Failed to generate randoms"); - - for i in 0..N { - for j in i + 1..N { - assert_ne!( - rs[i], rs[j], - "{}: generateRandString caused collision: {} == {}", - name, rs[i], rs[j], - ); - } - } - } - } - - Ok(()) -} diff --git a/ice/src/state/mod.rs b/ice/src/state/mod.rs deleted file mode 100644 index bf21877c9..000000000 --- a/ice/src/state/mod.rs +++ /dev/null @@ -1,112 +0,0 @@ -#[cfg(test)] -mod state_test; - -use std::fmt; - -/// An enum showing the state of a ICE Connection List of supported States. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum ConnectionState { - Unspecified, - - /// ICE agent is gathering addresses. - New, - - /// ICE agent has been given local and remote candidates, and is attempting to find a match. - Checking, - - /// ICE agent has a pairing, but is still checking other pairs. - Connected, - - /// ICE agent has finished. - Completed, - - /// ICE agent never could successfully connect. - Failed, - - /// ICE agent connected successfully, but has entered a failed state. - Disconnected, - - /// ICE agent has finished and is no longer handling requests. - Closed, -} - -impl Default for ConnectionState { - fn default() -> Self { - Self::Unspecified - } -} - -impl fmt::Display for ConnectionState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::Unspecified => "Unspecified", - Self::New => "New", - Self::Checking => "Checking", - Self::Connected => "Connected", - Self::Completed => "Completed", - Self::Failed => "Failed", - Self::Disconnected => "Disconnected", - Self::Closed => "Closed", - }; - write!(f, "{s}") - } -} - -impl From for ConnectionState { - fn from(v: u8) -> Self { - match v { - 1 => Self::New, - 2 => Self::Checking, - 3 => Self::Connected, - 4 => Self::Completed, - 5 => Self::Failed, - 6 => Self::Disconnected, - 7 => Self::Closed, - _ => Self::Unspecified, - } - } -} - -/// Describes the state of the candidate gathering process. -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum GatheringState { - Unspecified, - - /// Indicates candidate gathering is not yet started. - New, - - /// Indicates candidate gathering is ongoing. - Gathering, - - /// Indicates candidate gathering has been completed. - Complete, -} - -impl From for GatheringState { - fn from(v: u8) -> Self { - match v { - 1 => Self::New, - 2 => Self::Gathering, - 3 => Self::Complete, - _ => Self::Unspecified, - } - } -} - -impl Default for GatheringState { - fn default() -> Self { - Self::Unspecified - } -} - -impl fmt::Display for GatheringState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::New => "new", - Self::Gathering => "gathering", - Self::Complete => "complete", - Self::Unspecified => "unspecified", - }; - write!(f, "{s}") - } -} diff --git a/ice/src/state/state_test.rs b/ice/src/state/state_test.rs deleted file mode 100644 index 9e9382039..000000000 --- a/ice/src/state/state_test.rs +++ /dev/null @@ -1,46 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_connected_state_string() -> Result<()> { - let tests = vec![ - (ConnectionState::Unspecified, "Unspecified"), - (ConnectionState::New, "New"), - (ConnectionState::Checking, "Checking"), - (ConnectionState::Connected, "Connected"), - (ConnectionState::Completed, "Completed"), - (ConnectionState::Failed, "Failed"), - (ConnectionState::Disconnected, "Disconnected"), - (ConnectionState::Closed, "Closed"), - ]; - - for (connection_state, expected_string) in tests { - assert_eq!( - connection_state.to_string(), - expected_string, - "testCase: {expected_string} vs {connection_state}", - ) - } - - Ok(()) -} - -#[test] -fn test_gathering_state_string() -> Result<()> { - let tests = vec![ - (GatheringState::Unspecified, "unspecified"), - (GatheringState::New, "new"), - (GatheringState::Gathering, "gathering"), - (GatheringState::Complete, "complete"), - ]; - - for (gathering_state, expected_string) in tests { - assert_eq!( - gathering_state.to_string(), - expected_string, - "testCase: {expected_string} vs {gathering_state}", - ) - } - - Ok(()) -} diff --git a/ice/src/stats/mod.rs b/ice/src/stats/mod.rs deleted file mode 100644 index e3fc40676..000000000 --- a/ice/src/stats/mod.rs +++ /dev/null @@ -1,178 +0,0 @@ -use tokio::time::Instant; - -use crate::candidate::*; -use crate::network_type::*; - -// CandidatePairStats contains ICE candidate pair statistics -#[derive(Debug, Clone)] -pub struct CandidatePairStats { - // timestamp is the timestamp associated with this object. - pub timestamp: Instant, - - // local_candidate_id is the id of the local candidate - pub local_candidate_id: String, - - // remote_candidate_id is the id of the remote candidate - pub remote_candidate_id: String, - - // state represents the state of the checklist for the local and remote - // candidates in a pair. - pub state: CandidatePairState, - - // nominated is true when this valid pair that should be used for media - // if it is the highest-priority one amongst those whose nominated flag is set - pub nominated: bool, - - // packets_sent represents the total number of packets sent on this candidate pair. - pub packets_sent: u32, - - // packets_received represents the total number of packets received on this candidate pair. - pub packets_received: u32, - - // bytes_sent represents the total number of payload bytes sent on this candidate pair - // not including headers or padding. - pub bytes_sent: u64, - - // bytes_received represents the total number of payload bytes received on this candidate pair - // not including headers or padding. - pub bytes_received: u64, - - // last_packet_sent_timestamp represents the timestamp at which the last packet was - // sent on this particular candidate pair, excluding STUN packets. - pub last_packet_sent_timestamp: Instant, - - // last_packet_received_timestamp represents the timestamp at which the last packet - // was received on this particular candidate pair, excluding STUN packets. - pub last_packet_received_timestamp: Instant, - - // first_request_timestamp represents the timestamp at which the first STUN request - // was sent on this particular candidate pair. - pub first_request_timestamp: Instant, - - // last_request_timestamp represents the timestamp at which the last STUN request - // was sent on this particular candidate pair. The average interval between two - // consecutive connectivity checks sent can be calculated with - // (last_request_timestamp - first_request_timestamp) / requests_sent. - pub last_request_timestamp: Instant, - - // last_response_timestamp represents the timestamp at which the last STUN response - // was received on this particular candidate pair. - pub last_response_timestamp: Instant, - - // total_round_trip_time represents the sum of all round trip time measurements - // in seconds since the beginning of the session, based on STUN connectivity - // check responses (responses_received), including those that reply to requests - // that are sent in order to verify consent. The average round trip time can - // be computed from total_round_trip_time by dividing it by responses_received. - pub total_round_trip_time: f64, - - // current_round_trip_time represents the latest round trip time measured in seconds, - // computed from both STUN connectivity checks, including those that are sent - // for consent verification. - pub current_round_trip_time: f64, - - // available_outgoing_bitrate is calculated by the underlying congestion control - // by combining the available bitrate for all the outgoing RTP streams using - // this candidate pair. The bitrate measurement does not count the size of the - // ip or other transport layers like TCP or UDP. It is similar to the TIAS defined - // in RFC 3890, i.e., it is measured in bits per second and the bitrate is calculated - // over a 1 second window. - pub available_outgoing_bitrate: f64, - - // available_incoming_bitrate is calculated by the underlying congestion control - // by combining the available bitrate for all the incoming RTP streams using - // this candidate pair. The bitrate measurement does not count the size of the - // ip or other transport layers like TCP or UDP. It is similar to the TIAS defined - // in RFC 3890, i.e., it is measured in bits per second and the bitrate is - // calculated over a 1 second window. - pub available_incoming_bitrate: f64, - - // circuit_breaker_trigger_count represents the number of times the circuit breaker - // is triggered for this particular 5-tuple, ceasing transmission. - pub circuit_breaker_trigger_count: u32, - - // requests_received represents the total number of connectivity check requests - // received (including retransmissions). It is impossible for the receiver to - // tell whether the request was sent in order to check connectivity or check - // consent, so all connectivity checks requests are counted here. - pub requests_received: u64, - - // requests_sent represents the total number of connectivity check requests - // sent (not including retransmissions). - pub requests_sent: u64, - - // responses_received represents the total number of connectivity check responses received. - pub responses_received: u64, - - // responses_sent epresents the total number of connectivity check responses sent. - // Since we cannot distinguish connectivity check requests and consent requests, - // all responses are counted. - pub responses_sent: u64, - - // retransmissions_received represents the total number of connectivity check - // request retransmissions received. - pub retransmissions_received: u64, - - // retransmissions_sent represents the total number of connectivity check - // request retransmissions sent. - pub retransmissions_sent: u64, - - // consent_requests_sent represents the total number of consent requests sent. - pub consent_requests_sent: u64, - - // consent_expired_timestamp represents the timestamp at which the latest valid - // STUN binding response expired. - pub consent_expired_timestamp: Instant, -} - -// CandidateStats contains ICE candidate statistics related to the ICETransport objects. -#[derive(Debug, Clone)] -pub struct CandidateStats { - // timestamp is the timestamp associated with this object. - pub timestamp: Instant, - - // id is the candidate id - pub id: String, - - // network_type represents the type of network interface used by the base of a - // local candidate (the address the ICE agent sends from). Only present for - // local candidates; it's not possible to know what type of network interface - // a remote candidate is using. - // - // Note: - // This stat only tells you about the network interface used by the first "hop"; - // it's possible that a connection will be bottlenecked by another type of network. - // For example, when using Wi-Fi tethering, the networkType of the relevant candidate - // would be "wifi", even when the next hop is over a cellular connection. - pub network_type: NetworkType, - - // ip is the ip address of the candidate, allowing for IPv4 addresses and - // IPv6 addresses, but fully qualified domain names (FQDNs) are not allowed. - pub ip: String, - - // port is the port number of the candidate. - pub port: u16, - - // candidate_type is the "Type" field of the ICECandidate. - pub candidate_type: CandidateType, - - // priority is the "priority" field of the ICECandidate. - pub priority: u32, - - // url is the url of the TURN or STUN server indicated in the that translated - // this ip address. It is the url address surfaced in an PeerConnectionICEEvent. - pub url: String, - - // relay_protocol is the protocol used by the endpoint to communicate with the - // TURN server. This is only present for local candidates. Valid values for - // the TURN url protocol is one of udp, tcp, or tls. - pub relay_protocol: String, - - // deleted is true if the candidate has been deleted/freed. For host candidates, - // this means that any network resources (typically a socket) associated with the - // candidate have been released. For TURN candidates, this means the TURN allocation - // is no longer active. - // - // Only defined for local candidates. For remote candidates, this property is not applicable. - pub deleted: bool, -} diff --git a/ice/src/tcp_type/mod.rs b/ice/src/tcp_type/mod.rs deleted file mode 100644 index 11f7bdbcb..000000000 --- a/ice/src/tcp_type/mod.rs +++ /dev/null @@ -1,48 +0,0 @@ -#[cfg(test)] -mod tcp_type_test; - -use std::fmt; - -// TCPType is the type of ICE TCP candidate as described in -// ttps://tools.ietf.org/html/rfc6544#section-4.5 -#[derive(PartialEq, Eq, Debug, Copy, Clone)] -pub enum TcpType { - /// The default value. For example UDP candidates do not need this field. - Unspecified, - /// Active TCP candidate, which initiates TCP connections. - Active, - /// Passive TCP candidate, only accepts TCP connections. - Passive, - /// Like `Active` and `Passive` at the same time. - SimultaneousOpen, -} - -// from creates a new TCPType from string. -impl From<&str> for TcpType { - fn from(raw: &str) -> Self { - match raw { - "active" => Self::Active, - "passive" => Self::Passive, - "so" => Self::SimultaneousOpen, - _ => Self::Unspecified, - } - } -} - -impl fmt::Display for TcpType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::Active => "active", - Self::Passive => "passive", - Self::SimultaneousOpen => "so", - Self::Unspecified => "unspecified", - }; - write!(f, "{s}") - } -} - -impl Default for TcpType { - fn default() -> Self { - Self::Unspecified - } -} diff --git a/ice/src/tcp_type/tcp_type_test.rs b/ice/src/tcp_type/tcp_type_test.rs deleted file mode 100644 index 26a189e00..000000000 --- a/ice/src/tcp_type/tcp_type_test.rs +++ /dev/null @@ -1,18 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_tcp_type() -> Result<()> { - //assert_eq!(tcpType, TCPType::Unspecified) - assert_eq!(TcpType::from("active"), TcpType::Active); - assert_eq!(TcpType::from("passive"), TcpType::Passive); - assert_eq!(TcpType::from("so"), TcpType::SimultaneousOpen); - assert_eq!(TcpType::from("something else"), TcpType::Unspecified); - - assert_eq!(TcpType::Unspecified.to_string(), "unspecified"); - assert_eq!(TcpType::Active.to_string(), "active"); - assert_eq!(TcpType::Passive.to_string(), "passive"); - assert_eq!(TcpType::SimultaneousOpen.to_string(), "so"); - - Ok(()) -} diff --git a/ice/src/udp_mux/mod.rs b/ice/src/udp_mux/mod.rs deleted file mode 100644 index 59560ca43..000000000 --- a/ice/src/udp_mux/mod.rs +++ /dev/null @@ -1,338 +0,0 @@ -use std::collections::HashMap; -use std::io::ErrorKind; -use std::net::SocketAddr; -use std::sync::{Arc, Weak}; - -use async_trait::async_trait; -use tokio::sync::{watch, Mutex}; -use util::sync::RwLock; -use util::{Conn, Error}; - -mod udp_mux_conn; -pub use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams, UDPMuxWriter}; - -#[cfg(test)] -mod udp_mux_test; - -mod socket_addr_ext; - -use stun::attributes::ATTR_USERNAME; -use stun::message::{is_message as is_stun_message, Message as STUNMessage}; - -use crate::candidate::RECEIVE_MTU; - -/// Normalize a target socket addr for sending over a given local socket addr. This is useful when -/// a dual stack socket is used, in which case an IPv4 target needs to be mapped to an IPv6 -/// address. -fn normalize_socket_addr(target: &SocketAddr, socket_addr: &SocketAddr) -> SocketAddr { - match (target, socket_addr) { - (SocketAddr::V4(target_ipv4), SocketAddr::V6(_)) => { - let ipv6_mapped = target_ipv4.ip().to_ipv6_mapped(); - - SocketAddr::new(std::net::IpAddr::V6(ipv6_mapped), target_ipv4.port()) - } - // This will fail later if target is IPv6 and socket is IPv4, we ignore it here - (_, _) => *target, - } -} - -#[async_trait] -pub trait UDPMux { - /// Close the muxing. - async fn close(&self) -> Result<(), Error>; - - /// Get the underlying connection for a given ufrag. - async fn get_conn(self: Arc, ufrag: &str) -> Result, Error>; - - /// Remove the underlying connection for a given ufrag. - async fn remove_conn_by_ufrag(&self, ufrag: &str); -} - -pub struct UDPMuxParams { - conn: Box, -} - -impl UDPMuxParams { - pub fn new(conn: C) -> Self - where - C: Conn + Send + Sync + 'static, - { - Self { - conn: Box::new(conn), - } - } -} - -pub struct UDPMuxDefault { - /// The params this instance is configured with. - /// Contains the underlying UDP socket in use - params: UDPMuxParams, - - /// Maps from ufrag to the underlying connection. - conns: Mutex>, - - /// Maps from ip address to the underlying connection. - address_map: RwLock>, - - // Close sender - closed_watch_tx: Mutex>>, - - /// Close receiver - closed_watch_rx: watch::Receiver<()>, -} - -impl UDPMuxDefault { - pub fn new(params: UDPMuxParams) -> Arc { - let (closed_watch_tx, closed_watch_rx) = watch::channel(()); - - let mux = Arc::new(Self { - params, - conns: Mutex::default(), - address_map: RwLock::default(), - closed_watch_tx: Mutex::new(Some(closed_watch_tx)), - closed_watch_rx: closed_watch_rx.clone(), - }); - - let cloned_mux = Arc::clone(&mux); - cloned_mux.start_conn_worker(closed_watch_rx); - - mux - } - - pub async fn is_closed(&self) -> bool { - self.closed_watch_tx.lock().await.is_none() - } - - /// Create a muxed connection for a given ufrag. - fn create_muxed_conn(self: &Arc, ufrag: &str) -> Result { - let local_addr = self.params.conn.local_addr()?; - - let params = UDPMuxConnParams { - local_addr, - key: ufrag.into(), - udp_mux: Arc::downgrade(self) as Weak, - }; - - Ok(UDPMuxConn::new(params)) - } - - async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option { - let (result, message) = { - let mut m = STUNMessage::new(); - - (m.unmarshal_binary(buffer), m) - }; - - match result { - Err(err) => { - log::warn!("Failed to handle decode ICE from {}: {}", addr, err); - None - } - Ok(_) => { - let (attr, found) = message.attributes.get(ATTR_USERNAME); - if !found { - log::warn!("No username attribute in STUN message from {}", &addr); - return None; - } - - let s = match String::from_utf8(attr.value) { - // Per the RFC this shouldn't happen - // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3 - Err(err) => { - log::warn!( - "Failed to decode USERNAME from STUN message as UTF-8: {}", - err - ); - return None; - } - Ok(s) => s, - }; - - let conns = self.conns.lock().await; - let conn = s - .split(':') - .next() - .and_then(|ufrag| conns.get(ufrag)) - .cloned(); - - conn - } - } - } - - fn start_conn_worker(self: Arc, mut closed_watch_rx: watch::Receiver<()>) { - tokio::spawn(async move { - let mut buffer = [0u8; RECEIVE_MTU]; - - loop { - let loop_self = Arc::clone(&self); - let conn = &loop_self.params.conn; - - tokio::select! { - res = conn.recv_from(&mut buffer) => { - match res { - Ok((len, addr)) => { - // Find connection based on previously having seen this source address - let conn = { - let address_map = loop_self - .address_map - .read(); - - address_map.get(&addr).cloned() - }; - - let conn = match conn { - // If we couldn't find the connection based on source address, see if - // this is a STUN message and if so if we can find the connection based on ufrag. - None if is_stun_message(&buffer) => { - loop_self.conn_from_stun_message(&buffer, &addr).await - } - s @ Some(_) => s, - _ => None, - }; - - match conn { - None => { - log::trace!("Dropping packet from {}", &addr); - } - Some(conn) => { - if let Err(err) = conn.write_packet(&buffer[..len], addr).await { - log::error!("Failed to write packet: {}", err); - } - } - } - } - Err(Error::Io(err)) if err.0.kind() == ErrorKind::TimedOut => continue, - Err(err) => { - log::error!("Could not read udp packet: {}", err); - break; - } - } - } - _ = closed_watch_rx.changed() => { - return; - } - } - } - }); - } -} - -#[async_trait] -impl UDPMux for UDPMuxDefault { - async fn close(&self) -> Result<(), Error> { - if self.is_closed().await { - return Err(Error::ErrAlreadyClosed); - } - - let mut closed_tx = self.closed_watch_tx.lock().await; - - if let Some(tx) = closed_tx.take() { - let _ = tx.send(()); - drop(closed_tx); - - let old_conns = { - let mut conns = self.conns.lock().await; - - std::mem::take(&mut (*conns)) - }; - - // NOTE: We don't wait for these closure to complete - for (_, conn) in old_conns { - conn.close(); - } - - { - let mut address_map = self.address_map.write(); - - // NOTE: This is important, we need to drop all instances of `UDPMuxConn` to - // avoid a retain cycle due to the use of [`std::sync::Arc`] on both sides. - let _ = std::mem::take(&mut (*address_map)); - } - } - - Ok(()) - } - - async fn get_conn(self: Arc, ufrag: &str) -> Result, Error> { - if self.is_closed().await { - return Err(Error::ErrUseClosedNetworkConn); - } - - { - let mut conns = self.conns.lock().await; - if let Some(conn) = conns.get(ufrag) { - // UDPMuxConn uses `Arc` internally so it's cheap to clone, but because - // we implement `Conn` we need to further wrap it in an `Arc` here. - return Ok(Arc::new(conn.clone()) as Arc); - } - - let muxed_conn = self.create_muxed_conn(ufrag)?; - let mut close_rx = muxed_conn.close_rx(); - let cloned_self = Arc::clone(&self); - let cloned_ufrag = ufrag.to_string(); - tokio::spawn(async move { - let _ = close_rx.changed().await; - - // Arc needed - cloned_self.remove_conn_by_ufrag(&cloned_ufrag).await; - }); - - conns.insert(ufrag.into(), muxed_conn.clone()); - - Ok(Arc::new(muxed_conn) as Arc) - } - } - - async fn remove_conn_by_ufrag(&self, ufrag: &str) { - // Pion's ice implementation has both `RemoveConnByFrag` and `RemoveConn`, but since `conns` - // is keyed on `ufrag` their implementation is equivalent. - - let removed_conn = { - let mut conns = self.conns.lock().await; - conns.remove(ufrag) - }; - - if let Some(conn) = removed_conn { - let mut address_map = self.address_map.write(); - - for address in conn.get_addresses() { - address_map.remove(&address); - } - } - } -} - -#[async_trait] -impl UDPMuxWriter for UDPMuxDefault { - async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) { - if self.is_closed().await { - return; - } - - let key = conn.key(); - { - let mut addresses = self.address_map.write(); - - addresses - .entry(addr) - .and_modify(|e| { - if e.key() != key { - e.remove_address(&addr); - *e = conn.clone(); - } - }) - .or_insert_with(|| conn.clone()); - } - - log::debug!("Registered {} for {}", addr, key); - } - - async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result { - self.params - .conn - .send_to(buf, *target) - .await - .map_err(Into::into) - } -} diff --git a/ice/src/udp_mux/socket_addr_ext.rs b/ice/src/udp_mux/socket_addr_ext.rs deleted file mode 100644 index 7290b1b00..000000000 --- a/ice/src/udp_mux/socket_addr_ext.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::array::TryFromSliceError; -use std::convert::TryInto; -use std::net::SocketAddr; - -use util::Error; - -pub(super) trait SocketAddrExt { - ///Encode a representation of `self` into the buffer and return the length of this encoded - ///version. - /// - /// The buffer needs to be at least 27 bytes in length. - fn encode(&self, buffer: &mut [u8]) -> Result; - - /// Decode a `SocketAddr` from a buffer. The encoding should have previously been done with - /// [`SocketAddrExt::encode`]. - fn decode(buffer: &[u8]) -> Result; -} - -const IPV4_MARKER: u8 = 4; -const IPV4_ADDRESS_SIZE: usize = 7; -const IPV6_MARKER: u8 = 6; -const IPV6_ADDRESS_SIZE: usize = 27; - -pub(super) const MAX_ADDR_SIZE: usize = IPV6_ADDRESS_SIZE; - -impl SocketAddrExt for SocketAddr { - fn encode(&self, buffer: &mut [u8]) -> Result { - use std::net::SocketAddr::{V4, V6}; - - if buffer.len() < MAX_ADDR_SIZE { - return Err(Error::ErrBufferShort); - } - - match self { - V4(addr) => { - let marker = IPV4_MARKER; - let ip: [u8; 4] = addr.ip().octets(); - let port: u16 = addr.port(); - - buffer[0] = marker; - buffer[1..5].copy_from_slice(&ip); - buffer[5..7].copy_from_slice(&port.to_le_bytes()); - - Ok(7) - } - V6(addr) => { - let marker = IPV6_MARKER; - let ip: [u8; 16] = addr.ip().octets(); - let port: u16 = addr.port(); - let flowinfo = addr.flowinfo(); - let scope_id = addr.scope_id(); - - buffer[0] = marker; - buffer[1..17].copy_from_slice(&ip); - buffer[17..19].copy_from_slice(&port.to_le_bytes()); - buffer[19..23].copy_from_slice(&flowinfo.to_le_bytes()); - buffer[23..27].copy_from_slice(&scope_id.to_le_bytes()); - - Ok(MAX_ADDR_SIZE) - } - } - } - - fn decode(buffer: &[u8]) -> Result { - use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; - - match buffer[0] { - IPV4_MARKER => { - if buffer.len() < IPV4_ADDRESS_SIZE { - return Err(Error::ErrBufferShort); - } - - let ip_parts = &buffer[1..5]; - let port = match &buffer[5..7].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u16::from_le_bytes(*input), - }; - - let ip = Ipv4Addr::new(ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]); - - Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) - } - IPV6_MARKER => { - if buffer.len() < IPV6_ADDRESS_SIZE { - return Err(Error::ErrBufferShort); - } - - // Just to help the type system infer correctly - fn helper(b: &[u8]) -> Result<&[u8; 16], TryFromSliceError> { - b.try_into() - } - - let ip = match helper(&buffer[1..17]) { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => Ipv6Addr::from(*input), - }; - let port = match &buffer[17..19].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u16::from_le_bytes(*input), - }; - - let flowinfo = match &buffer[19..23].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u32::from_le_bytes(*input), - }; - - let scope_id = match &buffer[23..27].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u32::from_le_bytes(*input), - }; - - Ok(SocketAddr::V6(SocketAddrV6::new( - ip, port, flowinfo, scope_id, - ))) - } - _ => Err(Error::ErrFailedToParseIpaddr), - } - } -} - -#[cfg(test)] -mod test { - use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; - - use super::*; - - #[test] - fn test_ipv4() { - let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); - - let mut buffer = [0_u8; MAX_ADDR_SIZE]; - let encoded_len = ip.encode(&mut buffer); - - assert_eq!(encoded_len, Ok(7)); - assert_eq!( - &buffer[0..7], - &[IPV4_MARKER, 56, 128, 35, 5, 0x34, 0x12][..] - ); - - let decoded = SocketAddr::decode(&buffer); - - assert_eq!(decoded, Ok(ip)); - } - - #[test] - fn test_ipv6() { - let ip = SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::from([ - 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, - ]), - 0x1234, - 0x12345678, - 0x87654321, - )); - - let mut buffer = [0_u8; MAX_ADDR_SIZE]; - let encoded_len = ip.encode(&mut buffer); - - assert_eq!(encoded_len, Ok(27)); - assert_eq!( - &buffer[0..27], - &[ - IPV6_MARKER, // marker - // Start of ipv6 address - 92, - 114, - 235, - 3, - 244, - 64, - 38, - 111, - 20, - 100, - 199, - 241, - 19, - 174, - 220, - 123, - // LE port - 0x34, - 0x12, - // LE flowinfo - 0x78, - 0x56, - 0x34, - 0x12, - // LE scope_id - 0x21, - 0x43, - 0x65, - 0x87, - ][..] - ); - - let decoded = SocketAddr::decode(&buffer); - - assert_eq!(decoded, Ok(ip)); - } - - #[test] - fn test_encode_ipv4_with_short_buffer() { - let mut buffer = vec![0u8; IPV4_ADDRESS_SIZE - 1]; - let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); - - let result = ip.encode(&mut buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } - - #[test] - fn test_encode_ipv6_with_short_buffer() { - let mut buffer = vec![0u8; MAX_ADDR_SIZE - 1]; - let ip = SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::from([ - 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, - ]), - 0x1234, - 0x12345678, - 0x87654321, - )); - - let result = ip.encode(&mut buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } - - #[test] - fn test_decode_ipv4_with_short_buffer() { - let buffer = vec![IPV4_MARKER, 0]; - - let result = SocketAddr::decode(&buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } - - #[test] - fn test_decode_ipv6_with_short_buffer() { - let buffer = vec![IPV6_MARKER, 0]; - - let result = SocketAddr::decode(&buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } -} diff --git a/ice/src/udp_mux/udp_mux_conn.rs b/ice/src/udp_mux/udp_mux_conn.rs deleted file mode 100644 index b97183abf..000000000 --- a/ice/src/udp_mux/udp_mux_conn.rs +++ /dev/null @@ -1,320 +0,0 @@ -use std::collections::HashSet; -use std::convert::TryInto; -use std::io; -use std::net::SocketAddr; -use std::sync::{Arc, Weak}; - -use async_trait::async_trait; -use tokio::sync::watch; -use util::sync::Mutex; -use util::{Buffer, Conn, Error}; - -use super::socket_addr_ext::{SocketAddrExt, MAX_ADDR_SIZE}; -use super::{normalize_socket_addr, RECEIVE_MTU}; - -/// A trait for a [`UDPMuxConn`] to communicate with an UDP mux. -#[async_trait] -pub trait UDPMuxWriter { - /// Registers an address for the given connection. - async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr); - /// Sends the content of the buffer to the given target. - /// - /// Returns the number of bytes sent or an error, if any. - async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result; -} - -/// Parameters for a [`UDPMuxConn`]. -pub struct UDPMuxConnParams { - /// Local socket address. - pub local_addr: SocketAddr, - /// Static key identifying the connection. - pub key: String, - /// A `std::sync::Weak` reference to the UDP mux. - /// - /// NOTE: a non-owning reference should be used to prevent possible cycles. - pub udp_mux: Weak, -} - -type ConnResult = Result; - -/// A UDP mux connection. -#[derive(Clone)] -pub struct UDPMuxConn { - /// Close Receiver. A copy of this can be obtained via [`close_tx`]. - closed_watch_rx: watch::Receiver, - - inner: Arc, -} - -impl UDPMuxConn { - /// Creates a new [`UDPMuxConn`]. - pub fn new(params: UDPMuxConnParams) -> Self { - let (closed_watch_tx, closed_watch_rx) = watch::channel(false); - - Self { - closed_watch_rx, - inner: Arc::new(UDPMuxConnInner { - params, - closed_watch_tx: Mutex::new(Some(closed_watch_tx)), - addresses: Default::default(), - buffer: Buffer::new(0, 0), - }), - } - } - - /// Returns a key identifying this connection. - pub fn key(&self) -> &str { - &self.inner.params.key - } - - /// Writes data to the given address. Returns an error if the buffer is too short or there's an - /// encoding error. - pub async fn write_packet(&self, data: &[u8], addr: SocketAddr) -> ConnResult<()> { - // NOTE: Pion/ice uses Sync.Pool to optimise this. - let mut buffer = make_buffer(); - let mut offset = 0; - - if (data.len() + MAX_ADDR_SIZE) > (RECEIVE_MTU + MAX_ADDR_SIZE) { - return Err(Error::ErrBufferShort); - } - - // Format of buffer: | data len(2) | data bytes(dn) | addr len(2) | addr bytes(an) | - // Where the number in parenthesis indicate the number of bytes used - // `dn` and `an` are the length in bytes of data and addr respectively. - - // SAFETY: `data.len()` is at most RECEIVE_MTU(8192) - MAX_ADDR_SIZE(27) - buffer[0..2].copy_from_slice(&(data.len() as u16).to_le_bytes()[..]); - offset += 2; - - buffer[offset..offset + data.len()].copy_from_slice(data); - offset += data.len(); - - let len = addr.encode(&mut buffer[offset + 2..])?; - buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_le_bytes()[..]); - offset += 2 + len; - - self.inner.buffer.write(&buffer[..offset]).await?; - - Ok(()) - } - - /// Returns true if this connection is closed. - pub fn is_closed(&self) -> bool { - self.inner.is_closed() - } - - /// Gets a copy of the close [`tokio::sync::watch::Receiver`] that fires when this - /// connection is closed. - pub fn close_rx(&self) -> watch::Receiver { - self.closed_watch_rx.clone() - } - - /// Closes this connection. - pub fn close(&self) { - self.inner.close(); - } - - /// Gets the list of the addresses associated with this connection. - pub fn get_addresses(&self) -> Vec { - self.inner.get_addresses() - } - - /// Registers a new address for this connection. - pub async fn add_address(&self, addr: SocketAddr) { - self.inner.add_address(addr); - if let Some(mux) = self.inner.params.udp_mux.upgrade() { - mux.register_conn_for_address(self, addr).await; - } - } - - /// Deregisters an address. - pub fn remove_address(&self, addr: &SocketAddr) { - self.inner.remove_address(addr) - } - - /// Returns true if the given address is associated with this connection. - pub fn contains_address(&self, addr: &SocketAddr) -> bool { - self.inner.contains_address(addr) - } -} - -struct UDPMuxConnInner { - params: UDPMuxConnParams, - - /// Close Sender. We'll send a value on this channel when we close - closed_watch_tx: Mutex>>, - - /// Remote addresses we've seen on this connection. - addresses: Mutex>, - - buffer: Buffer, -} - -impl UDPMuxConnInner { - // Sending/Recieving - async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> { - // NOTE: Pion/ice uses Sync.Pool to optimise this. - let mut buffer = make_buffer(); - let mut offset = 0; - - let len = self.buffer.read(&mut buffer, None).await?; - // We always have at least. - // - // * 2 bytes for data len - // * 2 bytes for addr len - // * 7 bytes for an Ipv4 addr - if len < 11 { - return Err(Error::ErrBufferShort); - } - - let data_len: usize = buffer[..2] - .try_into() - .map(u16::from_le_bytes) - .map(From::from) - .unwrap(); - offset += 2; - - let total = 2 + data_len + 2 + 7; - if data_len > buf.len() || total > len { - return Err(Error::ErrBufferShort); - } - - buf[..data_len].copy_from_slice(&buffer[offset..offset + data_len]); - offset += data_len; - - let address_len: usize = buffer[offset..offset + 2] - .try_into() - .map(u16::from_le_bytes) - .map(From::from) - .unwrap(); - offset += 2; - - let addr = SocketAddr::decode(&buffer[offset..offset + address_len])?; - - Ok((data_len, addr)) - } - - async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> ConnResult { - if let Some(mux) = self.params.udp_mux.upgrade() { - mux.send_to(buf, target).await - } else { - Err(Error::Other(format!( - "wanted to send {} bytes to {}, but UDP mux is gone", - buf.len(), - target - ))) - } - } - - fn is_closed(&self) -> bool { - self.closed_watch_tx.lock().is_none() - } - - fn close(self: &Arc) { - let mut closed_tx = self.closed_watch_tx.lock(); - - if let Some(tx) = closed_tx.take() { - let _ = tx.send(true); - drop(closed_tx); - - let cloned_self = Arc::clone(self); - - { - let mut addresses = self.addresses.lock(); - *addresses = Default::default(); - } - - // NOTE: Alternatively we could wait on the buffer closing here so that - // our caller can wait for things to fully settle down - tokio::spawn(async move { - cloned_self.buffer.close().await; - }); - } - } - - fn local_addr(&self) -> SocketAddr { - self.params.local_addr - } - - // Address related methods - pub(super) fn get_addresses(&self) -> Vec { - let addresses = self.addresses.lock(); - - addresses.iter().copied().collect() - } - - pub(super) fn add_address(self: &Arc, addr: SocketAddr) { - { - let mut addresses = self.addresses.lock(); - addresses.insert(addr); - } - } - - pub(super) fn remove_address(&self, addr: &SocketAddr) { - { - let mut addresses = self.addresses.lock(); - addresses.remove(addr); - } - } - - pub(super) fn contains_address(&self, addr: &SocketAddr) -> bool { - let addresses = self.addresses.lock(); - - addresses.contains(addr) - } -} - -#[async_trait] -impl Conn for UDPMuxConn { - async fn connect(&self, _addr: SocketAddr) -> ConnResult<()> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, _buf: &mut [u8]) -> ConnResult { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> { - self.inner.recv_from(buf).await - } - - async fn send(&self, _buf: &[u8]) -> ConnResult { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> ConnResult { - let normalized_target = normalize_socket_addr(&target, &self.inner.params.local_addr); - - if !self.contains_address(&normalized_target) { - self.add_address(normalized_target).await; - } - - self.inner.send_to(buf, &normalized_target).await - } - - fn local_addr(&self) -> ConnResult { - Ok(self.inner.local_addr()) - } - - fn remote_addr(&self) -> Option { - None - } - async fn close(&self) -> ConnResult<()> { - self.inner.close(); - - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -#[inline(always)] -/// Create a buffer of appropriate size to fit both a packet with max RECEIVE_MTU and the -/// additional metadata used for muxing. -fn make_buffer() -> Vec { - // The 4 extra bytes are used to encode the length of the data and address respectively. - // See [`write_packet`] for details. - vec![0u8; RECEIVE_MTU + MAX_ADDR_SIZE + 2 + 2] -} diff --git a/ice/src/udp_mux/udp_mux_test.rs b/ice/src/udp_mux/udp_mux_test.rs deleted file mode 100644 index 10e7701b7..000000000 --- a/ice/src/udp_mux/udp_mux_test.rs +++ /dev/null @@ -1,292 +0,0 @@ -use std::convert::TryInto; -use std::io; -use std::time::Duration; - -use rand::{thread_rng, Rng}; -use sha1::{Digest, Sha1}; -use stun::message::{Message, BINDING_REQUEST}; -use tokio::net::UdpSocket; -use tokio::time::{sleep, timeout}; - -use super::*; -use crate::error::Result; - -#[derive(Debug, Copy, Clone)] -enum Network { - Ipv4, - Ipv6, -} - -impl Network { - /// Bind the UDP socket for the "remote". - async fn bind(self) -> io::Result { - match self { - Network::Ipv4 => UdpSocket::bind("0.0.0.0:0").await, - Network::Ipv6 => UdpSocket::bind("[::]:0").await, - } - } - - /// Connect ip from the "remote". - fn connect_ip(self, port: u16) -> String { - match self { - Network::Ipv4 => format!("127.0.0.1:{port}"), - Network::Ipv6 => format!("[::1]:{port}"), - } - } -} - -const TIMEOUT: Duration = Duration::from_secs(60); - -#[tokio::test] -async fn test_udp_mux() -> Result<()> { - use std::io::Write; - env_logger::Builder::from_default_env() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .init(); - - // TODO: Support IPv6 dual stack. This works Linux and macOS, but not Windows. - #[cfg(all(unix, target_pointer_width = "64"))] - let udp_socket = UdpSocket::bind((std::net::Ipv6Addr::UNSPECIFIED, 0)).await?; - - #[cfg(any(not(unix), not(target_pointer_width = "64")))] - let udp_socket = UdpSocket::bind((std::net::Ipv4Addr::UNSPECIFIED, 0)).await?; - - let addr = udp_socket.local_addr()?; - log::info!("Listening on {}", addr); - - let udp_mux = UDPMuxDefault::new(UDPMuxParams::new(udp_socket)); - let udp_mux_dyn = Arc::clone(&udp_mux) as Arc; - - let udp_mux_dyn_1 = Arc::clone(&udp_mux_dyn); - let h1 = tokio::spawn(async move { - timeout( - TIMEOUT, - test_mux_connection(Arc::clone(&udp_mux_dyn_1), "ufrag1", addr, Network::Ipv4), - ) - .await - }); - - let udp_mux_dyn_2 = Arc::clone(&udp_mux_dyn); - let h2 = tokio::spawn(async move { - timeout( - TIMEOUT, - test_mux_connection(Arc::clone(&udp_mux_dyn_2), "ufrag2", addr, Network::Ipv4), - ) - .await - }); - - let all_results; - - #[cfg(all(unix, target_pointer_width = "64"))] - { - // TODO: Support IPv6 dual stack. This works Linux and macOS, but not Windows. - let udp_mux_dyn_3 = Arc::clone(&udp_mux_dyn); - let h3 = tokio::spawn(async move { - timeout( - TIMEOUT, - test_mux_connection(Arc::clone(&udp_mux_dyn_3), "ufrag3", addr, Network::Ipv6), - ) - .await - }); - - let (r1, r2, r3) = tokio::join!(h1, h2, h3); - all_results = [r1, r2, r3]; - } - - #[cfg(any(not(unix), not(target_pointer_width = "64")))] - { - let (r1, r2) = tokio::join!(h1, h2); - all_results = [r1, r2]; - } - - for timeout_result in &all_results { - // Timeout error - match timeout_result { - Err(timeout_err) => { - panic!("Mux test timedout: {timeout_err:?}"); - } - - // Join error - Ok(join_result) => match join_result { - Err(err) => { - panic!("Mux test failed with join error: {err:?}"); - } - // Actual error - Ok(mux_result) => { - if let Err(err) = mux_result { - panic!("Mux test failed with error: {err:?}"); - } - } - }, - } - } - - let timeout = all_results.iter().find_map(|r| r.as_ref().err()); - assert!( - timeout.is_none(), - "At least one of the muxed tasks timedout {all_results:?}" - ); - - let res = udp_mux.close().await; - assert!(res.is_ok()); - let res = udp_mux.get_conn("failurefrag").await; - - assert!( - res.is_err(), - "Getting connections after UDPMuxDefault is closed should fail" - ); - - Ok(()) -} - -async fn test_mux_connection( - mux: Arc, - ufrag: &str, - listener_addr: SocketAddr, - network: Network, -) -> Result<()> { - let conn = mux.get_conn(ufrag).await?; - // FIXME: Cleanup - - let connect_addr = network - .connect_ip(listener_addr.port()) - .parse::() - .unwrap(); - - let remote_connection = Arc::new(network.bind().await?); - log::info!("Bound for ufrag: {}", ufrag); - remote_connection.connect(connect_addr).await?; - log::info!("Connected to {} for ufrag: {}", connect_addr, ufrag); - log::info!( - "Testing muxing from {} over {}", - remote_connection.local_addr().unwrap(), - listener_addr - ); - - // These bytes should be dropped - remote_connection.send("Dropped bytes".as_bytes()).await?; - - sleep(Duration::from_millis(1)).await; - - let stun_msg = { - let mut m = Message { - typ: BINDING_REQUEST, - ..Message::default() - }; - - m.add(ATTR_USERNAME, format!("{ufrag}:otherufrag").as_bytes()); - - m.marshal_binary().unwrap() - }; - - let remote_connection_addr = remote_connection.local_addr()?; - - conn.send_to(&stun_msg, remote_connection_addr).await?; - - let mut buffer = vec![0u8; RECEIVE_MTU]; - let len = remote_connection.recv(&mut buffer).await?; - assert_eq!(buffer[..len], stun_msg); - - const TARGET_SIZE: usize = 1024 * 1024; - - // Read on the muxed side - let conn_2 = Arc::clone(&conn); - let mux_handle = tokio::spawn(async move { - let conn = conn_2; - - let mut buffer = vec![0u8; RECEIVE_MTU]; - let mut next_sequence = 0; - let mut read = 0; - - while read < TARGET_SIZE { - let (n, _) = conn - .recv_from(&mut buffer) - .await - .expect("recv_from should not error"); - assert_eq!(n, RECEIVE_MTU); - - verify_packet(&buffer[..n], next_sequence); - - conn.send_to(&buffer[..n], remote_connection_addr) - .await - .expect("Failed to write to muxxed connection"); - - read += n; - log::debug!("Muxxed read {}, sequence: {}", read, next_sequence); - next_sequence += 1; - } - }); - - let remote_connection_2 = Arc::clone(&remote_connection); - let remote_handle = tokio::spawn(async move { - let remote_connection = remote_connection_2; - let mut buffer = vec![0u8; RECEIVE_MTU]; - let mut next_sequence = 0; - let mut read = 0; - - while read < TARGET_SIZE { - let n = remote_connection - .recv(&mut buffer) - .await - .expect("recv_from should not error"); - assert_eq!(n, RECEIVE_MTU); - - verify_packet(&buffer[..n], next_sequence); - read += n; - log::debug!("Remote read {}, sequence: {}", read, next_sequence); - next_sequence += 1; - } - }); - - let mut sequence: u32 = 0; - let mut written = 0; - let mut buffer = vec![0u8; RECEIVE_MTU]; - while written < TARGET_SIZE { - thread_rng().fill(&mut buffer[24..]); - - let hash = sha1_hash(&buffer[24..]); - buffer[4..24].copy_from_slice(&hash); - buffer[0..4].copy_from_slice(&sequence.to_le_bytes()); - - let len = remote_connection.send(&buffer).await?; - - written += len; - log::debug!("Data written {}, sequence: {}", written, sequence); - sequence += 1; - - sleep(Duration::from_millis(1)).await; - } - - let (r1, r2) = tokio::join!(mux_handle, remote_handle); - assert!(r1.is_ok() && r2.is_ok()); - - let res = conn.close().await; - assert!(res.is_ok(), "Failed to close Conn: {res:?}"); - - Ok(()) -} - -fn verify_packet(buffer: &[u8], next_sequence: u32) { - let read_sequence = u32::from_le_bytes(buffer[0..4].try_into().unwrap()); - assert_eq!(read_sequence, next_sequence); - - let hash = sha1_hash(&buffer[24..]); - assert_eq!(hash, buffer[4..24]); -} - -fn sha1_hash(buffer: &[u8]) -> Vec { - let mut hasher = Sha1::new(); - hasher.update(&buffer[24..]); - - hasher.finalize().to_vec() -} diff --git a/ice/src/udp_network.rs b/ice/src/udp_network.rs deleted file mode 100644 index e2077bdf7..000000000 --- a/ice/src/udp_network.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::sync::Arc; - -use super::udp_mux::UDPMux; -use super::Error; - -#[derive(Default, Clone)] -pub struct EphemeralUDP { - port_min: u16, - port_max: u16, -} - -impl EphemeralUDP { - pub fn new(port_min: u16, port_max: u16) -> Result { - let mut s = Self::default(); - s.set_ports(port_min, port_max)?; - - Ok(s) - } - - pub fn port_min(&self) -> u16 { - self.port_min - } - - pub fn port_max(&self) -> u16 { - self.port_max - } - - pub fn set_ports(&mut self, port_min: u16, port_max: u16) -> Result<(), Error> { - if port_max < port_min { - return Err(Error::ErrPort); - } - - self.port_min = port_min; - self.port_max = port_max; - - Ok(()) - } -} - -/// Configuration for the underlying UDP network stack. -/// There are two ways to configure this Ephemeral and Muxed. -/// -/// **Ephemeral mode** -/// -/// In Ephemeral mode sockets are created and bound to random ports during ICE -/// gathering. The ports to use can be restricted by setting [`EphemeralUDP::port_min`] and -/// [`EphemeralUDP::port_max`] in which case only ports in this range will be used. -/// -/// **Muxed** -/// -/// In muxed mode a single UDP socket is used and all connections are muxed over this single socket. -/// -#[derive(Clone)] -pub enum UDPNetwork { - Ephemeral(EphemeralUDP), - Muxed(Arc), -} - -impl Default for UDPNetwork { - fn default() -> Self { - Self::Ephemeral(Default::default()) - } -} - -impl UDPNetwork { - fn is_ephemeral(&self) -> bool { - matches!(self, Self::Ephemeral(_)) - } - - fn is_muxed(&self) -> bool { - matches!(self, Self::Muxed(_)) - } -} - -#[cfg(test)] -mod test { - use super::EphemeralUDP; - - #[test] - fn test_ephemeral_udp_constructor() { - assert!( - EphemeralUDP::new(3000, 2999).is_err(), - "EphemeralUDP should not allow invalid port range" - ); - - let e = EphemeralUDP::default(); - assert_eq!(e.port_min(), 0, "EphemeralUDP should default port_min to 0"); - assert_eq!(e.port_max(), 0, "EphemeralUDP should default port_max to 0"); - } - - #[test] - fn test_ephemeral_udp_set_ports() { - let mut e = EphemeralUDP::default(); - - assert!( - e.set_ports(3000, 2999).is_err(), - "EphemeralUDP should not allow invalid port range" - ); - - assert!( - e.set_ports(6000, 6001).is_ok(), - "EphemeralUDP::set_ports should allow valid port range" - ); - - assert_eq!( - e.port_min(), - 6000, - "Ports set with `EphemeralUDP::set_ports` should be reflected" - ); - assert_eq!( - e.port_max(), - 6001, - "Ports set with `EphemeralUDP::set_ports` should be reflected" - ); - } -} diff --git a/ice/src/url/mod.rs b/ice/src/url/mod.rs deleted file mode 100644 index d55fd48f1..000000000 --- a/ice/src/url/mod.rs +++ /dev/null @@ -1,266 +0,0 @@ -#[cfg(test)] -mod url_test; - -use std::borrow::Cow; -use std::convert::From; -use std::fmt; - -use crate::error::*; - -/// The type of server used in the ice.URL structure. -#[derive(PartialEq, Eq, Debug, Copy, Clone)] -pub enum SchemeType { - /// The URL represents a STUN server. - Stun, - - /// The URL represents a STUNS (secure) server. - Stuns, - - /// The URL represents a TURN server. - Turn, - - /// The URL represents a TURNS (secure) server. - Turns, - - /// Default public constant to use for "enum" like struct comparisons when no value was defined. - Unknown, -} - -impl Default for SchemeType { - fn default() -> Self { - Self::Unknown - } -} - -impl From<&str> for SchemeType { - /// Defines a procedure for creating a new `SchemeType` from a raw - /// string naming the scheme type. - fn from(raw: &str) -> Self { - match raw { - "stun" => Self::Stun, - "stuns" => Self::Stuns, - "turn" => Self::Turn, - "turns" => Self::Turns, - _ => Self::Unknown, - } - } -} - -impl fmt::Display for SchemeType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - SchemeType::Stun => "stun", - SchemeType::Stuns => "stuns", - SchemeType::Turn => "turn", - SchemeType::Turns => "turns", - SchemeType::Unknown => "unknown", - }; - write!(f, "{s}") - } -} - -/// The transport protocol type that is used in the `ice::url::Url` structure. -#[derive(PartialEq, Eq, Debug, Copy, Clone)] -pub enum ProtoType { - /// The URL uses a UDP transport. - Udp, - - /// The URL uses a TCP transport. - Tcp, - - Unknown, -} - -impl Default for ProtoType { - fn default() -> Self { - Self::Udp - } -} - -// defines a procedure for creating a new ProtoType from a raw -// string naming the transport protocol type. -impl From<&str> for ProtoType { - // NewSchemeType defines a procedure for creating a new SchemeType from a raw - // string naming the scheme type. - fn from(raw: &str) -> Self { - match raw { - "udp" => Self::Udp, - "tcp" => Self::Tcp, - _ => Self::Unknown, - } - } -} - -impl fmt::Display for ProtoType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Self::Udp => "udp", - Self::Tcp => "tcp", - Self::Unknown => "unknown", - }; - write!(f, "{s}") - } -} - -/// Represents a STUN (rfc7064) or TURN (rfc7065) URL. -#[derive(Debug, Clone, Default)] -pub struct Url { - pub scheme: SchemeType, - pub host: String, - pub port: u16, - pub username: String, - pub password: String, - pub proto: ProtoType, -} - -impl fmt::Display for Url { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let host = if self.host.contains("::") { - "[".to_owned() + self.host.as_str() + "]" - } else { - self.host.clone() - }; - if self.scheme == SchemeType::Turn || self.scheme == SchemeType::Turns { - write!( - f, - "{}:{}:{}?transport={}", - self.scheme, host, self.port, self.proto - ) - } else { - write!(f, "{}:{}:{}", self.scheme, host, self.port) - } - } -} - -impl Url { - /// Parses a STUN or TURN urls following the ABNF syntax described in - /// [IETF rfc-7064](https://tools.ietf.org/html/rfc7064) and - /// [IETF rfc-7065](https://tools.ietf.org/html/rfc7065) respectively. - pub fn parse_url(raw: &str) -> Result { - // work around for url crate - if raw.contains("//") { - return Err(Error::ErrInvalidUrl); - } - - let mut s = raw.to_string(); - let pos = raw.find(':'); - if let Some(p) = pos { - s.replace_range(p..=p, "://"); - } else { - return Err(Error::ErrSchemeType); - } - - let raw_parts = url::Url::parse(&s)?; - - let scheme = raw_parts.scheme().into(); - - let host = if let Some(host) = raw_parts.host_str() { - host.trim() - .trim_start_matches('[') - .trim_end_matches(']') - .to_owned() - } else { - return Err(Error::ErrHost); - }; - - let port = if let Some(port) = raw_parts.port() { - port - } else if scheme == SchemeType::Stun || scheme == SchemeType::Turn { - 3478 - } else { - 5349 - }; - - let mut q_args = raw_parts.query_pairs(); - let proto = match scheme { - SchemeType::Stun => { - if q_args.count() > 0 { - return Err(Error::ErrStunQuery); - } - ProtoType::Udp - } - SchemeType::Stuns => { - if q_args.count() > 0 { - return Err(Error::ErrStunQuery); - } - ProtoType::Tcp - } - SchemeType::Turn => { - if q_args.count() > 1 { - return Err(Error::ErrInvalidQuery); - } - if let Some((key, value)) = q_args.next() { - if key == Cow::Borrowed("transport") { - let proto: ProtoType = value.as_ref().into(); - if proto == ProtoType::Unknown { - return Err(Error::ErrProtoType); - } - proto - } else { - return Err(Error::ErrInvalidQuery); - } - } else { - ProtoType::Udp - } - } - SchemeType::Turns => { - if q_args.count() > 1 { - return Err(Error::ErrInvalidQuery); - } - if let Some((key, value)) = q_args.next() { - if key == Cow::Borrowed("transport") { - let proto: ProtoType = value.as_ref().into(); - if proto == ProtoType::Unknown { - return Err(Error::ErrProtoType); - } - proto - } else { - return Err(Error::ErrInvalidQuery); - } - } else { - ProtoType::Tcp - } - } - SchemeType::Unknown => { - return Err(Error::ErrSchemeType); - } - }; - - Ok(Self { - scheme, - host, - port, - username: "".to_owned(), - password: "".to_owned(), - proto, - }) - } - - /* - fn parse_proto(raw:&str) ->Result { - let qArgs= raw.split('='); - if qArgs.len() != 2 { - return Err(Error::ErrInvalidQuery.into()); - } - - var proto ProtoType - if rawProto := qArgs.Get("transport"); rawProto != "" { - if proto = NewProtoType(rawProto); proto == ProtoType(0) { - return ProtoType(Unknown), ErrProtoType - } - return proto, nil - } - - if len(qArgs) > 0 { - return ProtoType(Unknown), ErrInvalidQuery - } - - return proto, nil - }*/ - - /// Returns whether the this URL's scheme describes secure scheme or not. - #[must_use] - pub fn is_secure(&self) -> bool { - self.scheme == SchemeType::Stuns || self.scheme == SchemeType::Turns - } -} diff --git a/ice/src/url/url_test.rs b/ice/src/url/url_test.rs deleted file mode 100644 index acbf72789..000000000 --- a/ice/src/url/url_test.rs +++ /dev/null @@ -1,142 +0,0 @@ -use super::*; - -#[test] -fn test_parse_url_success() -> Result<()> { - let tests = vec![ - ( - "stun:google.de", - "stun:google.de:3478", - SchemeType::Stun, - false, - "google.de", - 3478, - ProtoType::Udp, - ), - ( - "stun:google.de:1234", - "stun:google.de:1234", - SchemeType::Stun, - false, - "google.de", - 1234, - ProtoType::Udp, - ), - ( - "stuns:google.de", - "stuns:google.de:5349", - SchemeType::Stuns, - true, - "google.de", - 5349, - ProtoType::Tcp, - ), - ( - "stun:[::1]:123", - "stun:[::1]:123", - SchemeType::Stun, - false, - "::1", - 123, - ProtoType::Udp, - ), - ( - "turn:google.de", - "turn:google.de:3478?transport=udp", - SchemeType::Turn, - false, - "google.de", - 3478, - ProtoType::Udp, - ), - ( - "turns:google.de", - "turns:google.de:5349?transport=tcp", - SchemeType::Turns, - true, - "google.de", - 5349, - ProtoType::Tcp, - ), - ( - "turn:google.de?transport=udp", - "turn:google.de:3478?transport=udp", - SchemeType::Turn, - false, - "google.de", - 3478, - ProtoType::Udp, - ), - ( - "turns:google.de?transport=tcp", - "turns:google.de:5349?transport=tcp", - SchemeType::Turns, - true, - "google.de", - 5349, - ProtoType::Tcp, - ), - ]; - - for ( - raw_url, - expected_url_string, - expected_scheme, - expected_secure, - expected_host, - expected_port, - expected_proto, - ) in tests - { - let url = Url::parse_url(raw_url)?; - - assert_eq!(url.scheme, expected_scheme, "testCase: {raw_url:?}"); - assert_eq!( - expected_url_string, - url.to_string(), - "testCase: {raw_url:?}" - ); - assert_eq!(url.is_secure(), expected_secure, "testCase: {raw_url:?}"); - assert_eq!(url.host, expected_host, "testCase: {raw_url:?}"); - assert_eq!(url.port, expected_port, "testCase: {raw_url:?}"); - assert_eq!(url.proto, expected_proto, "testCase: {raw_url:?}"); - } - - Ok(()) -} - -#[test] -fn test_parse_url_failure() -> Result<()> { - let tests = vec![ - ("", Error::ErrSchemeType), - (":::", Error::ErrUrlParse), - ("stun:[::1]:123:", Error::ErrPort), - ("stun:[::1]:123a", Error::ErrPort), - ("google.de", Error::ErrSchemeType), - ("stun:", Error::ErrHost), - ("stun:google.de:abc", Error::ErrPort), - ("stun:google.de?transport=udp", Error::ErrStunQuery), - ("stuns:google.de?transport=udp", Error::ErrStunQuery), - ("turn:google.de?trans=udp", Error::ErrInvalidQuery), - ("turns:google.de?trans=udp", Error::ErrInvalidQuery), - ( - "turns:google.de?transport=udp&another=1", - Error::ErrInvalidQuery, - ), - ("turn:google.de?transport=ip", Error::ErrProtoType), - ]; - - for (raw_url, expected_err) in tests { - let result = Url::parse_url(raw_url); - if let Err(err) = result { - assert_eq!( - err.to_string(), - expected_err.to_string(), - "testCase: '{raw_url}', expected err '{expected_err}', but got err '{err}'" - ); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} diff --git a/ice/src/use_candidate/mod.rs b/ice/src/use_candidate/mod.rs deleted file mode 100644 index 8bb0d47ca..000000000 --- a/ice/src/use_candidate/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -#[cfg(test)] -mod use_candidate_test; - -use stun::attributes::ATTR_USE_CANDIDATE; -use stun::message::*; - -/// Represents USE-CANDIDATE attribute. -#[derive(Default)] -pub struct UseCandidateAttr; - -impl Setter for UseCandidateAttr { - /// Adds USE-CANDIDATE attribute to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - m.add(ATTR_USE_CANDIDATE, &[]); - Ok(()) - } -} - -impl UseCandidateAttr { - #[must_use] - pub const fn new() -> Self { - Self - } - - /// Returns true if USE-CANDIDATE attribute is set. - #[must_use] - pub fn is_set(m: &Message) -> bool { - let result = m.get(ATTR_USE_CANDIDATE); - result.is_ok() - } -} diff --git a/ice/src/use_candidate/use_candidate_test.rs b/ice/src/use_candidate/use_candidate_test.rs deleted file mode 100644 index 671a7544c..000000000 --- a/ice/src/use_candidate/use_candidate_test.rs +++ /dev/null @@ -1,19 +0,0 @@ -use stun::message::BINDING_REQUEST; - -use super::*; -use crate::error::Result; - -#[test] -fn test_use_candidate_attr_add_to() -> Result<()> { - let mut m = Message::new(); - assert!(!UseCandidateAttr::is_set(&m), "should not be set"); - - m.build(&[Box::new(BINDING_REQUEST), Box::new(UseCandidateAttr::new())])?; - - let mut m1 = Message::new(); - m1.write(&m.raw)?; - - assert!(UseCandidateAttr::is_set(&m1), "should be set"); - - Ok(()) -} diff --git a/ice/src/util/mod.rs b/ice/src/util/mod.rs deleted file mode 100644 index a44cb09f2..000000000 --- a/ice/src/util/mod.rs +++ /dev/null @@ -1,175 +0,0 @@ -#[cfg(test)] -mod util_test; - -use std::collections::HashSet; -use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; - -use stun::agent::*; -use stun::attributes::*; -use stun::integrity::*; -use stun::message::*; -use stun::textattrs::*; -use stun::xoraddr::*; -use tokio::time::Duration; -use util::vnet::net::*; -use util::Conn; - -use crate::agent::agent_config::{InterfaceFilterFn, IpFilterFn}; -use crate::error::*; -use crate::network_type::*; - -pub fn create_addr(_network: NetworkType, ip: IpAddr, port: u16) -> SocketAddr { - /*if network.is_tcp(){ - return &net.TCPAddr{IP: ip, Port: port} - default: - return &net.UDPAddr{IP: ip, Port: port} - }*/ - SocketAddr::new(ip, port) -} - -pub fn assert_inbound_username(m: &Message, expected_username: &str) -> Result<()> { - let mut username = Username::new(ATTR_USERNAME, String::new()); - username.get_from(m)?; - - if username.to_string() != expected_username { - return Err(Error::Other(format!( - "{:?} expected({}) actual({})", - Error::ErrMismatchUsername, - expected_username, - username, - ))); - } - - Ok(()) -} - -pub fn assert_inbound_message_integrity(m: &mut Message, key: &[u8]) -> Result<()> { - let message_integrity_attr = MessageIntegrity(key.to_vec()); - Ok(message_integrity_attr.check(m)?) -} - -/// Initiates a stun requests to `server_addr` using conn, reads the response and returns the -/// `XORMappedAddress` returned by the stun server. -/// Adapted from stun v0.2. -pub async fn get_xormapped_addr( - conn: &Arc, - server_addr: SocketAddr, - deadline: Duration, -) -> Result { - let resp = stun_request(conn, server_addr, deadline).await?; - let mut addr = XorMappedAddress::default(); - addr.get_from(&resp)?; - Ok(addr) -} - -const MAX_MESSAGE_SIZE: usize = 1280; - -pub async fn stun_request( - conn: &Arc, - server_addr: SocketAddr, - deadline: Duration, -) -> Result { - let mut request = Message::new(); - request.build(&[Box::new(BINDING_REQUEST), Box::new(TransactionId::new())])?; - - conn.send_to(&request.raw, server_addr).await?; - let mut bs = vec![0_u8; MAX_MESSAGE_SIZE]; - let (n, _) = if deadline > Duration::from_secs(0) { - match tokio::time::timeout(deadline, conn.recv_from(&mut bs)).await { - Ok(result) => match result { - Ok((n, addr)) => (n, addr), - Err(err) => return Err(Error::Other(err.to_string())), - }, - Err(err) => return Err(Error::Other(err.to_string())), - } - } else { - conn.recv_from(&mut bs).await? - }; - - let mut res = Message::new(); - res.raw = bs[..n].to_vec(); - res.decode()?; - - Ok(res) -} - -pub async fn local_interfaces( - vnet: &Arc, - interface_filter: &Option, - ip_filter: &Option, - network_types: &[NetworkType], -) -> HashSet { - let mut ips = HashSet::new(); - let interfaces = vnet.get_interfaces().await; - - let (mut ipv4requested, mut ipv6requested) = (false, false); - for typ in network_types { - if typ.is_ipv4() { - ipv4requested = true; - } - if typ.is_ipv6() { - ipv6requested = true; - } - } - - for iface in interfaces { - if let Some(filter) = interface_filter { - if !filter(iface.name()) { - continue; - } - } - - for ipnet in iface.addrs() { - let ipaddr = ipnet.addr(); - - if !ipaddr.is_loopback() - && ((ipv4requested && ipaddr.is_ipv4()) || (ipv6requested && ipaddr.is_ipv6())) - && ip_filter - .as_ref() - .map(|filter| filter(ipaddr)) - .unwrap_or(true) - { - ips.insert(ipaddr); - } - } - } - - ips -} - -pub async fn listen_udp_in_port_range( - vnet: &Arc, - port_max: u16, - port_min: u16, - laddr: SocketAddr, -) -> Result> { - if laddr.port() != 0 || (port_min == 0 && port_max == 0) { - return Ok(vnet.bind(laddr).await?); - } - let i = if port_min == 0 { 1 } else { port_min }; - let j = if port_max == 0 { 0xFFFF } else { port_max }; - if i > j { - return Err(Error::ErrPort); - } - - let port_start = rand::random::() % (j - i + 1) + i; - let mut port_current = port_start; - loop { - let laddr = SocketAddr::new(laddr.ip(), port_current); - match vnet.bind(laddr).await { - Ok(c) => return Ok(c), - Err(err) => log::debug!("failed to listen {}: {}", laddr, err), - }; - - port_current += 1; - if port_current > j { - port_current = i; - } - if port_current == port_start { - break; - } - } - - Err(Error::ErrPort) -} diff --git a/ice/src/util/util_test.rs b/ice/src/util/util_test.rs deleted file mode 100644 index ab0faf94f..000000000 --- a/ice/src/util/util_test.rs +++ /dev/null @@ -1,10 +0,0 @@ -use super::*; - -#[tokio::test] -async fn test_local_interfaces() -> Result<()> { - let vnet = Arc::new(Net::new(None)); - let interfaces = vnet.get_interfaces().await; - let ips = local_interfaces(&vnet, &None, &None, &[NetworkType::Udp4, NetworkType::Udp6]).await; - log::info!("interfaces: {:?}, ips: {:?}", interfaces, ips); - Ok(()) -} diff --git a/interceptor/.gitignore b/interceptor/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/interceptor/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/interceptor/CHANGELOG.md b/interceptor/CHANGELOG.md deleted file mode 100644 index 5394a2e78..000000000 --- a/interceptor/CHANGELOG.md +++ /dev/null @@ -1,23 +0,0 @@ -# interceptor changelog - -## Unreleased - -## v0.9.0 - -* Fix over-NACK due not resetting lost_packets bitmask [\#372](https://github.com/webrtc-rs/webrtc/pull/372/). -* Further extended stats interceptors to collect stats for `RemoteOutoundRTPStats` and improve `RemoteInboundRTPStats` collection. [#282](https://github.com/webrtc-rs/webrtc/pull/282) by [@k0nserv](https://github.com/k0nserv). -* When generating periodic TWCC feedback packets we no longer burst several packets in a row to catch up, i.e., we now use `MissedTickBehavior::Skip` instead of the default `MissedTickBehavior::Burst` for the ticker in question. [#323](https://github.com/webrtc-rs/webrtc/pull/323) by [@k0nserv](https://github.com/k0nserv). -* Don't generate empty TWCC packets that libWebRTC will ignore. [#324](https://github.com/webrtc-rs/webrtc/pull/324) by [@k0nserv](https://github.com/k0nserv). -* Increased minimum support rust version to `1.60.0`. -* Increased required `webrtc-util` version to `0.7.0`. - -## v0.8.0 - -* [#14 Don't panic on seqnum rollover](https://github.com/webrtc-rs/interceptor/pull/14) contributed by by [@pthatcher](https://github.com/pthatcher). -* Add stats interceptor. Contributed by [@k0nserv](https://github.com/k0nserv) in [#277](https://github.com/webrtc-rs/webrtc/pull/277/) and [#225](https://github.com/webrtc-rs/webrtc/pull/225). -* Increased min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). - -## Prior to 0.8.0 - -Before 0.8.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/interceptor/releases). - diff --git a/interceptor/Cargo.toml b/interceptor/Cargo.toml deleted file mode 100644 index 3edbcd3d0..000000000 --- a/interceptor/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "interceptor" -version = "0.12.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of Pluggable RTP/RTCP processors" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/interceptor" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/interceptor" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["marshal", "sync"] } -rtp = { version = "0.11.0", path = "../rtp" } -rtcp = { version = "0.11.0", path = "../rtcp" } -srtp = { version = "0.13.0", path = "../srtp", package = "webrtc-srtp" } - -tokio = { version = "1.32.0", features = ["sync", "time"] } -async-trait = "0.1" -bytes = "1" -thiserror = "1" -rand = "0.8" -waitgroup = "0.1" -log = "0.4" -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" -chrono = "0.4.28" diff --git a/interceptor/LICENSE-APACHE b/interceptor/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/interceptor/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/interceptor/LICENSE-MIT b/interceptor/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/interceptor/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/interceptor/README.md b/interceptor/README.md deleted file mode 100644 index a35123ec3..000000000 --- a/interceptor/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of Pluggable RTP/RTCP processors. Rewrite Pion Interceptor in Rust -

diff --git a/interceptor/codecov.yml b/interceptor/codecov.yml deleted file mode 100644 index 2103309ff..000000000 --- a/interceptor/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: a1ee2aa3-5623-4b41-8ba8-446a6b6792df - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/interceptor/doc/webrtc.rs.png b/interceptor/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/interceptor/doc/webrtc.rs.png and /dev/null differ diff --git a/interceptor/src/chain.rs b/interceptor/src/chain.rs deleted file mode 100644 index 12f3bb161..000000000 --- a/interceptor/src/chain.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::sync::Arc; - -use crate::error::*; -use crate::stream_info::StreamInfo; -use crate::*; - -/// Chain is an interceptor that runs all child interceptors in order. -#[derive(Default)] -pub struct Chain { - interceptors: Vec>, -} - -impl Chain { - /// new returns a new Chain interceptor. - pub fn new(interceptors: Vec>) -> Self { - Chain { interceptors } - } - - pub fn add(&mut self, icpr: Arc) { - self.interceptors.push(icpr); - } -} - -#[async_trait] -impl Interceptor for Chain { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - mut reader: Arc, - ) -> Arc { - for icpr in &self.interceptors { - reader = icpr.bind_rtcp_reader(reader).await; - } - reader - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - mut writer: Arc, - ) -> Arc { - for icpr in &self.interceptors { - writer = icpr.bind_rtcp_writer(writer).await; - } - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - mut writer: Arc, - ) -> Arc { - for icpr in &self.interceptors { - writer = icpr.bind_local_stream(info, writer).await; - } - writer - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo) { - for icpr in &self.interceptors { - icpr.unbind_local_stream(info).await; - } - } - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - mut reader: Arc, - ) -> Arc { - for icpr in &self.interceptors { - reader = icpr.bind_remote_stream(info, reader).await; - } - reader - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo) { - for icpr in &self.interceptors { - icpr.unbind_remote_stream(info).await; - } - } - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - let mut errs = vec![]; - for icpr in &self.interceptors { - if let Err(err) = icpr.close().await { - errs.push(err); - } - } - flatten_errs(errs) - } -} diff --git a/interceptor/src/error.rs b/interceptor/src/error.rs deleted file mode 100644 index 1ee4d227e..000000000 --- a/interceptor/src/error.rs +++ /dev/null @@ -1,46 +0,0 @@ -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("Invalid Parent RTCP Reader")] - ErrInvalidParentRtcpReader, - #[error("Invalid Parent RTP Reader")] - ErrInvalidParentRtpReader, - #[error("Invalid Next RTP Writer")] - ErrInvalidNextRtpWriter, - #[error("Invalid CloseRx Channel")] - ErrInvalidCloseRx, - #[error("Invalid PacketRx Channel")] - ErrInvalidPacketRx, - #[error("IO EOF")] - ErrIoEOF, - #[error("Buffer is too short")] - ErrShortBuffer, - #[error("Invalid buffer size")] - ErrInvalidSize, - - #[error("{0}")] - Srtp(#[from] srtp::Error), - #[error("{0}")] - Rtcp(#[from] rtcp::Error), - #[error("{0}")] - Rtp(#[from] rtp::Error), - #[error("{0}")] - Util(#[from] util::Error), - - #[error("{0}")] - Other(String), -} - -/// flatten_errs flattens multiple errors into one -pub fn flatten_errs(errs: Vec) -> Result<()> { - if errs.is_empty() { - Ok(()) - } else { - let errs_strs: Vec = errs.into_iter().map(|e| e.to_string()).collect(); - Err(Error::Other(errs_strs.join("\n"))) - } -} diff --git a/interceptor/src/lib.rs b/interceptor/src/lib.rs deleted file mode 100644 index 98008d405..000000000 --- a/interceptor/src/lib.rs +++ /dev/null @@ -1,228 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use async_trait::async_trait; -use error::Result; -use stream_info::StreamInfo; - -pub mod chain; -mod error; -pub mod mock; -pub mod nack; -pub mod noop; -pub mod registry; -pub mod report; -pub mod stats; -pub mod stream_info; -pub mod stream_reader; -pub mod twcc; - -pub use error::Error; - -/// Attributes are a generic key/value store used by interceptors -pub type Attributes = HashMap; - -/// InterceptorBuilder provides an interface for constructing interceptors -pub trait InterceptorBuilder { - fn build(&self, id: &str) -> Result>; -} - -/// Interceptor can be used to add functionality to you PeerConnections by modifying any incoming/outgoing rtp/rtcp -/// packets, or sending your own packets as needed. -#[async_trait] -pub trait Interceptor { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc; - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc; - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - writer: Arc, - ) -> Arc; - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo); - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - reader: Arc, - ) -> Arc; - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo); - - async fn close(&self) -> Result<()>; -} - -/// RTPWriter is used by Interceptor.bind_local_stream. -#[async_trait] -pub trait RTPWriter { - /// write a rtp packet - async fn write(&self, pkt: &rtp::packet::Packet, attributes: &Attributes) -> Result; -} - -pub type RTPWriterBoxFn = Box< - dyn (Fn( - &rtp::packet::Packet, - &Attributes, - ) -> Pin> + Send + Sync>>) - + Send - + Sync, ->; -pub struct RTPWriterFn(pub RTPWriterBoxFn); - -#[async_trait] -impl RTPWriter for RTPWriterFn { - /// write a rtp packet - async fn write(&self, pkt: &rtp::packet::Packet, attributes: &Attributes) -> Result { - self.0(pkt, attributes).await - } -} - -/// RTPReader is used by Interceptor.bind_remote_stream. -#[async_trait] -pub trait RTPReader { - /// read a rtp packet - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)>; -} - -pub type RTPReaderBoxFn = Box< - dyn (Fn( - &mut [u8], - &Attributes, - ) - -> Pin> + Send + Sync>>) - + Send - + Sync, ->; -pub struct RTPReaderFn(pub RTPReaderBoxFn); - -#[async_trait] -impl RTPReader for RTPReaderFn { - /// read a rtp packet - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - self.0(buf, attributes).await - } -} - -/// RTCPWriter is used by Interceptor.bind_rtcpwriter. -#[async_trait] -pub trait RTCPWriter { - /// write a batch of rtcp packets - async fn write( - &self, - pkts: &[Box], - attributes: &Attributes, - ) -> Result; -} - -pub type RTCPWriterBoxFn = Box< - dyn (Fn( - &[Box], - &Attributes, - ) -> Pin> + Send + Sync>>) - + Send - + Sync, ->; - -pub struct RTCPWriterFn(pub RTCPWriterBoxFn); - -#[async_trait] -impl RTCPWriter for RTCPWriterFn { - /// write a batch of rtcp packets - async fn write( - &self, - pkts: &[Box], - attributes: &Attributes, - ) -> Result { - self.0(pkts, attributes).await - } -} - -/// RTCPReader is used by Interceptor.bind_rtcpreader. -#[async_trait] -pub trait RTCPReader { - /// read a batch of rtcp packets - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(Vec>, Attributes)>; -} - -pub type RTCPReaderBoxFn = Box< - dyn (Fn( - &mut [u8], - &Attributes, - ) -> Pin< - Box< - dyn Future< - Output = Result<( - Vec>, - Attributes, - )>, - > + Send - + Sync, - >, - >) + Send - + Sync, ->; - -pub struct RTCPReaderFn(pub RTCPReaderBoxFn); - -#[async_trait] -impl RTCPReader for RTCPReaderFn { - /// read a batch of rtcp packets - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(Vec>, Attributes)> { - self.0(buf, attributes).await - } -} - -/// Helper for the tests. -#[cfg(test)] -mod test { - use std::future::Future; - use std::time::Duration; - - pub async fn timeout_or_fail(duration: Duration, future: T) -> T::Output - where - T: Future, - { - tokio::time::timeout(duration, future) - .await - .expect("should not time out") - } -} diff --git a/interceptor/src/mock/mock_builder.rs b/interceptor/src/mock/mock_builder.rs deleted file mode 100644 index d6faabbca..000000000 --- a/interceptor/src/mock/mock_builder.rs +++ /dev/null @@ -1,23 +0,0 @@ -use std::sync::Arc; - -use crate::error::Result; -use crate::{Interceptor, InterceptorBuilder}; - -pub type MockBuilderResult = Result>; - -/// MockBuilder is a mock Builder for testing. -pub struct MockBuilder { - pub build: Box MockBuilderResult) + Send + Sync + 'static>, -} - -impl MockBuilder { - pub fn new MockBuilderResult) + Send + Sync + 'static>(f: F) -> Self { - MockBuilder { build: Box::new(f) } - } -} - -impl InterceptorBuilder for MockBuilder { - fn build(&self, id: &str) -> MockBuilderResult { - (self.build)(id) - } -} diff --git a/interceptor/src/mock/mock_interceptor.rs b/interceptor/src/mock/mock_interceptor.rs deleted file mode 100644 index 7b5088efb..000000000 --- a/interceptor/src/mock/mock_interceptor.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::future::Future; -use std::pin::Pin; - -use crate::*; - -pub type BindRtcpReaderFn = Box< - dyn (Fn( - Arc, - ) - -> Pin> + Send + Sync>>) - + Send - + Sync, ->; - -pub type BindRtcpWriterFn = Box< - dyn (Fn( - Arc, - ) - -> Pin> + Send + Sync>>) - + Send - + Sync, ->; -pub type BindLocalStreamFn = Box< - dyn (Fn( - &StreamInfo, - Arc, - ) -> Pin> + Send + Sync>>) - + Send - + Sync, ->; -pub type UnbindLocalStreamFn = - Box Pin + Send + Sync>>) + Send + Sync>; -pub type BindRemoteStreamFn = Box< - dyn (Fn( - &StreamInfo, - Arc, - ) -> Pin> + Send + Sync>>) - + Send - + Sync, ->; -pub type UnbindRemoteStreamFn = - Box Pin + Send + Sync>>) + Send + Sync>; -pub type CloseFn = - Box Pin> + Send + Sync>>) + Send + Sync>; - -/// MockInterceptor is an mock Interceptor for testing. -#[derive(Default)] -pub struct MockInterceptor { - pub bind_rtcp_reader_fn: Option, - pub bind_rtcp_writer_fn: Option, - pub bind_local_stream_fn: Option, - pub unbind_local_stream_fn: Option, - pub bind_remote_stream_fn: Option, - pub unbind_remote_stream_fn: Option, - pub close_fn: Option, -} - -#[async_trait] -impl Interceptor for MockInterceptor { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - if let Some(f) = &self.bind_rtcp_reader_fn { - f(reader).await - } else { - reader - } - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - if let Some(f) = &self.bind_rtcp_writer_fn { - f(writer).await - } else { - writer - } - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - writer: Arc, - ) -> Arc { - if let Some(f) = &self.bind_local_stream_fn { - f(info, writer).await - } else { - writer - } - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo) { - if let Some(f) = &self.unbind_local_stream_fn { - f(info).await - } - } - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - reader: Arc, - ) -> Arc { - if let Some(f) = &self.bind_remote_stream_fn { - f(info, reader).await - } else { - reader - } - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo) { - if let Some(f) = &self.unbind_remote_stream_fn { - f(info).await - } - } - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - if let Some(f) = &self.close_fn { - f().await - } else { - Ok(()) - } - } -} diff --git a/interceptor/src/mock/mock_stream.rs b/interceptor/src/mock/mock_stream.rs deleted file mode 100644 index 8183197a3..000000000 --- a/interceptor/src/mock/mock_stream.rs +++ /dev/null @@ -1,355 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::sync::{mpsc, Mutex}; -use util::Marshal; - -use crate::error::{Error, Result}; -use crate::stream_info::StreamInfo; -use crate::{Attributes, Interceptor, RTCPReader, RTCPWriter, RTPReader, RTPWriter}; - -type RTCPPackets = Vec>; - -/// MockStream is a helper struct for testing interceptors. -pub struct MockStream { - interceptor: Arc, - - rtcp_writer: Mutex>>, - rtp_writer: Mutex>>, - - rtcp_out_modified_tx: mpsc::Sender, - rtp_out_modified_tx: mpsc::Sender, - rtcp_in_rx: Mutex>, - rtp_in_rx: Mutex>, - - rtcp_out_modified_rx: Mutex>, - rtp_out_modified_rx: Mutex>, - rtcp_in_tx: Mutex>>, - rtp_in_tx: Mutex>>, - - rtcp_in_modified_rx: Mutex>>, - rtp_in_modified_rx: Mutex>>, -} - -impl MockStream { - /// new creates a new MockStream - pub async fn new( - info: &StreamInfo, - interceptor: Arc, - ) -> Arc { - let (rtcp_in_tx, rtcp_in_rx) = mpsc::channel(1000); - let (rtp_in_tx, rtp_in_rx) = mpsc::channel(1000); - let (rtcp_out_modified_tx, rtcp_out_modified_rx) = mpsc::channel(1000); - let (rtp_out_modified_tx, rtp_out_modified_rx) = mpsc::channel(1000); - let (rtcp_in_modified_tx, rtcp_in_modified_rx) = mpsc::channel(1000); - let (rtp_in_modified_tx, rtp_in_modified_rx) = mpsc::channel(1000); - - let stream = Arc::new(MockStream { - interceptor: Arc::clone(&interceptor), - - rtcp_writer: Mutex::new(None), - rtp_writer: Mutex::new(None), - - rtcp_in_tx: Mutex::new(Some(rtcp_in_tx)), - rtp_in_tx: Mutex::new(Some(rtp_in_tx)), - rtcp_in_rx: Mutex::new(rtcp_in_rx), - rtp_in_rx: Mutex::new(rtp_in_rx), - - rtcp_out_modified_tx, - rtp_out_modified_tx, - rtcp_out_modified_rx: Mutex::new(rtcp_out_modified_rx), - rtp_out_modified_rx: Mutex::new(rtp_out_modified_rx), - - rtcp_in_modified_rx: Mutex::new(rtcp_in_modified_rx), - rtp_in_modified_rx: Mutex::new(rtp_in_modified_rx), - }); - - let rtcp_writer = interceptor - .bind_rtcp_writer(Arc::clone(&stream) as Arc) - .await; - { - let mut rw = stream.rtcp_writer.lock().await; - *rw = Some(rtcp_writer); - } - let rtp_writer = interceptor - .bind_local_stream( - info, - Arc::clone(&stream) as Arc, - ) - .await; - { - let mut rw = stream.rtp_writer.lock().await; - *rw = Some(rtp_writer); - } - - let rtcp_reader = interceptor - .bind_rtcp_reader(Arc::clone(&stream) as Arc) - .await; - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - let a = Attributes::new(); - loop { - let pkts = match rtcp_reader.read(&mut buf, &a).await { - Ok((n, _)) => n, - Err(err) => { - let _ = rtcp_in_modified_tx.send(Err(err)).await; - break; - } - }; - - let _ = rtcp_in_modified_tx.send(Ok(pkts)).await; - } - }); - - let rtp_reader = interceptor - .bind_remote_stream( - info, - Arc::clone(&stream) as Arc, - ) - .await; - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - let a = Attributes::new(); - loop { - let pkt = match rtp_reader.read(&mut buf, &a).await { - Ok((pkt, _)) => pkt, - Err(err) => { - let _ = rtp_in_modified_tx.send(Err(err)).await; - break; - } - }; - - let _ = rtp_in_modified_tx.send(Ok(pkt)).await; - } - }); - - stream - } - - /// write_rtcp writes a batch of rtcp packet to the stream, using the interceptor - pub async fn write_rtcp( - &self, - pkt: &[Box], - ) -> Result { - let a = Attributes::new(); - let rtcp_writer = self.rtcp_writer.lock().await; - if let Some(writer) = &*rtcp_writer { - writer.write(pkt, &a).await - } else { - Err(Error::Other("invalid rtcp_writer".to_owned())) - } - } - - /// write_rtp writes an rtp packet to the stream, using the interceptor - pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result { - let a = Attributes::new(); - let rtp_writer = self.rtp_writer.lock().await; - if let Some(writer) = &*rtp_writer { - writer.write(pkt, &a).await - } else { - Err(Error::Other("invalid rtp_writer".to_owned())) - } - } - - /// receive_rtcp schedules a new rtcp batch, so it can be read be the stream - pub async fn receive_rtcp(&self, pkts: Vec>) { - let rtcp_in_tx = self.rtcp_in_tx.lock().await; - if let Some(tx) = &*rtcp_in_tx { - let _ = tx.send(pkts).await; - } - } - - /// receive_rtp schedules a rtp packet, so it can be read be the stream - pub async fn receive_rtp(&self, pkt: rtp::packet::Packet) { - let rtp_in_tx = self.rtp_in_tx.lock().await; - if let Some(tx) = &*rtp_in_tx { - let _ = tx.send(pkt).await; - } - } - - /// written_rtcp returns a channel containing the rtcp batches written, modified by the interceptor - pub async fn written_rtcp(&self) -> Option>> { - let mut rtcp_out_modified_rx = self.rtcp_out_modified_rx.lock().await; - rtcp_out_modified_rx.recv().await - } - - /// Returns the last rtcp packet bacth that was written, modified by the interceptor. - /// - /// NB: This method discards all other previously recoreded packet batches. - pub async fn last_written_rtcp( - &self, - ) -> Option>> { - let mut last = None; - let mut rtcp_out_modified_rx = self.rtcp_out_modified_rx.lock().await; - - while let Ok(v) = rtcp_out_modified_rx.try_recv() { - last = Some(v); - } - - last - } - - /// written_rtp returns a channel containing rtp packets written, modified by the interceptor - pub async fn written_rtp(&self) -> Option { - let mut rtp_out_modified_rx = self.rtp_out_modified_rx.lock().await; - rtp_out_modified_rx.recv().await - } - - /// read_rtcp returns a channel containing the rtcp batched read, modified by the interceptor - pub async fn read_rtcp( - &self, - ) -> Option>>> { - let mut rtcp_in_modified_rx = self.rtcp_in_modified_rx.lock().await; - rtcp_in_modified_rx.recv().await - } - - /// read_rtp returns a channel containing the rtp packets read, modified by the interceptor - pub async fn read_rtp(&self) -> Option> { - let mut rtp_in_modified_rx = self.rtp_in_modified_rx.lock().await; - rtp_in_modified_rx.recv().await - } - - /// close closes the stream and the underlying interceptor - pub async fn close(&self) -> Result<()> { - { - let mut rtcp_in_tx = self.rtcp_in_tx.lock().await; - rtcp_in_tx.take(); - } - { - let mut rtp_in_tx = self.rtp_in_tx.lock().await; - rtp_in_tx.take(); - } - self.interceptor.close().await - } -} - -#[async_trait] -impl RTCPWriter for MockStream { - async fn write( - &self, - pkts: &[Box], - _attributes: &Attributes, - ) -> Result { - let _ = self.rtcp_out_modified_tx.send(pkts.to_vec()).await; - - Ok(0) - } -} - -#[async_trait] -impl RTCPReader for MockStream { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(Vec>, Attributes)> { - let pkts = { - let mut rtcp_in = self.rtcp_in_rx.lock().await; - rtcp_in.recv().await.ok_or(Error::ErrIoEOF)? - }; - - let marshaled = rtcp::packet::marshal(&pkts)?; - let n = marshaled.len(); - if n > buf.len() { - return Err(Error::ErrShortBuffer); - } - - buf[..n].copy_from_slice(&marshaled); - Ok((pkts, a.clone())) - } -} - -#[async_trait] -impl RTPWriter for MockStream { - async fn write(&self, pkt: &rtp::packet::Packet, _a: &Attributes) -> Result { - let _ = self.rtp_out_modified_tx.send(pkt.clone()).await; - Ok(0) - } -} - -#[async_trait] -impl RTPReader for MockStream { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - let pkt = { - let mut rtp_in = self.rtp_in_rx.lock().await; - rtp_in.recv().await.ok_or(Error::ErrIoEOF)? - }; - - let marshaled = pkt.marshal()?; - let n = marshaled.len(); - if n > buf.len() { - return Err(Error::ErrShortBuffer); - } - - buf[..n].copy_from_slice(&marshaled); - Ok((pkt, a.clone())) - } -} - -#[cfg(test)] -mod test { - use tokio::time::Duration; - - use super::*; - use crate::noop::NoOp; - use crate::test::timeout_or_fail; - - #[tokio::test] - async fn test_mock_stream() -> Result<()> { - use rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; - - let s = MockStream::new(&StreamInfo::default(), Arc::new(NoOp)).await; - - s.write_rtcp(&[Box::::default()]) - .await?; - timeout_or_fail(Duration::from_millis(10), s.written_rtcp()).await; - let result = tokio::time::timeout(Duration::from_millis(10), s.written_rtcp()).await; - assert!( - result.is_err(), - "single rtcp packet written, but multiple found" - ); - - s.write_rtp(&rtp::packet::Packet::default()).await?; - timeout_or_fail(Duration::from_millis(10), s.written_rtp()).await; - let result = tokio::time::timeout(Duration::from_millis(10), s.written_rtp()).await; - assert!( - result.is_err(), - "single rtp packet written, but multiple found" - ); - - s.receive_rtcp(vec![Box::::default()]) - .await; - assert!( - timeout_or_fail(Duration::from_millis(10), s.read_rtcp()) - .await - .is_some(), - "read rtcp returned error", - ); - let result = tokio::time::timeout(Duration::from_millis(10), s.read_rtcp()).await; - assert!( - result.is_err(), - "single rtcp packet written, but multiple found" - ); - - s.receive_rtp(rtp::packet::Packet::default()).await; - assert!( - timeout_or_fail(Duration::from_millis(10), s.read_rtp()) - .await - .is_some(), - "read rtp returned error", - ); - let result = tokio::time::timeout(Duration::from_millis(10), s.read_rtp()).await; - assert!( - result.is_err(), - "single rtp packet written, but multiple found" - ); - - s.close().await?; - - Ok(()) - } -} diff --git a/interceptor/src/mock/mock_time.rs b/interceptor/src/mock/mock_time.rs deleted file mode 100644 index 566ae436e..000000000 --- a/interceptor/src/mock/mock_time.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::time::{Duration, SystemTime}; - -use util::sync::Mutex; - -/// MockTime is a helper to replace SystemTime::now() for testing purposes. -pub struct MockTime { - cur_now: Mutex, -} - -impl Default for MockTime { - fn default() -> Self { - MockTime { - cur_now: Mutex::new(SystemTime::UNIX_EPOCH), - } - } -} - -impl MockTime { - /// set_now sets the current time. - pub fn set_now(&self, now: SystemTime) { - let mut cur_now = self.cur_now.lock(); - *cur_now = now; - } - - /// now returns the current time. - pub fn now(&self) -> SystemTime { - let cur_now = self.cur_now.lock(); - *cur_now - } - - /// advance advances duration d - pub fn advance(&mut self, d: Duration) { - let mut cur_now = self.cur_now.lock(); - *cur_now = cur_now.checked_add(d).unwrap_or(*cur_now); - } -} diff --git a/interceptor/src/mock/mod.rs b/interceptor/src/mock/mod.rs deleted file mode 100644 index 0a561b8ff..000000000 --- a/interceptor/src/mock/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod mock_builder; -pub mod mock_interceptor; -pub mod mock_stream; -pub mod mock_time; diff --git a/interceptor/src/nack/generator/generator_stream.rs b/interceptor/src/nack/generator/generator_stream.rs deleted file mode 100644 index 8cd5a108c..000000000 --- a/interceptor/src/nack/generator/generator_stream.rs +++ /dev/null @@ -1,314 +0,0 @@ -use util::sync::Mutex; - -use super::*; -use crate::nack::UINT16SIZE_HALF; - -struct GeneratorStreamInternal { - packets: Vec, - size: u16, - end: u16, - started: bool, - last_consecutive: u16, -} - -impl GeneratorStreamInternal { - fn new(log2_size_minus_6: u8) -> Self { - GeneratorStreamInternal { - packets: vec![0u64; 1 << log2_size_minus_6], - size: 1 << (log2_size_minus_6 + 6), - end: 0, - started: false, - last_consecutive: 0, - } - } - - fn add(&mut self, seq: u16) { - if !self.started { - self.set_received(seq); - self.end = seq; - self.started = true; - self.last_consecutive = seq; - return; - } - - let last_consecutive_plus1 = self.last_consecutive.wrapping_add(1); - let diff = seq.wrapping_sub(self.end); - if diff == 0 { - return; - } else if diff < UINT16SIZE_HALF { - // this means a positive diff, in other words seq > end (with counting for rollovers) - let mut i = self.end.wrapping_add(1); - while i != seq { - // clear packets between end and seq (these may contain packets from a "size" ago) - self.del_received(i); - i = i.wrapping_add(1); - } - self.end = seq; - - let seq_sub_last_consecutive = seq.wrapping_sub(self.last_consecutive); - if last_consecutive_plus1 == seq { - self.last_consecutive = seq; - } else if seq_sub_last_consecutive > self.size { - let diff = seq.wrapping_sub(self.size); - self.last_consecutive = diff; - self.fix_last_consecutive(); // there might be valid packets at the beginning of the buffer now - } - } else if last_consecutive_plus1 == seq { - // negative diff, seq < end (with counting for rollovers) - self.last_consecutive = seq; - self.fix_last_consecutive(); // there might be other valid packets after seq - } - - self.set_received(seq); - } - - fn get(&self, seq: u16) -> bool { - let diff = self.end.wrapping_sub(seq); - if diff >= UINT16SIZE_HALF { - return false; - } - - if diff >= self.size { - return false; - } - - self.get_received(seq) - } - - fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec { - let until = self.end.wrapping_sub(skip_last_n); - let diff = until.wrapping_sub(self.last_consecutive); - if diff >= UINT16SIZE_HALF { - // until < s.last_consecutive (counting for rollover) - return vec![]; - } - - let mut missing_packet_seq_nums = vec![]; - let mut i = self.last_consecutive.wrapping_add(1); - let util_plus1 = until.wrapping_add(1); - while i != util_plus1 { - if !self.get_received(i) { - missing_packet_seq_nums.push(i); - } - i = i.wrapping_add(1); - } - - missing_packet_seq_nums - } - - fn set_received(&mut self, seq: u16) { - let pos = (seq % self.size) as usize; - self.packets[pos / 64] |= 1u64 << (pos % 64); - } - - fn del_received(&mut self, seq: u16) { - let pos = (seq % self.size) as usize; - self.packets[pos / 64] &= u64::MAX ^ (1u64 << (pos % 64)); - } - - fn get_received(&self, seq: u16) -> bool { - let pos = (seq % self.size) as usize; - (self.packets[pos / 64] & (1u64 << (pos % 64))) != 0 - } - - fn fix_last_consecutive(&mut self) { - let mut i = self.last_consecutive.wrapping_add(1); - while i != self.end.wrapping_add(1) && self.get_received(i) { - // find all consecutive packets - i = i.wrapping_add(1); - } - self.last_consecutive = i.wrapping_sub(1); - } -} - -pub(super) struct GeneratorStream { - parent_rtp_reader: Arc, - - internal: Mutex, -} - -impl GeneratorStream { - pub(super) fn new(log2_size_minus_6: u8, reader: Arc) -> Self { - GeneratorStream { - parent_rtp_reader: reader, - internal: Mutex::new(GeneratorStreamInternal::new(log2_size_minus_6)), - } - } - - pub(super) fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec { - let internal = self.internal.lock(); - internal.missing_seq_numbers(skip_last_n) - } - - pub(super) fn add(&self, seq: u16) { - let mut internal = self.internal.lock(); - internal.add(seq); - } -} - -/// RTPReader is used by Interceptor.bind_remote_stream. -#[async_trait] -impl RTPReader for GeneratorStream { - /// read a rtp packet - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - let (pkt, attr) = self.parent_rtp_reader.read(buf, a).await?; - - self.add(pkt.header.sequence_number); - - Ok((pkt, attr)) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_generator_stream() -> Result<()> { - let tests: Vec = vec![ - 0, 1, 127, 128, 129, 511, 512, 513, 32767, 32768, 32769, 65407, 65408, 65409, 65534, - 65535, - ]; - for start in tests { - let mut rl = GeneratorStreamInternal::new(1); - - let all = |min: u16, max: u16| -> Vec { - let mut result = vec![]; - let mut i = min; - let max_plus_1 = max.wrapping_add(1); - while i != max_plus_1 { - result.push(i); - i = i.wrapping_add(1); - } - result - }; - - let join = |parts: &[&[u16]]| -> Vec { - let mut result = vec![]; - for p in parts { - result.extend_from_slice(p); - } - result - }; - - let add = |rl: &mut GeneratorStreamInternal, nums: &[u16]| { - for n in nums { - let seq = start.wrapping_add(*n); - rl.add(seq); - } - }; - - let assert_get = |rl: &GeneratorStreamInternal, nums: &[u16]| { - for n in nums { - let seq = start.wrapping_add(*n); - assert!(rl.get(seq), "not found: {seq}"); - } - }; - - let assert_not_get = |rl: &GeneratorStreamInternal, nums: &[u16]| { - for n in nums { - let seq = start.wrapping_add(*n); - assert!( - !rl.get(seq), - "packet found: start {}, n {}, seq {}", - start, - *n, - seq - ); - } - }; - - let assert_missing = |rl: &GeneratorStreamInternal, skip_last_n: u16, nums: &[u16]| { - let missing = rl.missing_seq_numbers(skip_last_n); - let mut want = vec![]; - for n in nums { - let seq = start.wrapping_add(*n); - want.push(seq); - } - assert_eq!(want, missing, "missing want/got, "); - }; - - let assert_last_consecutive = |rl: &GeneratorStreamInternal, last_consecutive: u16| { - let want = last_consecutive.wrapping_add(start); - assert_eq!(rl.last_consecutive, want, "invalid last_consecutive want"); - }; - - add(&mut rl, &[0]); - assert_get(&rl, &[0]); - assert_missing(&rl, 0, &[]); - assert_last_consecutive(&rl, 0); // first element added - - add(&mut rl, &all(1, 127)); - assert_get(&rl, &all(1, 127)); - assert_missing(&rl, 0, &[]); - assert_last_consecutive(&rl, 127); - - add(&mut rl, &[128]); - assert_get(&rl, &[128]); - assert_not_get(&rl, &[0]); - assert_missing(&rl, 0, &[]); - assert_last_consecutive(&rl, 128); - - add(&mut rl, &[130]); - assert_get(&rl, &[130]); - assert_not_get(&rl, &[1, 2, 129]); - assert_missing(&rl, 0, &[129]); - assert_last_consecutive(&rl, 128); - - add(&mut rl, &[333]); - assert_get(&rl, &[333]); - assert_not_get(&rl, &all(0, 332)); - assert_missing(&rl, 0, &all(206, 332)); // all 127 elements missing before 333 - assert_missing(&rl, 10, &all(206, 323)); // skip last 10 packets (324-333) from check - assert_last_consecutive(&rl, 205); // lastConsecutive is still out of the buffer - - add(&mut rl, &[329]); - assert_get(&rl, &[329]); - assert_missing(&rl, 0, &join(&[&all(206, 328), &all(330, 332)])); - assert_missing(&rl, 5, &join(&[&all(206, 328)])); // skip last 5 packets (329-333) from check - assert_last_consecutive(&rl, 205); - - add(&mut rl, &all(207, 320)); - assert_get(&rl, &all(207, 320)); - assert_missing(&rl, 0, &join(&[&[206], &all(321, 328), &all(330, 332)])); - assert_last_consecutive(&rl, 205); - - add(&mut rl, &[334]); - assert_get(&rl, &[334]); - assert_not_get(&rl, &[206]); - assert_missing(&rl, 0, &join(&[&all(321, 328), &all(330, 332)])); - assert_last_consecutive(&rl, 320); // head of buffer is full of consecutive packages - - add(&mut rl, &all(322, 328)); - assert_get(&rl, &all(322, 328)); - assert_missing(&rl, 0, &join(&[&[321], &all(330, 332)])); - assert_last_consecutive(&rl, 320); - - add(&mut rl, &[321]); - assert_get(&rl, &[321]); - assert_missing(&rl, 0, &all(330, 332)); - assert_last_consecutive(&rl, 329); // after adding a single missing packet, lastConsecutive should jump forward - } - - Ok(()) - } - - #[test] - fn test_generator_stream_rollover() { - let mut rl = GeneratorStreamInternal::new(1); - // Make sure it doesn't panic. - rl.add(65533); - rl.add(65535); - rl.add(65534); - - let mut rl = GeneratorStreamInternal::new(1); - // Make sure it doesn't panic. - rl.add(65534); - rl.add(0); - rl.add(65535); - } -} diff --git a/interceptor/src/nack/generator/generator_test.rs b/interceptor/src/nack/generator/generator_test.rs deleted file mode 100644 index b0ed1805e..000000000 --- a/interceptor/src/nack/generator/generator_test.rs +++ /dev/null @@ -1,66 +0,0 @@ -use rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack; - -use super::*; -use crate::mock::mock_stream::MockStream; -use crate::stream_info::RTCPFeedback; -use crate::test::timeout_or_fail; - -#[tokio::test] -async fn test_generator_interceptor() -> Result<()> { - const INTERVAL: Duration = Duration::from_millis(10); - let icpr: Arc = Generator::builder() - .with_log2_size_minus_6(0) - .with_skip_last_n(2) - .with_interval(INTERVAL) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtcp_feedback: vec![RTCPFeedback { - typ: "nack".to_owned(), - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - for seq_num in [10, 11, 12, 14, 16, 18] { - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: seq_num, - ..Default::default() - }, - ..Default::default() - }) - .await; - - let r = timeout_or_fail(Duration::from_millis(10), stream.read_rtp()) - .await - .expect("A read packet") - .expect("Not an error"); - assert_eq!(r.header.sequence_number, seq_num); - } - - tokio::time::sleep(INTERVAL * 2).await; // wait for at least 2 nack packets - - // ignore the first nack, it might only contain the sequence id 13 as missing - let _ = stream.written_rtcp().await; - - let r = timeout_or_fail(Duration::from_millis(10), stream.written_rtcp()) - .await - .expect("Write rtcp"); - if let Some(p) = r[0].as_any().downcast_ref::() { - assert_eq!(p.nacks[0].packet_id, 13); - assert_eq!(p.nacks[0].lost_packets, 0b10); // we want packets: 13, 15 (not packet 17, because skipLastN is setReceived to 2) - } else { - panic!("single packet RTCP Compound Packet expected"); - } - - stream.close().await?; - - Ok(()) -} diff --git a/interceptor/src/nack/generator/mod.rs b/interceptor/src/nack/generator/mod.rs deleted file mode 100644 index 93d13dac4..000000000 --- a/interceptor/src/nack/generator/mod.rs +++ /dev/null @@ -1,255 +0,0 @@ -mod generator_stream; -#[cfg(test)] -mod generator_test; - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use async_trait::async_trait; -use generator_stream::GeneratorStream; -use rtcp::transport_feedbacks::transport_layer_nack::{ - nack_pairs_from_sequence_numbers, TransportLayerNack, -}; -use tokio::sync::{mpsc, Mutex}; -use waitgroup::WaitGroup; - -use crate::error::{Error, Result}; -use crate::nack::stream_support_nack; -use crate::stream_info::StreamInfo; -use crate::{ - Attributes, Interceptor, InterceptorBuilder, RTCPReader, RTCPWriter, RTPReader, RTPWriter, -}; - -/// GeneratorBuilder can be used to configure Generator Interceptor -#[derive(Default)] -pub struct GeneratorBuilder { - log2_size_minus_6: Option, - skip_last_n: Option, - interval: Option, -} - -impl GeneratorBuilder { - /// with_size sets the size of the interceptor. - /// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 - pub fn with_log2_size_minus_6(mut self, log2_size_minus_6: u8) -> GeneratorBuilder { - self.log2_size_minus_6 = Some(log2_size_minus_6); - self - } - - /// with_skip_last_n sets the number of packets (n-1 packets before the last received packets) to ignore when generating - /// nack requests. - pub fn with_skip_last_n(mut self, skip_last_n: u16) -> GeneratorBuilder { - self.skip_last_n = Some(skip_last_n); - self - } - - /// with_interval sets the nack send interval for the interceptor - pub fn with_interval(mut self, interval: Duration) -> GeneratorBuilder { - self.interval = Some(interval); - self - } -} - -impl InterceptorBuilder for GeneratorBuilder { - fn build(&self, _id: &str) -> Result> { - let (close_tx, close_rx) = mpsc::channel(1); - Ok(Arc::new(Generator { - internal: Arc::new(GeneratorInternal { - log2_size_minus_6: if let Some(log2_size_minus_6) = self.log2_size_minus_6 { - log2_size_minus_6 - } else { - 13 - 6 // 8192 = 1 << 13 - }, - skip_last_n: self.skip_last_n.unwrap_or_default(), - interval: if let Some(interval) = self.interval { - interval - } else { - Duration::from_millis(100) - }, - - streams: Mutex::new(HashMap::new()), - close_rx: Mutex::new(Some(close_rx)), - }), - - wg: Mutex::new(Some(WaitGroup::new())), - close_tx: Mutex::new(Some(close_tx)), - })) - } -} - -struct GeneratorInternal { - log2_size_minus_6: u8, - skip_last_n: u16, - interval: Duration, - - streams: Mutex>>, - close_rx: Mutex>>, -} - -/// Generator interceptor generates nack feedback messages. -pub struct Generator { - internal: Arc, - - pub(crate) wg: Mutex>, - pub(crate) close_tx: Mutex>>, -} - -impl Generator { - /// builder returns a new GeneratorBuilder. - pub fn builder() -> GeneratorBuilder { - GeneratorBuilder::default() - } - - async fn is_closed(&self) -> bool { - let close_tx = self.close_tx.lock().await; - close_tx.is_none() - } - - async fn run( - rtcp_writer: Arc, - internal: Arc, - ) -> Result<()> { - let mut ticker = tokio::time::interval(internal.interval); - let mut close_rx = { - let mut close_rx = internal.close_rx.lock().await; - if let Some(close) = close_rx.take() { - close - } else { - return Err(Error::ErrInvalidCloseRx); - } - }; - - let sender_ssrc = rand::random::(); - loop { - tokio::select! { - _ = ticker.tick() =>{ - let nacks = { - let mut nacks = vec![]; - let streams = internal.streams.lock().await; - for (ssrc, stream) in streams.iter() { - let missing = stream.missing_seq_numbers(internal.skip_last_n); - if missing.is_empty(){ - continue; - } - - nacks.push(TransportLayerNack{ - sender_ssrc, - media_ssrc: *ssrc, - nacks: nack_pairs_from_sequence_numbers(&missing), - }); - } - nacks - }; - - let a = Attributes::new(); - for nack in nacks{ - if let Err(err) = rtcp_writer.write(&[Box::new(nack)], &a).await{ - log::warn!("failed sending nack: {}", err); - } - } - } - _ = close_rx.recv() =>{ - return Ok(()); - } - } - } - } -} - -#[async_trait] -impl Interceptor for Generator { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - reader - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - if self.is_closed().await { - return writer; - } - - let mut w = { - let wait_group = self.wg.lock().await; - wait_group.as_ref().map(|wg| wg.worker()) - }; - let writer2 = Arc::clone(&writer); - let internal = Arc::clone(&self.internal); - tokio::spawn(async move { - let _d = w.take(); - if let Err(err) = Generator::run(writer2, internal).await { - log::warn!("bind_rtcp_writer NACK Generator::run got error: {}", err); - } - }); - - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - _info: &StreamInfo, - writer: Arc, - ) -> Arc { - writer - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, _info: &StreamInfo) {} - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - reader: Arc, - ) -> Arc { - if !stream_support_nack(info) { - return reader; - } - - let stream = Arc::new(GeneratorStream::new( - self.internal.log2_size_minus_6, - reader, - )); - { - let mut streams = self.internal.streams.lock().await; - streams.insert(info.ssrc, Arc::clone(&stream)); - } - - stream - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo) { - let mut receive_logs = self.internal.streams.lock().await; - receive_logs.remove(&info.ssrc); - } - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - { - let mut close_tx = self.close_tx.lock().await; - close_tx.take(); - } - - { - let mut wait_group = self.wg.lock().await; - if let Some(wg) = wait_group.take() { - wg.wait().await; - } - } - - Ok(()) - } -} diff --git a/interceptor/src/nack/mod.rs b/interceptor/src/nack/mod.rs deleted file mode 100644 index 87abe5039..000000000 --- a/interceptor/src/nack/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -use crate::stream_info::StreamInfo; - -pub mod generator; -pub mod responder; - -const UINT16SIZE_HALF: u16 = 1 << 15; - -fn stream_support_nack(info: &StreamInfo) -> bool { - for fb in &info.rtcp_feedback { - if fb.typ == "nack" && fb.parameter.is_empty() { - return true; - } - } - - false -} diff --git a/interceptor/src/nack/responder/mod.rs b/interceptor/src/nack/responder/mod.rs deleted file mode 100644 index 934a948a2..000000000 --- a/interceptor/src/nack/responder/mod.rs +++ /dev/null @@ -1,203 +0,0 @@ -mod responder_stream; -#[cfg(test)] -mod responder_test; - -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use async_trait::async_trait; -use responder_stream::ResponderStream; -use rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack; -use tokio::sync::Mutex; - -use crate::error::Result; -use crate::nack::stream_support_nack; -use crate::stream_info::StreamInfo; -use crate::{ - Attributes, Interceptor, InterceptorBuilder, RTCPReader, RTCPWriter, RTPReader, RTPWriter, -}; - -/// GeneratorBuilder can be used to configure Responder Interceptor -#[derive(Default)] -pub struct ResponderBuilder { - log2_size: Option, -} - -impl ResponderBuilder { - /// with_log2_size sets the size of the interceptor. - /// Size must be one of: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 - pub fn with_log2_size(mut self, log2_size: u8) -> ResponderBuilder { - self.log2_size = Some(log2_size); - self - } -} - -impl InterceptorBuilder for ResponderBuilder { - fn build(&self, _id: &str) -> Result> { - Ok(Arc::new(Responder { - internal: Arc::new(ResponderInternal { - log2_size: if let Some(log2_size) = self.log2_size { - log2_size - } else { - 13 // 8192 = 1 << 13 - }, - streams: Arc::new(Mutex::new(HashMap::new())), - }), - })) - } -} - -pub struct ResponderInternal { - log2_size: u8, - streams: Arc>>>, -} - -impl ResponderInternal { - async fn resend_packets( - streams: Arc>>>, - nack: TransportLayerNack, - ) { - let stream = { - let m = streams.lock().await; - if let Some(stream) = m.get(&nack.media_ssrc) { - stream.clone() - } else { - return; - } - }; - - for n in &nack.nacks { - // can't use n.range() since this callback is async fn, - // instead, use NackPair into_iter() - let stream2 = Arc::clone(&stream); - let f = Box::new( - move |seq: u16| -> Pin + Send + 'static>> { - let stream3 = Arc::clone(&stream2); - Box::pin(async move { - if let Some(p) = stream3.get(seq).await { - let a = Attributes::new(); - if let Err(err) = stream3.next_rtp_writer.write(&p, &a).await { - log::warn!("failed resending nacked packet: {}", err); - } - } - true - }) - }, - ); - for packet_id in n.into_iter() { - if !f(packet_id).await { - return; - } - } - } - } -} - -pub struct ResponderRtcpReader { - parent_rtcp_reader: Arc, - internal: Arc, -} - -#[async_trait] -impl RTCPReader for ResponderRtcpReader { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(Vec>, Attributes)> { - let (pkts, attr) = { self.parent_rtcp_reader.read(buf, a).await? }; - for p in &pkts { - if let Some(nack) = p.as_any().downcast_ref::() { - let nack = nack.clone(); - let streams = Arc::clone(&self.internal.streams); - tokio::spawn(async move { - ResponderInternal::resend_packets(streams, nack).await; - }); - } - } - - Ok((pkts, attr)) - } -} - -/// Responder responds to nack feedback messages -pub struct Responder { - internal: Arc, -} - -impl Responder { - /// builder returns a new ResponderBuilder. - pub fn builder() -> ResponderBuilder { - ResponderBuilder::default() - } -} - -#[async_trait] -impl Interceptor for Responder { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - Arc::new(ResponderRtcpReader { - internal: Arc::clone(&self.internal), - parent_rtcp_reader: reader, - }) as Arc - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - writer: Arc, - ) -> Arc { - if !stream_support_nack(info) { - return writer; - } - - let stream = Arc::new(ResponderStream::new(self.internal.log2_size, writer)); - { - let mut streams = self.internal.streams.lock().await; - streams.insert(info.ssrc, Arc::clone(&stream)); - } - - stream - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo) { - let mut streams = self.internal.streams.lock().await; - streams.remove(&info.ssrc); - } - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - _info: &StreamInfo, - reader: Arc, - ) -> Arc { - reader - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, _info: &StreamInfo) {} - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - Ok(()) - } -} diff --git a/interceptor/src/nack/responder/responder_stream.rs b/interceptor/src/nack/responder/responder_stream.rs deleted file mode 100644 index ec714da6d..000000000 --- a/interceptor/src/nack/responder/responder_stream.rs +++ /dev/null @@ -1,176 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::sync::Mutex; - -use crate::error::Result; -use crate::nack::UINT16SIZE_HALF; -use crate::{Attributes, RTPWriter}; - -struct ResponderStreamInternal { - packets: Vec>, - size: u16, - last_added: u16, - started: bool, -} - -impl ResponderStreamInternal { - fn new(log2_size: u8) -> Self { - ResponderStreamInternal { - packets: vec![None; 1 << log2_size], - size: 1 << log2_size, - last_added: 0, - started: false, - } - } - - fn add(&mut self, packet: &rtp::packet::Packet) { - let seq = packet.header.sequence_number; - if !self.started { - self.packets[(seq % self.size) as usize] = Some(packet.clone()); - self.last_added = seq; - self.started = true; - return; - } - - let diff = seq.wrapping_sub(self.last_added); - if diff == 0 { - return; - } else if diff < UINT16SIZE_HALF { - let mut i = self.last_added.wrapping_add(1); - while i != seq { - self.packets[(i % self.size) as usize] = None; - i = i.wrapping_add(1); - } - } - - self.packets[(seq % self.size) as usize] = Some(packet.clone()); - self.last_added = seq; - } - - fn get(&self, seq: u16) -> Option<&rtp::packet::Packet> { - let diff = self.last_added.wrapping_sub(seq); - if diff >= UINT16SIZE_HALF { - return None; - } - - if diff >= self.size { - return None; - } - - self.packets[(seq % self.size) as usize].as_ref() - } -} - -pub(super) struct ResponderStream { - internal: Mutex, - pub(super) next_rtp_writer: Arc, -} - -impl ResponderStream { - pub(super) fn new(log2_size: u8, writer: Arc) -> Self { - ResponderStream { - internal: Mutex::new(ResponderStreamInternal::new(log2_size)), - next_rtp_writer: writer, - } - } - - async fn add(&self, pkt: &rtp::packet::Packet) { - let mut internal = self.internal.lock().await; - internal.add(pkt); - } - - pub(super) async fn get(&self, seq: u16) -> Option { - let internal = self.internal.lock().await; - internal.get(seq).cloned() - } -} - -/// RTPWriter is used by Interceptor.bind_local_stream. -#[async_trait] -impl RTPWriter for ResponderStream { - /// write a rtp packet - async fn write(&self, pkt: &rtp::packet::Packet, a: &Attributes) -> Result { - self.add(pkt).await; - - self.next_rtp_writer.write(pkt, a).await - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_responder_stream() -> Result<()> { - let tests: Vec = vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, - 65530, 65531, 65532, 65533, 65534, 65535, - ]; - for start in tests { - let mut sb = ResponderStreamInternal::new(3); - - let add = |sb: &mut ResponderStreamInternal, nums: &[u16]| { - for n in nums { - let seq = start.wrapping_add(*n); - sb.add(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: seq, - ..Default::default() - }, - ..Default::default() - }); - } - }; - - let assert_get = |sb: &ResponderStreamInternal, nums: &[u16]| { - for n in nums { - let seq = start.wrapping_add(*n); - if let Some(packet) = sb.get(seq) { - assert_eq!( - packet.header.sequence_number, seq, - "packet for {} returned with incorrect SequenceNumber: {}", - seq, packet.header.sequence_number - ); - } else { - panic!("packet not found: {seq}"); - } - } - }; - - let assert_not_get = |sb: &ResponderStreamInternal, nums: &[u16]| { - for n in nums { - let seq = start.wrapping_add(*n); - if let Some(packet) = sb.get(seq) { - panic!( - "packet found for {}: {}", - seq, packet.header.sequence_number - ); - } - } - }; - - add(&mut sb, &[0, 1, 2, 3, 4, 5, 6, 7]); - assert_get(&sb, &[0, 1, 2, 3, 4, 5, 6, 7]); - - add(&mut sb, &[8]); - assert_get(&sb, &[8]); - assert_not_get(&sb, &[0]); - - add(&mut sb, &[10]); - assert_get(&sb, &[10]); - assert_not_get(&sb, &[1, 2, 9]); - - add(&mut sb, &[22]); - assert_get(&sb, &[22]); - assert_not_get( - &sb, - &[ - 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - ], - ); - } - - Ok(()) - } -} diff --git a/interceptor/src/nack/responder/responder_test.rs b/interceptor/src/nack/responder/responder_test.rs deleted file mode 100644 index e61cf0ea6..000000000 --- a/interceptor/src/nack/responder/responder_test.rs +++ /dev/null @@ -1,76 +0,0 @@ -use rtcp::transport_feedbacks::transport_layer_nack::{NackPair, TransportLayerNack}; -use tokio::time::Duration; - -use super::*; -use crate::mock::mock_stream::MockStream; -use crate::stream_info::RTCPFeedback; -use crate::test::timeout_or_fail; - -#[tokio::test] -async fn test_responder_interceptor() -> Result<()> { - let icpr: Arc = - Responder::builder().with_log2_size(3).build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtcp_feedback: vec![RTCPFeedback { - typ: "nack".to_owned(), - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - for seq_num in [10, 11, 12, 14, 15] { - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: seq_num, - ..Default::default() - }, - ..Default::default() - }) - .await?; - - let p = timeout_or_fail(Duration::from_millis(10), stream.written_rtp()) - .await - .expect("A packet"); - assert_eq!(p.header.sequence_number, seq_num); - } - - stream - .receive_rtcp(vec![Box::new(TransportLayerNack { - media_ssrc: 1, - sender_ssrc: 2, - nacks: vec![ - NackPair { - packet_id: 11, - lost_packets: 0b1011, - }, // sequence numbers: 11, 12, 13, 15 - ], - })]) - .await; - - // seq number 13 was never sent, so it can't be resent - for seq_num in [11, 12, 15] { - if let Ok(r) = tokio::time::timeout(Duration::from_millis(50), stream.written_rtp()).await { - if let Some(p) = r { - assert_eq!(p.header.sequence_number, seq_num); - } else { - panic!("seq_num {seq_num} is not sent due to channel closed"); - } - } else { - panic!("seq_num {seq_num} is not sent yet"); - } - } - - let result = tokio::time::timeout(Duration::from_millis(10), stream.written_rtp()).await; - assert!(result.is_err(), "no more rtp packets expected"); - - stream.close().await?; - - Ok(()) -} diff --git a/interceptor/src/noop.rs b/interceptor/src/noop.rs deleted file mode 100644 index 597c57a8f..000000000 --- a/interceptor/src/noop.rs +++ /dev/null @@ -1,80 +0,0 @@ -use super::*; -use crate::error::Result; - -/// NoOp is an Interceptor that does not modify any packets. It can embedded in other interceptors, so it's -/// possible to implement only a subset of the methods. -pub struct NoOp; - -#[async_trait] -impl Interceptor for NoOp { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - reader - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - _info: &StreamInfo, - writer: Arc, - ) -> Arc { - writer - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, _info: &StreamInfo) {} - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - _info: &StreamInfo, - reader: Arc, - ) -> Arc { - reader - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, _info: &StreamInfo) {} - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - Ok(()) - } -} - -#[async_trait] -impl RTPReader for NoOp { - async fn read( - &self, - _buf: &mut [u8], - a: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - Ok((rtp::packet::Packet::default(), a.clone())) - } -} - -#[async_trait] -impl RTCPReader for NoOp { - async fn read( - &self, - _buf: &mut [u8], - a: &Attributes, - ) -> Result<(Vec>, Attributes)> { - Ok((vec![], a.clone())) - } -} diff --git a/interceptor/src/registry.rs b/interceptor/src/registry.rs deleted file mode 100644 index 6a3ed6c98..000000000 --- a/interceptor/src/registry.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::sync::Arc; - -use crate::chain::Chain; -use crate::error::Result; -use crate::noop::NoOp; -use crate::{Interceptor, InterceptorBuilder}; - -/// Registry is a collector for interceptors. -#[derive(Default)] -pub struct Registry { - builders: Vec>, -} - -impl Registry { - pub fn new() -> Self { - Registry { builders: vec![] } - } - - /// add adds a new InterceptorBuilder to the registry. - pub fn add(&mut self, builder: Box) { - self.builders.push(builder); - } - - /// build constructs a single Interceptor from an InterceptorRegistry - pub fn build(&self, id: &str) -> Result> { - if self.builders.is_empty() { - return Ok(Arc::new(NoOp {})); - } - - self.build_chain(id) - .map(|c| Arc::new(c) as Arc) - } - - /// build_chain constructs a non-type erased Chain from an Interceptor registry. - pub fn build_chain(&self, id: &str) -> Result { - if self.builders.is_empty() { - return Ok(Chain::new(vec![Arc::new(NoOp {})])); - } - - let interceptors: Result> = self.builders.iter().map(|b| b.build(id)).collect(); - - Ok(Chain::new(interceptors?)) - } -} diff --git a/interceptor/src/report/mod.rs b/interceptor/src/report/mod.rs deleted file mode 100644 index bb366a2a0..000000000 --- a/interceptor/src/report/mod.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; - -use tokio::sync::{mpsc, Mutex}; -use waitgroup::WaitGroup; - -pub mod receiver; -pub mod sender; - -use receiver::{ReceiverReport, ReceiverReportInternal}; -use sender::{SenderReport, SenderReportInternal}; - -use crate::error::Result; -use crate::{Interceptor, InterceptorBuilder}; - -type FnTimeGen = Arc SystemTime + Sync + 'static + Send>; - -/// ReceiverBuilder can be used to configure ReceiverReport Interceptor. -#[derive(Default)] -pub struct ReportBuilder { - is_rr: bool, - interval: Option, - now: Option, -} - -impl ReportBuilder { - /// with_interval sets send interval for the interceptor. - pub fn with_interval(mut self, interval: Duration) -> ReportBuilder { - self.interval = Some(interval); - self - } - - /// with_now_fn sets an alternative for the time.Now function. - pub fn with_now_fn(mut self, now: FnTimeGen) -> ReportBuilder { - self.now = Some(now); - self - } - - fn build_rr(&self) -> ReceiverReport { - let (close_tx, close_rx) = mpsc::channel(1); - ReceiverReport { - internal: Arc::new(ReceiverReportInternal { - interval: if let Some(interval) = &self.interval { - *interval - } else { - Duration::from_secs(1) - }, - now: self.now.clone(), - streams: Mutex::new(HashMap::new()), - close_rx: Mutex::new(Some(close_rx)), - }), - - wg: Mutex::new(Some(WaitGroup::new())), - close_tx: Mutex::new(Some(close_tx)), - } - } - - fn build_sr(&self) -> SenderReport { - let (close_tx, close_rx) = mpsc::channel(1); - SenderReport { - internal: Arc::new(SenderReportInternal { - interval: if let Some(interval) = &self.interval { - *interval - } else { - Duration::from_secs(1) - }, - now: self.now.clone(), - streams: Mutex::new(HashMap::new()), - close_rx: Mutex::new(Some(close_rx)), - }), - - wg: Mutex::new(Some(WaitGroup::new())), - close_tx: Mutex::new(Some(close_tx)), - } - } -} - -impl InterceptorBuilder for ReportBuilder { - fn build(&self, _id: &str) -> Result> { - if self.is_rr { - Ok(Arc::new(self.build_rr())) - } else { - Ok(Arc::new(self.build_sr())) - } - } -} diff --git a/interceptor/src/report/receiver/mod.rs b/interceptor/src/report/receiver/mod.rs deleted file mode 100644 index ef3381949..000000000 --- a/interceptor/src/report/receiver/mod.rs +++ /dev/null @@ -1,225 +0,0 @@ -mod receiver_stream; -#[cfg(test)] -mod receiver_test; - -use std::collections::HashMap; -use std::time::{Duration, SystemTime}; - -use receiver_stream::ReceiverStream; -use tokio::sync::{mpsc, Mutex}; -use waitgroup::WaitGroup; - -use super::*; -use crate::error::Error; -use crate::*; - -pub(crate) struct ReceiverReportInternal { - pub(crate) interval: Duration, - pub(crate) now: Option, - pub(crate) streams: Mutex>>, - pub(crate) close_rx: Mutex>>, -} - -pub(crate) struct ReceiverReportRtcpReader { - pub(crate) internal: Arc, - pub(crate) parent_rtcp_reader: Arc, -} - -#[async_trait] -impl RTCPReader for ReceiverReportRtcpReader { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(Vec>, Attributes)> { - let (pkts, attr) = self.parent_rtcp_reader.read(buf, a).await?; - - let now = if let Some(f) = &self.internal.now { - f() - } else { - SystemTime::now() - }; - - for p in &pkts { - if let Some(sr) = p - .as_any() - .downcast_ref::() - { - let stream = { - let m = self.internal.streams.lock().await; - m.get(&sr.ssrc).cloned() - }; - if let Some(stream) = stream { - stream.process_sender_report(now, sr); - } - } - } - - Ok((pkts, attr)) - } -} - -/// ReceiverReport interceptor generates receiver reports. -pub struct ReceiverReport { - pub(crate) internal: Arc, - - pub(crate) wg: Mutex>, - pub(crate) close_tx: Mutex>>, -} - -impl ReceiverReport { - /// builder returns a new ReportBuilder. - pub fn builder() -> ReportBuilder { - ReportBuilder { - is_rr: true, - ..Default::default() - } - } - - async fn is_closed(&self) -> bool { - let close_tx = self.close_tx.lock().await; - close_tx.is_none() - } - - async fn run( - rtcp_writer: Arc, - internal: Arc, - ) -> Result<()> { - let mut ticker = tokio::time::interval(internal.interval); - let mut close_rx = { - let mut close_rx = internal.close_rx.lock().await; - if let Some(close) = close_rx.take() { - close - } else { - return Err(Error::ErrInvalidCloseRx); - } - }; - - loop { - tokio::select! { - _ = ticker.tick() =>{ - // TODO(cancel safety): This branch isn't cancel safe - - let now = if let Some(f) = &internal.now { - f() - } else { - SystemTime::now() - }; - let streams:Vec> = { - let m = internal.streams.lock().await; - m.values().cloned().collect() - }; - for stream in streams { - let pkt = stream.generate_report(now); - - let a = Attributes::new(); - if let Err(err) = rtcp_writer.write(&[Box::new(pkt)], &a).await{ - log::warn!("failed sending: {}", err); - } - } - } - _ = close_rx.recv() =>{ - return Ok(()); - } - } - } - } -} - -#[async_trait] -impl Interceptor for ReceiverReport { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - Arc::new(ReceiverReportRtcpReader { - internal: Arc::clone(&self.internal), - parent_rtcp_reader: reader, - }) - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - if self.is_closed().await { - return writer; - } - - let mut w = { - let wait_group = self.wg.lock().await; - wait_group.as_ref().map(|wg| wg.worker()) - }; - let writer2 = Arc::clone(&writer); - let internal = Arc::clone(&self.internal); - tokio::spawn(async move { - let _d = w.take(); - if let Err(err) = ReceiverReport::run(writer2, internal).await { - log::warn!("bind_rtcp_writer ReceiverReport::run got error: {}", err); - } - }); - - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - _info: &StreamInfo, - writer: Arc, - ) -> Arc { - writer - } - - /// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, _info: &StreamInfo) {} - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - reader: Arc, - ) -> Arc { - let stream = Arc::new(ReceiverStream::new( - info.ssrc, - info.clock_rate, - reader, - self.internal.now.clone(), - )); - { - let mut streams = self.internal.streams.lock().await; - streams.insert(info.ssrc, Arc::clone(&stream)); - } - - stream - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo) { - let mut streams = self.internal.streams.lock().await; - streams.remove(&info.ssrc); - } - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - { - let mut close_tx = self.close_tx.lock().await; - close_tx.take(); - } - - { - let mut wait_group = self.wg.lock().await; - if let Some(wg) = wait_group.take() { - wg.wait().await; - } - } - - Ok(()) - } -} diff --git a/interceptor/src/report/receiver/receiver_stream.rs b/interceptor/src/report/receiver/receiver_stream.rs deleted file mode 100644 index d170922e8..000000000 --- a/interceptor/src/report/receiver/receiver_stream.rs +++ /dev/null @@ -1,227 +0,0 @@ -use std::time::SystemTime; - -use async_trait::async_trait; -use util::sync::Mutex; - -use super::*; -use crate::{Attributes, RTPReader}; - -struct ReceiverStreamInternal { - ssrc: u32, - receiver_ssrc: u32, - clock_rate: f64, - - packets: Vec, - started: bool, - seq_num_cycles: u16, - last_seq_num: i32, - last_report_seq_num: i32, - last_rtp_time_rtp: u32, - last_rtp_time_time: SystemTime, - jitter: f64, - last_sender_report: u32, - last_sender_report_time: SystemTime, - total_lost: u32, -} - -impl ReceiverStreamInternal { - fn set_received(&mut self, seq: u16) { - let pos = (seq as usize) % self.packets.len(); - self.packets[pos / 64] |= 1 << (pos % 64); - } - - fn del_received(&mut self, seq: u16) { - let pos = (seq as usize) % self.packets.len(); - self.packets[pos / 64] &= u64::MAX ^ (1u64 << (pos % 64)); - } - - fn get_received(&self, seq: u16) -> bool { - let pos = (seq as usize) % self.packets.len(); - (self.packets[pos / 64] & (1 << (pos % 64))) != 0 - } - - fn process_rtp(&mut self, now: SystemTime, pkt: &rtp::packet::Packet) { - if !self.started { - // first frame - self.started = true; - self.set_received(pkt.header.sequence_number); - self.last_seq_num = pkt.header.sequence_number as i32; - self.last_report_seq_num = pkt.header.sequence_number as i32 - 1; - } else { - // following frames - self.set_received(pkt.header.sequence_number); - - let diff = pkt.header.sequence_number as i32 - self.last_seq_num; - if !(-0x0FFF..=0).contains(&diff) { - // overflow - if diff < -0x0FFF { - self.seq_num_cycles += 1; - } - - // set missing packets as missing - for i in self.last_seq_num + 1..pkt.header.sequence_number as i32 { - self.del_received(i as u16); - } - - self.last_seq_num = pkt.header.sequence_number as i32; - } - - // compute jitter - // https://tools.ietf.org/html/rfc3550#page-39 - let d = now - .duration_since(self.last_rtp_time_time) - .unwrap_or_else(|_| Duration::from_secs(0)) - .as_secs_f64() - * self.clock_rate - - (pkt.header.timestamp as f64 - self.last_rtp_time_rtp as f64); - self.jitter += (d.abs() - self.jitter) / 16.0; - } - - self.last_rtp_time_rtp = pkt.header.timestamp; - self.last_rtp_time_time = now; - } - - fn process_sender_report(&mut self, now: SystemTime, sr: &rtcp::sender_report::SenderReport) { - self.last_sender_report = (sr.ntp_time >> 16) as u32; - self.last_sender_report_time = now; - } - - fn generate_report(&mut self, now: SystemTime) -> rtcp::receiver_report::ReceiverReport { - let total_since_report = (self.last_seq_num - self.last_report_seq_num) as u16; - let mut total_lost_since_report = { - if self.last_seq_num == self.last_report_seq_num { - 0 - } else { - let mut ret = 0u32; - let mut i = (self.last_report_seq_num + 1) as u16; - while i != self.last_seq_num as u16 { - if !self.get_received(i) { - ret += 1; - } - i = i.wrapping_add(1); - } - ret - } - }; - - self.total_lost += total_lost_since_report; - - // allow up to 24 bits - if total_lost_since_report > 0xFFFFFF { - total_lost_since_report = 0xFFFFFF; - } - if self.total_lost > 0xFFFFFF { - self.total_lost = 0xFFFFFF - } - - let r = rtcp::receiver_report::ReceiverReport { - ssrc: self.receiver_ssrc, - reports: vec![rtcp::reception_report::ReceptionReport { - ssrc: self.ssrc, - last_sequence_number: (self.seq_num_cycles as u32) << 16 - | (self.last_seq_num as u32), - last_sender_report: self.last_sender_report, - fraction_lost: ((total_lost_since_report * 256) as f64 / total_since_report as f64) - as u8, - total_lost: self.total_lost, - delay: { - if self.last_sender_report_time == SystemTime::UNIX_EPOCH { - 0 - } else { - match now.duration_since(self.last_sender_report_time) { - Ok(d) => (d.as_secs_f64() * 65536.0) as u32, - Err(_) => 0, - } - } - }, - jitter: self.jitter as u32, - }], - ..Default::default() - }; - - self.last_report_seq_num = self.last_seq_num; - - r - } -} - -pub(crate) struct ReceiverStream { - parent_rtp_reader: Arc, - now: Option, - - internal: Mutex, -} - -impl ReceiverStream { - pub(crate) fn new( - ssrc: u32, - clock_rate: u32, - reader: Arc, - now: Option, - ) -> Self { - let receiver_ssrc = rand::random::(); - ReceiverStream { - parent_rtp_reader: reader, - now, - - internal: Mutex::new(ReceiverStreamInternal { - ssrc, - receiver_ssrc, - clock_rate: clock_rate as f64, - - packets: vec![0u64; 128], - started: false, - seq_num_cycles: 0, - last_seq_num: 0, - last_report_seq_num: 0, - last_rtp_time_rtp: 0, - last_rtp_time_time: SystemTime::UNIX_EPOCH, - jitter: 0.0, - last_sender_report: 0, - last_sender_report_time: SystemTime::UNIX_EPOCH, - total_lost: 0, - }), - } - } - - pub(crate) fn process_rtp(&self, now: SystemTime, pkt: &rtp::packet::Packet) { - let mut internal = self.internal.lock(); - internal.process_rtp(now, pkt); - } - - pub(crate) fn process_sender_report( - &self, - now: SystemTime, - sr: &rtcp::sender_report::SenderReport, - ) { - let mut internal = self.internal.lock(); - internal.process_sender_report(now, sr); - } - - pub(crate) fn generate_report(&self, now: SystemTime) -> rtcp::receiver_report::ReceiverReport { - let mut internal = self.internal.lock(); - internal.generate_report(now) - } -} - -/// RTPReader is used by Interceptor.bind_remote_stream. -#[async_trait] -impl RTPReader for ReceiverStream { - /// read a rtp packet - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - let (pkt, attr) = self.parent_rtp_reader.read(buf, a).await?; - - let now = if let Some(f) = &self.now { - f() - } else { - SystemTime::now() - }; - self.process_rtp(now, &pkt); - - Ok((pkt, attr)) - } -} diff --git a/interceptor/src/report/receiver/receiver_test.rs b/interceptor/src/report/receiver/receiver_test.rs deleted file mode 100644 index 77fa4a929..000000000 --- a/interceptor/src/report/receiver/receiver_test.rs +++ /dev/null @@ -1,772 +0,0 @@ -//use bytes::Bytes; -use chrono::prelude::*; -use rtp::extension::abs_send_time_extension::unix2ntp; - -use super::*; -use crate::mock::mock_stream::MockStream; -use crate::mock::mock_time::MockTime; - -#[tokio::test] -async fn test_receiver_interceptor_before_any_packet() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 0, - last_sender_report: 0, - fraction_lost: 0, - total_lost: 0, - delay: 0, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_after_rtp_packets() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - for i in 0..10u16 { - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: i, - ..Default::default() - }, - ..Default::default() - }) - .await; - } - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 9, - last_sender_report: 0, - fraction_lost: 0, - total_lost: 0, - delay: 0, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_after_rtp_and_rtcp_packets() -> Result<()> { - let rtp_time: SystemTime = Utc.with_ymd_and_hms(2009, 10, 23, 0, 0, 0).unwrap().into(); - - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - for i in 0..10u16 { - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: i, - ..Default::default() - }, - ..Default::default() - }) - .await; - } - - let now: SystemTime = Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 1).unwrap().into(); - let rt = 987654321u32.wrapping_add( - (now.duration_since(rtp_time) - .unwrap_or(Duration::from_secs(0)) - .as_secs_f64() - * 90000.0) as u32, - ); - stream - .receive_rtcp(vec![Box::new(rtcp::sender_report::SenderReport { - ssrc: 123456, - ntp_time: unix2ntp(now), - rtp_time: rt, - packet_count: 10, - octet_count: 0, - ..Default::default() - })]) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 9, - last_sender_report: 1861287936, - fraction_lost: 0, - total_lost: 0, - delay: rr.reports[0].delay, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_overflow() -> Result<()> { - #![allow(clippy::identity_op)] - - let mt = Arc::new(MockTime::default()); - let _mt2 = Arc::clone(&mt); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xffff, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0, - ..Default::default() - }, - ..Default::default() - }) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: { - // most significant bits: 1 << 16 - // least significant bits: 0x0000 - (1 << 16) | 0x0000 - }, - last_sender_report: 0, - fraction_lost: 0, - total_lost: 0, - delay: rr.reports[0].delay, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_overflow_five_pkts() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xfffd, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xfffe, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xffff, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 1, - ..Default::default() - }, - ..Default::default() - }) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: (1 << 16) | 0x0001, - last_sender_report: 0, - fraction_lost: 0, - total_lost: 0, - delay: rr.reports[0].delay, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_packet_loss() -> Result<()> { - let rtp_time: SystemTime = Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 0).unwrap().into(); - - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0x01, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0x03, - ..Default::default() - }, - ..Default::default() - }) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 0x03, - last_sender_report: 0, - fraction_lost: ((1u16 << 8) / 3) as u8, - total_lost: 1, - delay: 0, - jitter: 0, - } - ) - } else { - panic!(); - } - - let now: SystemTime = Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 1).unwrap().into(); - let rt = 987654321u32.wrapping_add( - (now.duration_since(rtp_time) - .unwrap_or(Duration::from_secs(0)) - .as_secs_f64() - * 90000.0) as u32, - ); - stream - .receive_rtcp(vec![Box::new(rtcp::sender_report::SenderReport { - ssrc: 123456, - ntp_time: unix2ntp(now), - rtp_time: rt, - packet_count: 10, - octet_count: 0, - ..Default::default() - })]) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 0x03, - last_sender_report: 1861287936, - fraction_lost: 0, - total_lost: 1, - delay: rr.reports[0].delay, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_overflow_and_packet_loss() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xffff, - ..Default::default() - }, - ..Default::default() - }) - .await; - - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0x01, - ..Default::default() - }, - ..Default::default() - }) - .await; - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 1 << 16 | 0x01, - last_sender_report: 0, - fraction_lost: ((1u16 << 8) / 3) as u8, - total_lost: 1, - delay: 0, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_reordered_packets() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - for sequence_number in [0x01, 0x03, 0x02, 0x04] { - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number, - ..Default::default() - }, - ..Default::default() - }) - .await; - } - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 0x04, - last_sender_report: 0, - fraction_lost: 0, - total_lost: 0, - delay: 0, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} - -#[tokio::test(start_paused = true)] -async fn test_receiver_interceptor_jitter() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - mt.set_now(Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 0).unwrap().into()); - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0x01, - timestamp: 42378934, - ..Default::default() - }, - ..Default::default() - }) - .await; - stream.read_rtp().await; - - mt.set_now(Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 1).unwrap().into()); - stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0x02, - timestamp: 42378934 + 60000, - ..Default::default() - }, - ..Default::default() - }) - .await; - - // Advance the time to generate a report - tokio::time::advance(Duration::from_millis(60)).await; - // Yield to let the reporting task run - tokio::task::yield_now().await; - - let pkts = stream.last_written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 0x02, - last_sender_report: 0, - fraction_lost: 0, - total_lost: 0, - delay: 0, - jitter: 30000 / 16, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_receiver_interceptor_delay() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = ReceiverReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - mt.set_now(Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 0).unwrap().into()); - stream - .receive_rtcp(vec![Box::new(rtcp::sender_report::SenderReport { - ssrc: 123456, - ntp_time: unix2ntp(Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 0).unwrap().into()), - rtp_time: 987654321, - packet_count: 0, - octet_count: 0, - ..Default::default() - })]) - .await; - stream.read_rtcp().await; - - mt.set_now(Utc.with_ymd_and_hms(2009, 11, 10, 23, 0, 1).unwrap().into()); - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(rr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!(rr.reports.len(), 1); - assert_eq!( - rr.reports[0], - rtcp::reception_report::ReceptionReport { - ssrc: 123456, - last_sequence_number: 0, - last_sender_report: 1861222400, - fraction_lost: 0, - total_lost: 0, - delay: 65536, - jitter: 0, - } - ) - } else { - panic!(); - } - - stream.close().await?; - Ok(()) -} diff --git a/interceptor/src/report/sender/mod.rs b/interceptor/src/report/sender/mod.rs deleted file mode 100644 index 83d1aa38a..000000000 --- a/interceptor/src/report/sender/mod.rs +++ /dev/null @@ -1,182 +0,0 @@ -mod sender_stream; -#[cfg(test)] -mod sender_test; - -use std::collections::HashMap; -use std::time::{Duration, SystemTime}; - -use sender_stream::SenderStream; -use tokio::sync::{mpsc, Mutex}; -use waitgroup::WaitGroup; - -use super::*; -use crate::error::Error; -use crate::*; - -pub(crate) struct SenderReportInternal { - pub(crate) interval: Duration, - pub(crate) now: Option, - pub(crate) streams: Mutex>>, - pub(crate) close_rx: Mutex>>, -} - -/// SenderReport interceptor generates sender reports. -pub struct SenderReport { - pub(crate) internal: Arc, - - pub(crate) wg: Mutex>, - pub(crate) close_tx: Mutex>>, -} - -impl SenderReport { - /// builder returns a new ReportBuilder. - pub fn builder() -> ReportBuilder { - ReportBuilder { - is_rr: false, - ..Default::default() - } - } - - async fn is_closed(&self) -> bool { - let close_tx = self.close_tx.lock().await; - close_tx.is_none() - } - - async fn run( - rtcp_writer: Arc, - internal: Arc, - ) -> Result<()> { - let mut ticker = tokio::time::interval(internal.interval); - let mut close_rx = { - let mut close_rx = internal.close_rx.lock().await; - if let Some(close) = close_rx.take() { - close - } else { - return Err(Error::ErrInvalidCloseRx); - } - }; - - loop { - tokio::select! { - _ = ticker.tick() =>{ - // TODO(cancel safety): This branch isn't cancel safe - let now = if let Some(f) = &internal.now { - f() - } else { - SystemTime::now() - }; - let streams:Vec> = { - let m = internal.streams.lock().await; - m.values().cloned().collect() - }; - for stream in streams { - let pkt = stream.generate_report(now).await; - - let a = Attributes::new(); - if let Err(err) = rtcp_writer.write(&[Box::new(pkt)], &a).await{ - log::warn!("failed sending: {}", err); - } - } - } - _ = close_rx.recv() =>{ - return Ok(()); - } - } - } - } -} - -#[async_trait] -impl Interceptor for SenderReport { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - reader - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - if self.is_closed().await { - return writer; - } - - let mut w = { - let wait_group = self.wg.lock().await; - wait_group.as_ref().map(|wg| wg.worker()) - }; - let writer2 = Arc::clone(&writer); - let internal = Arc::clone(&self.internal); - tokio::spawn(async move { - let _d = w.take(); - if let Err(err) = SenderReport::run(writer2, internal).await { - log::warn!("bind_rtcp_writer Generator::run got error: {}", err); - } - }); - - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - writer: Arc, - ) -> Arc { - let stream = Arc::new(SenderStream::new( - info.ssrc, - info.clock_rate, - writer, - self.internal.now.clone(), - )); - { - let mut streams = self.internal.streams.lock().await; - streams.insert(info.ssrc, Arc::clone(&stream)); - } - - stream - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo) { - let mut streams = self.internal.streams.lock().await; - streams.remove(&info.ssrc); - } - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - _info: &StreamInfo, - reader: Arc, - ) -> Arc { - reader - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, _info: &StreamInfo) {} - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - { - let mut close_tx = self.close_tx.lock().await; - close_tx.take(); - } - - { - let mut wait_group = self.wg.lock().await; - if let Some(wg) = wait_group.take() { - wg.wait().await; - } - } - - Ok(()) - } -} diff --git a/interceptor/src/report/sender/sender_stream.rs b/interceptor/src/report/sender/sender_stream.rs deleted file mode 100644 index df0602afa..000000000 --- a/interceptor/src/report/sender/sender_stream.rs +++ /dev/null @@ -1,142 +0,0 @@ -use std::convert::TryInto; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; - -use async_trait::async_trait; -use rtp::extension::abs_send_time_extension::unix2ntp; -use tokio::sync::Mutex; - -use super::*; -use crate::{Attributes, RTPWriter}; - -struct SenderStreamInternal { - ssrc: u32, - clock_rate: f64, - - /// data from rtp packets - last_rtp_time_rtp: u32, - last_rtp_time_time: SystemTime, - counters: Counters, -} - -impl SenderStreamInternal { - fn process_rtp(&mut self, now: SystemTime, pkt: &rtp::packet::Packet) { - // always update time to minimize errors - self.last_rtp_time_rtp = pkt.header.timestamp; - self.last_rtp_time_time = now; - - self.counters.increment_packets(); - self.counters.count_octets(pkt.payload.len()); - } - - fn generate_report(&mut self, now: SystemTime) -> rtcp::sender_report::SenderReport { - rtcp::sender_report::SenderReport { - ssrc: self.ssrc, - ntp_time: unix2ntp(now), - rtp_time: self.last_rtp_time_rtp.wrapping_add( - (now.duration_since(self.last_rtp_time_time) - .unwrap_or_else(|_| Duration::from_secs(0)) - .as_secs_f64() - * self.clock_rate) as u32, - ), - packet_count: self.counters.packet_count(), - octet_count: self.counters.octet_count(), - ..Default::default() - } - } -} - -pub(crate) struct SenderStream { - next_rtp_writer: Arc, - now: Option, - - internal: Mutex, -} - -impl SenderStream { - pub(crate) fn new( - ssrc: u32, - clock_rate: u32, - writer: Arc, - now: Option, - ) -> Self { - SenderStream { - next_rtp_writer: writer, - now, - - internal: Mutex::new(SenderStreamInternal { - ssrc, - clock_rate: clock_rate as f64, - last_rtp_time_rtp: 0, - last_rtp_time_time: SystemTime::UNIX_EPOCH, - counters: Default::default(), - }), - } - } - - async fn process_rtp(&self, now: SystemTime, pkt: &rtp::packet::Packet) { - let mut internal = self.internal.lock().await; - internal.process_rtp(now, pkt); - } - - pub(crate) async fn generate_report( - &self, - now: SystemTime, - ) -> rtcp::sender_report::SenderReport { - let mut internal = self.internal.lock().await; - internal.generate_report(now) - } -} - -/// RTPWriter is used by Interceptor.bind_local_stream. -#[async_trait] -impl RTPWriter for SenderStream { - /// write a rtp packet - async fn write(&self, pkt: &rtp::packet::Packet, a: &Attributes) -> Result { - let now = if let Some(f) = &self.now { - f() - } else { - SystemTime::now() - }; - self.process_rtp(now, pkt).await; - - self.next_rtp_writer.write(pkt, a).await - } -} - -#[derive(Default)] -pub(crate) struct Counters { - packets: u32, - octets: u32, -} - -/// Wrapping counters used for generating [`rtcp::sender_report::SenderReport`] -impl Counters { - pub fn increment_packets(&mut self) { - self.packets = self.packets.wrapping_add(1); - } - - pub fn count_octets(&mut self, octets: usize) { - // account for a payload size of at most `u32::MAX` - // and log a message if larger - self.octets = self - .octets - .wrapping_add(octets.try_into().unwrap_or_else(|_| { - log::warn!("packet payload larger than 32 bits"); - u32::MAX - })); - } - - pub fn packet_count(&self) -> u32 { - self.packets - } - - pub fn octet_count(&self) -> u32 { - self.octets - } - - #[cfg(test)] - pub fn mock(packets: u32, octets: u32) -> Self { - Self { packets, octets } - } -} diff --git a/interceptor/src/report/sender/sender_test.rs b/interceptor/src/report/sender/sender_test.rs deleted file mode 100644 index fdd2bbd5c..000000000 --- a/interceptor/src/report/sender/sender_test.rs +++ /dev/null @@ -1,260 +0,0 @@ -use bytes::Bytes; -use chrono::prelude::*; -use rtp::extension::abs_send_time_extension::unix2ntp; - -use super::*; -use crate::mock::mock_stream::MockStream; -use crate::mock::mock_time::MockTime; - -#[tokio::test] -async fn test_sender_interceptor_before_any_packet() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = SenderReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - let dt = Utc.with_ymd_and_hms(2009, 10, 23, 0, 0, 0).unwrap(); - mt.set_now(dt.into()); - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(sr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!( - sr, - &rtcp::sender_report::SenderReport { - ssrc: 123456, - ntp_time: unix2ntp(mt.now()), - rtp_time: 4294967295, // pion: 2269117121, - packet_count: 0, - octet_count: 0, - ..Default::default() - } - ) - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_sender_interceptor_after_rtp_packets() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = SenderReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - for i in 0..10u16 { - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: i, - ..Default::default() - }, - payload: Bytes::from_static(b"\x00\x00"), - }) - .await?; - } - - let dt = Utc.with_ymd_and_hms(2009, 10, 23, 0, 0, 0).unwrap(); - mt.set_now(dt.into()); - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(sr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!( - sr, - &rtcp::sender_report::SenderReport { - ssrc: 123456, - ntp_time: unix2ntp(mt.now()), - rtp_time: 4294967295, // pion: 2269117121, - packet_count: 10, - octet_count: 20, - ..Default::default() - } - ) - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_sender_interceptor_after_rtp_packets_overflow() -> Result<()> { - let mt = Arc::new(MockTime::default()); - let time_gen = { - let mt = Arc::clone(&mt); - Arc::new(move || mt.now()) - }; - - let icpr: Arc = SenderReport::builder() - .with_interval(Duration::from_millis(50)) - .with_now_fn(time_gen) - .build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - clock_rate: 90000, - ..Default::default() - }, - icpr, - ) - .await; - - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xfffd, - ..Default::default() - }, - payload: Bytes::from_static(b"\x00\x00"), - }) - .await?; - - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xfffe, - ..Default::default() - }, - payload: Bytes::from_static(b"\x00\x00"), - }) - .await?; - - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0xffff, - ..Default::default() - }, - payload: Bytes::from_static(b"\x00\x00"), - }) - .await?; - - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 0, - ..Default::default() - }, - payload: Bytes::from_static(b"\x00\x00"), - }) - .await?; - - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: 1, - ..Default::default() - }, - payload: Bytes::from_static(b"\x00\x00"), - }) - .await?; - - let dt = Utc.with_ymd_and_hms(2009, 10, 23, 0, 0, 0).unwrap(); - mt.set_now(dt.into()); - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(sr) = pkts[0] - .as_any() - .downcast_ref::() - { - assert_eq!( - sr, - &rtcp::sender_report::SenderReport { - ssrc: 123456, - ntp_time: unix2ntp(mt.now()), - rtp_time: 4294967295, // pion: 2269117121, - packet_count: 5, - octet_count: 10, - ..Default::default() - } - ) - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_stream_counters_initially_zero() -> Result<()> { - let counters = sender_stream::Counters::default(); - assert_eq!(counters.octet_count(), 0); - assert_eq!(counters.packet_count(), 0); - Ok(()) -} - -#[tokio::test] -async fn test_stream_packet_counter_wraps_on_overflow() -> Result<()> { - let mut counters = sender_stream::Counters::mock(u32::MAX, 0); - for _ in 0..3 { - counters.increment_packets(); - } - assert_eq!(counters.packet_count(), 2); - Ok(()) -} - -#[tokio::test] -async fn test_stream_octet_counter_wraps_on_overflow() -> Result<()> { - let mut counters = sender_stream::Counters::default(); - counters.count_octets(u32::MAX as usize); - counters.count_octets(3); - assert_eq!(counters.octet_count(), 2); - Ok(()) -} - -#[tokio::test] -async fn test_stream_octet_counter_saturates_u32_from_usize() -> Result<()> { - let mut counters = sender_stream::Counters::default(); - counters.count_octets(0xabcdef01234567_usize); - assert_eq!(counters.octet_count(), 0xffffffff_u32); - Ok(()) -} diff --git a/interceptor/src/stats/interceptor.rs b/interceptor/src/stats/interceptor.rs deleted file mode 100644 index 272336e97..000000000 --- a/interceptor/src/stats/interceptor.rs +++ /dev/null @@ -1,1194 +0,0 @@ -use std::collections::HashMap; -use std::fmt; -use std::sync::Arc; -use std::time::SystemTime; - -use async_trait::async_trait; -use rtcp::extended_report::{DLRRReportBlock, ExtendedReport}; -use rtcp::payload_feedbacks::full_intra_request::FullIntraRequest; -use rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; -use rtcp::receiver_report::ReceiverReport; -use rtcp::sender_report::SenderReport; -use rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack; -use rtp::extension::abs_send_time_extension::unix2ntp; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::Duration; -use util::sync::Mutex; -use util::MarshalSize; - -use super::{inbound, outbound, StatsContainer}; -use crate::error::Result; -use crate::stream_info::StreamInfo; -use crate::{Attributes, Interceptor, RTCPReader, RTCPWriter, RTPReader, RTPWriter}; - -#[derive(Debug)] -enum Message { - StatUpdate { - ssrc: u32, - update: StatsUpdate, - }, - RequestInboundSnapshot { - ssrcs: Vec, - chan: oneshot::Sender>>, - }, - RequestOutboundSnapshot { - ssrcs: Vec, - chan: oneshot::Sender>>, - }, -} - -#[derive(Debug)] -enum StatsUpdate { - /// Stats collected on the receiving end(inbound) of an RTP stream. - InboundRTP { - packets: u64, - header_bytes: u64, - payload_bytes: u64, - last_packet_timestamp: SystemTime, - }, - /// Stats collected on the sending end(outbound) of an RTP stream. - OutboundRTP { - packets: u64, - header_bytes: u64, - payload_bytes: u64, - last_packet_timestamp: SystemTime, - }, - /// Stats collected from received RTCP packets. - InboundRTCP { - fir_count: Option, - pli_count: Option, - nack_count: Option, - }, - /// Stats collected from sent RTCP packets. - OutboundRTCP { - fir_count: Option, - pli_count: Option, - nack_count: Option, - }, - /// An extended sequence number sent in an SR. - OutboundSRExtSeqNum { seq_num: u32 }, - /// Stats collected from received Receiver Reports i.e. where we have an outbound RTP stream. - InboundReceiverReport { - ext_seq_num: u32, - total_lost: u32, - jitter: u32, - rtt_ms: Option, - fraction_lost: u8, - }, - /// Stats collected from received Sender Reports i.e. where we have an inbound RTP stream. - InboundSenderRerport { - packets_and_bytes_sent: Option<(u32, u32)>, - rtt_ms: Option, - }, -} - -pub struct StatsInterceptor { - // Wrapped RTP streams - recv_streams: Mutex>>, - send_streams: Mutex>>, - - tx: mpsc::Sender, - - id: String, - now_gen: Arc SystemTime + Send + Sync>, -} - -impl StatsInterceptor { - pub fn new(id: String) -> Self { - let (tx, rx) = mpsc::channel(100); - - tokio::spawn(run_stats_reducer(rx)); - - Self { - id, - recv_streams: Default::default(), - send_streams: Default::default(), - tx, - now_gen: Arc::new(SystemTime::now), - } - } - - fn with_time_gen(id: String, now_gen: F) -> Self - where - F: Fn() -> SystemTime + Send + Sync + 'static, - { - let (tx, rx) = mpsc::channel(100); - tokio::spawn(run_stats_reducer(rx)); - - Self { - id, - recv_streams: Default::default(), - send_streams: Default::default(), - tx, - now_gen: Arc::new(now_gen), - } - } - - pub async fn fetch_inbound_stats( - &self, - ssrcs: Vec, - ) -> Vec> { - let (tx, rx) = oneshot::channel(); - - if let Err(e) = self - .tx - .send(Message::RequestInboundSnapshot { ssrcs, chan: tx }) - .await - { - log::debug!( - "Failed to fetch inbound RTP stream stats from stats task with error: {}", - e - ); - - return vec![]; - } - - rx.await.unwrap_or_default() - } - - pub async fn fetch_outbound_stats( - &self, - ssrcs: Vec, - ) -> Vec> { - let (tx, rx) = oneshot::channel(); - - if let Err(e) = self - .tx - .send(Message::RequestOutboundSnapshot { ssrcs, chan: tx }) - .await - { - log::debug!( - "Failed to fetch outbound RTP stream stats from stats task with error: {}", - e - ); - - return vec![]; - } - - rx.await.unwrap_or_default() - } -} - -async fn run_stats_reducer(mut rx: mpsc::Receiver) { - let mut ssrc_stats: StatsContainer = Default::default(); - let mut cleanup_ticker = tokio::time::interval(Duration::from_secs(10)); - - loop { - tokio::select! { - maybe_msg = rx.recv() => { - let msg = match maybe_msg { - Some(m) => m, - None => break, - }; - - match msg { - Message::StatUpdate { ssrc, update } => { - handle_stats_update(&mut ssrc_stats, ssrc, update); - } - Message::RequestInboundSnapshot { ssrcs, chan} => { - let result = ssrcs - .into_iter() - .map(|ssrc| ssrc_stats.get_inbound_stats(ssrc).map(inbound::StreamStats::snapshot)) - .collect(); - - let _ = chan.send(result); - } - Message::RequestOutboundSnapshot { ssrcs, chan} => { - let result = ssrcs - .into_iter() - .map(|ssrc| ssrc_stats.get_outbound_stats(ssrc).map(outbound::StreamStats::snapshot)) - .collect(); - - let _ = chan.send(result); - - } - } - - } - _ = cleanup_ticker.tick() => { - ssrc_stats.remove_stale_entries(); - } - } - } -} - -fn handle_stats_update(ssrc_stats: &mut StatsContainer, ssrc: u32, update: StatsUpdate) { - match update { - StatsUpdate::InboundRTP { - packets, - header_bytes, - payload_bytes, - last_packet_timestamp, - } => { - let stats = ssrc_stats.get_or_create_inbound_stream_stats(ssrc); - - stats - .rtp_stats - .update(header_bytes, payload_bytes, packets, last_packet_timestamp); - stats.mark_updated(); - } - StatsUpdate::OutboundRTP { - packets, - header_bytes, - payload_bytes, - last_packet_timestamp, - } => { - let stats = ssrc_stats.get_or_create_outbound_stream_stats(ssrc); - stats - .rtp_stats - .update(header_bytes, payload_bytes, packets, last_packet_timestamp); - stats.mark_updated(); - } - StatsUpdate::InboundRTCP { - fir_count, - pli_count, - nack_count, - } => { - let stats = ssrc_stats.get_or_create_outbound_stream_stats(ssrc); - stats.rtcp_stats.update(fir_count, pli_count, nack_count); - stats.mark_updated(); - } - StatsUpdate::OutboundRTCP { - fir_count, - pli_count, - nack_count, - } => { - let stats = ssrc_stats.get_or_create_inbound_stream_stats(ssrc); - stats.rtcp_stats.update(fir_count, pli_count, nack_count); - stats.mark_updated(); - } - StatsUpdate::OutboundSRExtSeqNum { seq_num } => { - let stats = ssrc_stats.get_or_create_outbound_stream_stats(ssrc); - stats.record_sr_ext_seq_num(seq_num); - stats.mark_updated(); - } - StatsUpdate::InboundReceiverReport { - ext_seq_num, - total_lost, - jitter, - rtt_ms, - fraction_lost, - } => { - let stats = ssrc_stats.get_or_create_outbound_stream_stats(ssrc); - stats.record_remote_round_trip_time(rtt_ms); - stats.update_remote_fraction_lost(fraction_lost); - stats.update_remote_total_lost(total_lost); - stats.update_remote_inbound_packets_received(ext_seq_num, total_lost); - stats.update_remote_jitter(jitter); - - stats.mark_updated(); - } - StatsUpdate::InboundSenderRerport { - rtt_ms, - packets_and_bytes_sent, - } => { - // This is a sender report we received, as such it concerns an RTP stream that's - // outbound at the remote. - let stats = ssrc_stats.get_or_create_inbound_stream_stats(ssrc); - - if let Some((packets_sent, bytes_sent)) = packets_and_bytes_sent { - stats.record_sender_report(packets_sent, bytes_sent); - } - stats.record_remote_round_trip_time(rtt_ms); - - stats.mark_updated(); - } - } -} - -#[async_trait] -impl Interceptor for StatsInterceptor { - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - reader: Arc, - ) -> Arc { - let mut lock = self.recv_streams.lock(); - - let e = lock - .entry(info.ssrc) - .or_insert_with(|| Arc::new(RTPReadRecorder::new(reader, self.tx.clone()))); - - e.clone() - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo) { - let mut lock = self.recv_streams.lock(); - - lock.remove(&info.ssrc); - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - writer: Arc, - ) -> Arc { - let mut lock = self.send_streams.lock(); - - let e = lock - .entry(info.ssrc) - .or_insert_with(|| Arc::new(RTPWriteRecorder::new(writer, self.tx.clone()))); - - e.clone() - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo) { - let mut lock = self.send_streams.lock(); - - lock.remove(&info.ssrc); - } - - async fn close(&self) -> Result<()> { - Ok(()) - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - let now = self.now_gen.clone(); - - Arc::new(RTCPWriteInterceptor { - rtcp_writer: writer, - tx: self.tx.clone(), - now_gen: move || now(), - }) - } - - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - let now = self.now_gen.clone(); - - Arc::new(RTCPReadInterceptor { - rtcp_reader: reader, - tx: self.tx.clone(), - now_gen: move || now(), - }) - } -} - -pub struct RTCPReadInterceptor { - rtcp_reader: Arc, - tx: mpsc::Sender, - now_gen: F, -} - -#[async_trait] -impl RTCPReader for RTCPReadInterceptor -where - F: Fn() -> SystemTime + Send + Sync, -{ - /// read a batch of rtcp packets - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(Vec>, Attributes)> { - let (pkts, attributes) = self.rtcp_reader.read(buf, attributes).await?; - - // Middle 32 bits - let now = (unix2ntp((self.now_gen)()) >> 16) as u32; - - #[derive(Default, Debug)] - struct GenericRTCP { - fir_count: Option, - pli_count: Option, - nack_count: Option, - } - - #[derive(Default, Debug)] - struct ReceiverReportEntry { - /// Extended sequence number value from Receiver Report, used to calculate remote - /// stats. - ext_seq_num: u32, - /// Total loss value from Receiver Report, used to calculate remote - /// stats. - total_lost: u32, - /// Jitter from Receiver Report. - jitter: u32, - /// Round Trip Time calculated from Receiver Report. - rtt_ms: Option, - /// Fraction of packets lost. - fraction_lost: u8, - } - - #[derive(Default, Debug)] - struct SenderReportEntry { - /// NTP timestamp(from Sender Report). - sr_ntp_time: Option, - /// Packets Sent(from Sender Report). - sr_packets_sent: Option, - /// Bytes Sent(from Sender Report). - sr_bytes_sent: Option, - /// Last RR timestamp(middle bits) from DLRR extended report block. - dlrr_last_rr: Option, - /// Delay since last RR from DLRR extended report block. - dlrr_delay_rr: Option, - } - - #[derive(Default, Debug)] - struct Entry { - generic_rtcp: GenericRTCP, - receiver_reports: Vec, - sender_reports: Vec, - } - let updates = pkts - .iter() - .fold(HashMap::::new(), |mut acc, p| { - if let Some(rr) = p.as_any().downcast_ref::() { - for recp in &rr.reports { - let e = acc.entry(recp.ssrc).or_default(); - - let rtt_ms = if recp.delay != 0 { - calculate_rtt_ms(now, recp.delay, recp.last_sender_report) - } else { - None - }; - - e.receiver_reports.push(ReceiverReportEntry { - ext_seq_num: recp.last_sequence_number, - total_lost: recp.total_lost, - jitter: recp.jitter, - rtt_ms, - fraction_lost: recp.fraction_lost, - }); - } - } else if let Some(fir) = p.as_any().downcast_ref::() { - for fir_entry in &fir.fir { - let e = acc.entry(fir_entry.ssrc).or_default(); - e.generic_rtcp.fir_count = - e.generic_rtcp.fir_count.map(|v| v + 1).or(Some(1)); - } - } else if let Some(pli) = p.as_any().downcast_ref::() { - let e = acc.entry(pli.media_ssrc).or_default(); - e.generic_rtcp.pli_count = e.generic_rtcp.pli_count.map(|v| v + 1).or(Some(1)); - } else if let Some(nack) = p.as_any().downcast_ref::() { - let count = nack.nacks.iter().flat_map(|p| p.into_iter()).count() as u64; - - let e = acc.entry(nack.media_ssrc).or_default(); - e.generic_rtcp.nack_count = - e.generic_rtcp.nack_count.map(|v| v + count).or(Some(count)); - } else if let Some(sr) = p.as_any().downcast_ref::() { - let e = acc.entry(sr.ssrc).or_default(); - let sr_e = { - let need_new_entry = e - .sender_reports - .last() - .map(|e| e.sr_packets_sent.is_some()) - .unwrap_or(true); - - if need_new_entry { - e.sender_reports.push(Default::default()); - } - - // SAFETY: Unrwap ok because we just added an entry above - e.sender_reports.last_mut().unwrap() - }; - - sr_e.sr_ntp_time = Some(sr.ntp_time); - sr_e.sr_packets_sent = Some(sr.packet_count); - sr_e.sr_bytes_sent = Some(sr.octet_count); - } else if let Some(xr) = p.as_any().downcast_ref::() { - // Extended Report(XR) - - // We only care about DLRR reports - let dlrrs = xr.reports.iter().flat_map(|report| { - let dlrr = report.as_any().downcast_ref::(); - - dlrr.map(|b| b.reports.iter()).into_iter().flatten() - }); - - for dlrr in dlrrs { - let e = acc.entry(dlrr.ssrc).or_default(); - let sr_e = { - let need_new_entry = e - .sender_reports - .last() - .map(|e| e.dlrr_last_rr.is_some()) - .unwrap_or(true); - - if need_new_entry { - e.sender_reports.push(Default::default()); - } - - // SAFETY: Unrwap ok because we just added an entry above - e.sender_reports.last_mut().unwrap() - }; - - sr_e.dlrr_last_rr = Some(dlrr.last_rr); - sr_e.dlrr_delay_rr = Some(dlrr.dlrr); - } - } - - acc - }); - - for ( - ssrc, - Entry { - generic_rtcp, - mut receiver_reports, - mut sender_reports, - }, - ) in updates.into_iter() - { - // Sort RR by seq number low to high - receiver_reports.sort_by(|a, b| a.ext_seq_num.cmp(&b.ext_seq_num)); - // Sort SR by ntp time, low to high - sender_reports - .sort_by(|a, b| a.sr_ntp_time.unwrap_or(0).cmp(&b.sr_ntp_time.unwrap_or(0))); - - let _ = self - .tx - .send(Message::StatUpdate { - ssrc, - update: StatsUpdate::InboundRTCP { - fir_count: generic_rtcp.fir_count, - pli_count: generic_rtcp.pli_count, - nack_count: generic_rtcp.nack_count, - }, - }) - .await; - - let futures = receiver_reports.into_iter().map(|rr| { - self.tx.send(Message::StatUpdate { - ssrc, - update: StatsUpdate::InboundReceiverReport { - ext_seq_num: rr.ext_seq_num, - total_lost: rr.total_lost, - jitter: rr.jitter, - rtt_ms: rr.rtt_ms, - fraction_lost: rr.fraction_lost, - }, - }) - }); - for fut in futures { - // TODO: Use futures::join_all - let _ = fut.await; - } - - let futures = sender_reports.into_iter().map(|sr| { - let rtt_ms = match (sr.dlrr_last_rr, sr.dlrr_delay_rr, sr.sr_packets_sent) { - (Some(last_rr), Some(delay_rr), Some(_)) if last_rr != 0 && delay_rr != 0 => { - calculate_rtt_ms(now, delay_rr, last_rr) - } - _ => None, - }; - - self.tx.send(Message::StatUpdate { - ssrc, - update: StatsUpdate::InboundSenderRerport { - packets_and_bytes_sent: sr - .sr_packets_sent - .and_then(|ps| sr.sr_bytes_sent.map(|bs| (ps, bs))), - rtt_ms, - }, - }) - }); - for fut in futures { - // TODO: Use futures::join_all - let _ = fut.await; - } - } - - Ok((pkts, attributes)) - } -} - -pub struct RTCPWriteInterceptor { - rtcp_writer: Arc, - tx: mpsc::Sender, - now_gen: F, -} - -#[async_trait] -impl RTCPWriter for RTCPWriteInterceptor -where - F: Fn() -> SystemTime + Send + Sync, -{ - async fn write( - &self, - pkts: &[Box], - attributes: &Attributes, - ) -> Result { - #[derive(Default, Debug)] - struct Entry { - fir_count: Option, - pli_count: Option, - nack_count: Option, - sr_ext_seq_num: Option, - } - let updates = pkts - .iter() - .fold(HashMap::::new(), |mut acc, p| { - if let Some(fir) = p.as_any().downcast_ref::() { - for fir_entry in &fir.fir { - let e = acc.entry(fir_entry.ssrc).or_default(); - e.fir_count = e.fir_count.map(|v| v + 1).or(Some(1)); - } - } else if let Some(pli) = p.as_any().downcast_ref::() { - let e = acc.entry(pli.media_ssrc).or_default(); - e.pli_count = e.pli_count.map(|v| v + 1).or(Some(1)); - } else if let Some(nack) = p.as_any().downcast_ref::() { - let count = nack.nacks.iter().flat_map(|p| p.into_iter()).count() as u64; - - let e = acc.entry(nack.media_ssrc).or_default(); - e.nack_count = e.nack_count.map(|v| v + count).or(Some(count)); - } else if let Some(sr) = p.as_any().downcast_ref::() { - for rep in &sr.reports { - let e = acc.entry(rep.ssrc).or_default(); - - match e.sr_ext_seq_num { - // We want the initial value for `last_sequence_number` from the first - // SR. It's possible that an RTCP batch contains more than one SR, in - // which case we should use the lowest value. - Some(seq_num) if seq_num > rep.last_sequence_number => { - e.sr_ext_seq_num = Some(rep.last_sequence_number) - } - None => e.sr_ext_seq_num = Some(rep.last_sequence_number), - _ => {} - } - } - } - - acc - }); - - for ( - ssrc, - Entry { - fir_count, - pli_count, - nack_count, - sr_ext_seq_num, - }, - ) in updates.into_iter() - { - let _ = self - .tx - .send(Message::StatUpdate { - ssrc, - update: StatsUpdate::OutboundRTCP { - fir_count, - pli_count, - nack_count, - }, - }) - .await; - - if let Some(seq_num) = sr_ext_seq_num { - let _ = self - .tx - .send(Message::StatUpdate { - ssrc, - update: StatsUpdate::OutboundSRExtSeqNum { seq_num }, - }) - .await; - } - } - - self.rtcp_writer.write(pkts, attributes).await - } -} - -pub struct RTPReadRecorder { - rtp_reader: Arc, - tx: mpsc::Sender, -} - -impl RTPReadRecorder { - fn new(rtp_reader: Arc, tx: mpsc::Sender) -> Self { - Self { rtp_reader, tx } - } -} - -impl fmt::Debug for RTPReadRecorder { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RTPReadRecorder").finish() - } -} - -#[async_trait] -impl RTPReader for RTPReadRecorder { - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - let (pkt, attributes) = self.rtp_reader.read(buf, attributes).await?; - - let _ = self - .tx - .send(Message::StatUpdate { - ssrc: pkt.header.ssrc, - update: StatsUpdate::InboundRTP { - packets: 1, - header_bytes: pkt.header.marshal_size() as u64, - payload_bytes: pkt.payload.len() as u64, - last_packet_timestamp: SystemTime::now(), - }, - }) - .await; - - Ok((pkt, attributes)) - } -} - -pub struct RTPWriteRecorder { - rtp_writer: Arc, - tx: mpsc::Sender, -} - -impl RTPWriteRecorder { - fn new(rtp_writer: Arc, tx: mpsc::Sender) -> Self { - Self { rtp_writer, tx } - } -} - -impl fmt::Debug for RTPWriteRecorder { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RTPWriteRecorder").finish() - } -} - -#[async_trait] -impl RTPWriter for RTPWriteRecorder { - /// write a rtp packet - async fn write(&self, pkt: &rtp::packet::Packet, attributes: &Attributes) -> Result { - let n = self.rtp_writer.write(pkt, attributes).await?; - - let _ = self - .tx - .send(Message::StatUpdate { - ssrc: pkt.header.ssrc, - update: StatsUpdate::OutboundRTP { - packets: 1, - header_bytes: pkt.header.marshal_size() as u64, - payload_bytes: pkt.payload.len() as u64, - last_packet_timestamp: SystemTime::now(), - }, - }) - .await; - - Ok(n) - } -} - -/// Calculate the round trip time for a given peer as described in -/// [RFC3550 6.4.1](https://datatracker.ietf.org/doc/html/rfc3550#section-6.4.1). -/// -/// ## Params -/// -/// - `now` the current middle 32 bits of an NTP timestamp for the current time. -/// - `delay` the delay(`DLSR`) since last sender report expressed as fractions of a second in 32 bits. -/// - `last_report` the middle 32 bits of an NTP timestamp for the most recent sender report(LSR) or Receiver Report(LRR). -fn calculate_rtt_ms(now: u32, delay: u32, last_report: u32) -> Option { - // [10 Nov 1995 11:33:25.125 UTC] [10 Nov 1995 11:33:36.5 UTC] - // n SR(n) A=b710:8000 (46864.500 s) - // ----------------------------------------------------------------> - // v ^ - // ntp_sec =0xb44db705 v ^ dlsr=0x0005:4000 ( 5.250s) - // ntp_frac=0x20000000 v ^ lsr =0xb705:2000 (46853.125s) - // (3024992005.125 s) v ^ - // r v ^ RR(n) - // ----------------------------------------------------------------> - // |<-DLSR->| - // (5.250 s) - // - // A 0xb710:8000 (46864.500 s) - // DLSR -0x0005:4000 ( 5.250 s) - // LSR -0xb705:2000 (46853.125 s) - // ------------------------------- - // delay 0x0006:2000 ( 6.125 s) - - let rtt = now.checked_sub(delay)?.checked_sub(last_report)?; - let rtt_seconds = rtt >> 16; - let rtt_fraction = (rtt & (u16::MAX as u32)) as f64 / (u16::MAX as u32) as f64; - - Some(rtt_seconds as f64 * 1000.0 + rtt_fraction * 1000.0) -} - -#[cfg(test)] -mod test { - // Silence warning on `..Default::default()` with no effect: - #![allow(clippy::needless_update)] - - macro_rules! assert_feq { - ($left: expr, $right: expr) => { - assert_feq!($left, $right, 0.01); - }; - ($left: expr, $right: expr, $eps: expr) => { - if ($left - $right).abs() >= $eps { - panic!("{:?} was not within {:?} of {:?}", $left, $eps, $right); - } - }; - } - - use std::sync::Arc; - use std::time::{Duration, SystemTime}; - - use bytes::Bytes; - use rtcp::extended_report::{DLRRReport, DLRRReportBlock, ExtendedReport}; - use rtcp::payload_feedbacks::full_intra_request::{FirEntry, FullIntraRequest}; - use rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; - use rtcp::receiver_report::ReceiverReport; - use rtcp::reception_report::ReceptionReport; - use rtcp::sender_report::SenderReport; - use rtcp::transport_feedbacks::transport_layer_nack::{NackPair, TransportLayerNack}; - - use super::StatsInterceptor; - use crate::error::Result; - use crate::mock::mock_stream::MockStream; - use crate::stream_info::StreamInfo; - - #[tokio::test] - async fn test_stats_interceptor_rtp() -> Result<()> { - let icpr: Arc<_> = Arc::new(StatsInterceptor::new("Hello".to_owned())); - - let recv_stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - ..Default::default() - }, - icpr.clone(), - ) - .await; - - let send_stream = MockStream::new( - &StreamInfo { - ssrc: 234567, - ..Default::default() - }, - icpr.clone(), - ) - .await; - - recv_stream - .receive_rtp(rtp::packet::Packet { - header: rtp::header::Header { - ssrc: 123456, - ..Default::default() - }, - payload: Bytes::from_static(b"\xde\xad\xbe\xef"), - }) - .await; - - let _ = recv_stream - .read_rtp() - .await - .expect("After calling receive_rtp read_rtp should return Some")?; - - let _ = send_stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - ssrc: 234567, - ..Default::default() - }, - payload: Bytes::from_static(b"\xde\xad\xbe\xef\xde\xad\xbe\xef"), - }) - .await; - - let _ = send_stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - ssrc: 234567, - ..Default::default() - }, - payload: Bytes::from_static(&[0x13, 0x37]), - }) - .await; - - let snapshots = icpr.fetch_inbound_stats(vec![123456]).await; - let recv_snapshot = snapshots[0] - .as_ref() - .expect("Stats should exist for ssrc: 123456"); - assert_eq!(recv_snapshot.packets_received(), 1); - assert_eq!(recv_snapshot.header_bytes_received(), 12); - assert_eq!(recv_snapshot.payload_bytes_received(), 4); - - let snapshots = icpr.fetch_outbound_stats(vec![234567]).await; - let send_snapshot = snapshots[0] - .as_ref() - .expect("Stats should exist for ssrc: 234567"); - assert_eq!(send_snapshot.packets_sent(), 2); - assert_eq!(send_snapshot.header_bytes_sent(), 24); - assert_eq!(send_snapshot.payload_bytes_sent(), 10); - - Ok(()) - } - - #[tokio::test] - async fn test_stats_interceptor_rtcp() -> Result<()> { - let icpr: Arc<_> = Arc::new(StatsInterceptor::with_time_gen("Hello".to_owned(), || { - // 10 Nov 1995 11:33:36.5 UTC - SystemTime::UNIX_EPOCH + Duration::from_secs_f64(816003216.5) - })); - - let recv_stream = MockStream::new( - &StreamInfo { - ssrc: 123456, - ..Default::default() - }, - icpr.clone(), - ) - .await; - - let send_stream = MockStream::new( - &StreamInfo { - ssrc: 234567, - ..Default::default() - }, - icpr.clone(), - ) - .await; - - send_stream - .write_rtcp(&[Box::new(SenderReport { - ssrc: 234567, - reports: vec![ - ReceptionReport { - ssrc: 234567, - last_sequence_number: (5 << 16) | 10, - ..Default::default() - }, - ReceptionReport { - ssrc: 234567, - last_sequence_number: (5 << 16) | 85, - ..Default::default() - }, - ], - ..Default::default() - })]) - .await - .expect("Failed to write RTCP packets"); - - send_stream - .receive_rtcp(vec![ - Box::new(ReceiverReport { - reports: vec![ - ReceptionReport { - ssrc: 234567, - last_sequence_number: (5 << 16) | 64, - total_lost: 5, - ..Default::default() - }, - ReceptionReport { - ssrc: 234567, - last_sender_report: 0xb705_2000, - delay: 0x0005_4000, - last_sequence_number: (5 << 16) | 70, - total_lost: 8, - fraction_lost: 32, - jitter: 2250, - ..Default::default() - }, - ], - ..Default::default() - }), - Box::new(TransportLayerNack { - sender_ssrc: 0, - media_ssrc: 234567, - nacks: vec![NackPair { - packet_id: 5, - lost_packets: 0b0011_0110, - }], - }), - Box::new(TransportLayerNack { - sender_ssrc: 0, - // NB: Different SSRC - media_ssrc: 999999, - nacks: vec![NackPair { - packet_id: 5, - lost_packets: 0b0011_0110, - }], - }), - Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: 234567, - }), - Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: 234567, - }), - Box::new(FullIntraRequest { - sender_ssrc: 0, - media_ssrc: 234567, - fir: vec![ - FirEntry { - ssrc: 234567, - sequence_number: 132, - }, - FirEntry { - ssrc: 234567, - sequence_number: 135, - }, - ], - }), - ]) - .await; - let snapshots = icpr.fetch_outbound_stats(vec![234567]).await; - let send_snapshot = snapshots[0] - .as_ref() - .expect("Outbound Stats should exist for ssrc: 234567"); - - assert!( - send_snapshot.remote_round_trip_time().is_none() - && send_snapshot.remote_round_trip_time_measurements() == 0, - "Before receiving the first RR we should not have a remote round trip time" - ); - let _ = send_stream - .read_rtcp() - .await - .expect("After calling `receive_rtcp`, `read_rtcp` should return some packets"); - - recv_stream - .write_rtcp(&[ - Box::new(TransportLayerNack { - sender_ssrc: 0, - media_ssrc: 123456, - nacks: vec![NackPair { - packet_id: 5, - lost_packets: 0b0011_0111, - }], - }), - Box::new(TransportLayerNack { - sender_ssrc: 0, - // NB: Different SSRC - media_ssrc: 999999, - nacks: vec![NackPair { - packet_id: 5, - lost_packets: 0b1111_0110, - }], - }), - Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: 123456, - }), - Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: 123456, - }), - Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: 123456, - }), - Box::new(FullIntraRequest { - sender_ssrc: 0, - media_ssrc: 123456, - fir: vec![FirEntry { - ssrc: 123456, - sequence_number: 132, - }], - }), - ]) - .await - .expect("Failed to write RTCP packets for recv_stream"); - - recv_stream - .receive_rtcp(vec![ - Box::new(SenderReport { - ssrc: 123456, - ntp_time: 12345, // Used for ordering - packet_count: 52, - octet_count: 8172, - reports: vec![], - ..Default::default() - }), - Box::new(SenderReport { - ssrc: 123456, - ntp_time: 23456, // Used for ordering - packet_count: 82, - octet_count: 10351, - reports: vec![], - ..Default::default() - }), - Box::new(ExtendedReport { - sender_ssrc: 928191, - reports: vec![Box::new(DLRRReportBlock { - reports: vec![DLRRReport { - ssrc: 123456, - last_rr: 0xb705_2000, - dlrr: 0x0005_4000, - }], - })], - }), - Box::new(SenderReport { - // NB: Different SSRC - ssrc: 9999999, - ntp_time: 99999, // Used for ordering - packet_count: 1231, - octet_count: 193812, - reports: vec![], - ..Default::default() - }), - ]) - .await; - - let snapshots = icpr.fetch_inbound_stats(vec![123456]).await; - let recv_snapshot = snapshots[0] - .as_ref() - .expect("Stats should exist for ssrc: 123456"); - assert!( - recv_snapshot.remote_round_trip_time().is_none() - && recv_snapshot.remote_round_trip_time_measurements() == 0, - "Before receiving the first SR/DLRR we should not have a remote round trip time" - ); - - let _ = recv_stream.read_rtcp().await.expect("read_rtcp failed"); - - let snapshots = icpr.fetch_outbound_stats(vec![234567]).await; - let send_snapshot = snapshots[0] - .as_ref() - .expect("Outbound Stats should exist for ssrc: 234567"); - let rtt_ms = send_snapshot.remote_round_trip_time().expect( - "After receiving an RR with a DSLR block we should have a remote round trip time", - ); - assert_feq!(rtt_ms, 6125.0); - - assert_eq!(send_snapshot.nacks_received(), 5); - assert_eq!(send_snapshot.plis_received(), 2); - assert_eq!(send_snapshot.firs_received(), 2); - // Last Seq Num(RR) - total lost(RR) - Initial Seq Num(SR) + 1 - // 70 - 8 - 10 + 1 = 53 - assert_eq!(send_snapshot.remote_packets_received(), 53); - assert_feq!( - send_snapshot - .remote_fraction_lost() - .expect("Should have a fraction lost values after receiving RR"), - 32.0 / 256.0 - ); - assert_eq!(send_snapshot.remote_total_lost(), 8); - assert_eq!(send_snapshot.remote_jitter(), 2250); - - let snapshots = icpr.fetch_inbound_stats(vec![123456]).await; - let recv_snapshot = snapshots[0] - .as_ref() - .expect("Stats should exist for ssrc: 123456"); - assert_eq!(recv_snapshot.nacks_sent(), 6); - assert_eq!(recv_snapshot.plis_sent(), 3); - assert_eq!(recv_snapshot.firs_sent(), 1); - assert_eq!(recv_snapshot.remote_packets_sent(), 82); - assert_eq!(recv_snapshot.remote_bytes_sent(), 10351); - let rtt_ms = recv_snapshot - .remote_round_trip_time() - .expect("After receiving SR and DLRR we should have a round trip time "); - assert_feq!(rtt_ms, 6125.0); - assert_eq!(recv_snapshot.remote_reports_sent(), 2); - assert_eq!(recv_snapshot.remote_round_trip_time_measurements(), 1); - assert_feq!(recv_snapshot.remote_total_round_trip_time(), 6125.0); - - Ok(()) - } -} diff --git a/interceptor/src/stats/mod.rs b/interceptor/src/stats/mod.rs deleted file mode 100644 index 2f1bb0074..000000000 --- a/interceptor/src/stats/mod.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::SystemTime; - -use tokio::time::Duration; - -mod interceptor; - -pub use self::interceptor::StatsInterceptor; - -pub fn make_stats_interceptor(id: &str) -> Arc { - Arc::new(StatsInterceptor::new(id.to_owned())) -} - -/// Types related to inbound RTP streams. -mod inbound { - use std::time::SystemTime; - - use tokio::time::{Duration, Instant}; - - use super::{RTCPStats, RTPStats}; - - #[derive(Debug, Clone)] - /// Stats collected for an inbound RTP stream. - /// Contains both stats relating to the inbound stream and remote stats for the corresponding - /// outbound stream at the remote end. - pub(super) struct StreamStats { - /// Received RTP stats. - pub(super) rtp_stats: RTPStats, - /// Common RTCP stats derived from inbound and outbound RTCP packets. - pub(super) rtcp_stats: RTCPStats, - - /// The last time any stats where update, used for garbage collection to remove obsolete stats. - last_update: Instant, - - /// The number of packets sent as reported in the latest SR from the remote. - remote_packets_sent: u32, - - /// The number of bytes sent as reported in the latest SR from the remote. - remote_bytes_sent: u32, - - /// The total number of sender reports sent by the remote and received. - remote_reports_sent: u64, - - /// The last remote round trip time measurement in ms. [`None`] if no round trip time has - /// been derived yet, or if it wasn't possible to derive it. - remote_round_trip_time: Option, - - /// The cumulative total round trip times reported in ms. - remote_total_round_trip_time: f64, - - /// The total number of measurements of the remote round trip time. - remote_round_trip_time_measurements: u64, - } - - impl Default for StreamStats { - fn default() -> Self { - Self { - rtp_stats: RTPStats::default(), - rtcp_stats: RTCPStats::default(), - last_update: Instant::now(), - remote_packets_sent: 0, - remote_bytes_sent: 0, - remote_reports_sent: 0, - remote_round_trip_time: None, - remote_total_round_trip_time: 0.0, - remote_round_trip_time_measurements: 0, - } - } - } - - impl StreamStats { - pub(super) fn snapshot(&self) -> StatsSnapshot { - self.into() - } - - pub(super) fn mark_updated(&mut self) { - self.last_update = Instant::now(); - } - - pub(super) fn duration_since_last_update(&self) -> Duration { - self.last_update.elapsed() - } - - pub(super) fn record_sender_report(&mut self, packets_sent: u32, bytes_sent: u32) { - self.remote_reports_sent += 1; - self.remote_packets_sent = packets_sent; - self.remote_bytes_sent = bytes_sent; - } - - pub(super) fn record_remote_round_trip_time(&mut self, round_trip_time: Option) { - // Store the latest measurement, even if it's None. - self.remote_round_trip_time = round_trip_time; - - if let Some(rtt) = round_trip_time { - // Only if we have a valid measurement do we update the totals - self.remote_total_round_trip_time += rtt; - self.remote_round_trip_time_measurements += 1; - } - } - } - - /// A point in time snapshot of the stream stats for an inbound RTP stream. - /// - /// Created by [`StreamStats::snapshot`]. - #[derive(Debug)] - pub struct StatsSnapshot { - /// Received RTP stats. - rtp_stats: RTPStats, - /// Common RTCP stats derived from inbound and outbound RTCP packets. - rtcp_stats: RTCPStats, - - /// The number of packets sent as reported in the latest SR from the remote. - remote_packets_sent: u32, - - /// The number of bytes sent as reported in the latest SR from the remote. - remote_bytes_sent: u32, - - /// The total number of sender reports sent by the remote and received. - remote_reports_sent: u64, - - /// The last remote round trip time measurement in ms. [`None`] if no round trip time has - /// been derived yet, or if it wasn't possible to derive it. - remote_round_trip_time: Option, - - /// The cumulative total round trip times reported in ms. - remote_total_round_trip_time: f64, - - /// The total number of measurements of the remote round trip time. - remote_round_trip_time_measurements: u64, - } - - impl StatsSnapshot { - pub fn packets_received(&self) -> u64 { - self.rtp_stats.packets - } - - pub fn payload_bytes_received(&self) -> u64 { - self.rtp_stats.payload_bytes - } - - pub fn header_bytes_received(&self) -> u64 { - self.rtp_stats.header_bytes - } - - pub fn last_packet_received_timestamp(&self) -> Option { - self.rtp_stats.last_packet_timestamp - } - - pub fn nacks_sent(&self) -> u64 { - self.rtcp_stats.nack_count - } - - pub fn firs_sent(&self) -> u64 { - self.rtcp_stats.fir_count - } - - pub fn plis_sent(&self) -> u64 { - self.rtcp_stats.pli_count - } - pub fn remote_packets_sent(&self) -> u32 { - self.remote_packets_sent - } - - pub fn remote_bytes_sent(&self) -> u32 { - self.remote_bytes_sent - } - - pub fn remote_reports_sent(&self) -> u64 { - self.remote_reports_sent - } - - pub fn remote_round_trip_time(&self) -> Option { - self.remote_round_trip_time - } - - pub fn remote_total_round_trip_time(&self) -> f64 { - self.remote_total_round_trip_time - } - - pub fn remote_round_trip_time_measurements(&self) -> u64 { - self.remote_round_trip_time_measurements - } - } - - impl From<&StreamStats> for StatsSnapshot { - fn from(stream_stats: &StreamStats) -> Self { - Self { - rtp_stats: stream_stats.rtp_stats.clone(), - rtcp_stats: stream_stats.rtcp_stats.clone(), - remote_packets_sent: stream_stats.remote_packets_sent, - remote_bytes_sent: stream_stats.remote_bytes_sent, - remote_reports_sent: stream_stats.remote_reports_sent, - remote_round_trip_time: stream_stats.remote_round_trip_time, - remote_total_round_trip_time: stream_stats.remote_total_round_trip_time, - remote_round_trip_time_measurements: stream_stats - .remote_round_trip_time_measurements, - } - } - } -} - -/// Types related to outbound RTP streams. -mod outbound { - use std::time::SystemTime; - - use tokio::time::{Duration, Instant}; - - use super::{RTCPStats, RTPStats}; - - #[derive(Debug, Clone)] - /// Stats collected for an outbound RTP stream. - /// Contains both stats relating to the outbound stream and remote stats for the corresponding - /// inbound stream. - pub(super) struct StreamStats { - /// Sent RTP stats. - pub(super) rtp_stats: RTPStats, - /// Common RTCP stats derived from inbound and outbound RTCP packets. - pub(super) rtcp_stats: RTCPStats, - - /// The last time any stats where update, used for garbage collection to remove obsolete stats. - last_update: Instant, - - /// The first value of extended seq num that was sent in an SR for this SSRC. [`None`] before - /// the first SR is sent. - /// - /// Used to calculate packet statistic for remote stats. - initial_outbound_ext_seq_num: Option, - - /// The number of inbound packets received by the remote side for this stream. - remote_packets_received: u64, - - /// The number of lost packets reported by the remote for this tream. - remote_total_lost: u32, - - /// The estimated remote jitter for this stream in timestamp units. - remote_jitter: u32, - - /// The last remote round trip time measurement in ms. [`None`] if no round trip time has - /// been derived yet, or if it wasn't possible to derive it. - remote_round_trip_time: Option, - - /// The cumulative total round trip times reported in ms. - remote_total_round_trip_time: f64, - - /// The total number of measurements of the remote round trip time. - remote_round_trip_time_measurements: u64, - - /// The latest fraction lost value from RR. - remote_fraction_lost: Option, - } - - impl Default for StreamStats { - fn default() -> Self { - Self { - rtp_stats: RTPStats::default(), - rtcp_stats: RTCPStats::default(), - last_update: Instant::now(), - initial_outbound_ext_seq_num: None, - remote_packets_received: 0, - remote_total_lost: 0, - remote_jitter: 0, - remote_round_trip_time: None, - remote_total_round_trip_time: 0.0, - remote_round_trip_time_measurements: 0, - remote_fraction_lost: None, - } - } - } - - impl StreamStats { - pub(super) fn snapshot(&self) -> StatsSnapshot { - self.into() - } - - pub(super) fn mark_updated(&mut self) { - self.last_update = Instant::now(); - } - - pub(super) fn duration_since_last_update(&self) -> Duration { - self.last_update.elapsed() - } - - pub(super) fn update_remote_inbound_packets_received( - &mut self, - rr_ext_seq_num: u32, - rr_total_lost: u32, - ) { - if let Some(initial_ext_seq_num) = self.initial_outbound_ext_seq_num { - // Total number of RTP packets received for this SSRC. - // At the receiving endpoint, this is calculated as defined in [RFC3550] section 6.4.1. - // At the sending endpoint the packetsReceived is estimated by subtracting the - // Cumulative Number of Packets Lost from the Extended Highest Sequence Number Received, - // both reported in the RTCP Receiver Report, and then subtracting the - // initial Extended Sequence Number that was sent to this SSRC in a RTCP Sender Report and then adding one, - // to mirror what is discussed in Appendix A.3 in [RFC3550], but for the sender side. - // If no RTCP Receiver Report has been received yet, then return 0. - self.remote_packets_received = - (rr_ext_seq_num as u64) - (rr_total_lost as u64) - (initial_ext_seq_num as u64) - + 1; - } - } - - #[inline(always)] - pub(super) fn record_sr_ext_seq_num(&mut self, seq_num: u32) { - // Only record the initial value - if self.initial_outbound_ext_seq_num.is_none() { - self.initial_outbound_ext_seq_num = Some(seq_num); - } - } - - pub(super) fn record_remote_round_trip_time(&mut self, round_trip_time: Option) { - // Store the latest measurement, even if it's None. - self.remote_round_trip_time = round_trip_time; - - if let Some(rtt) = round_trip_time { - // Only if we have a valid measurement do we update the totals - self.remote_total_round_trip_time += rtt; - self.remote_round_trip_time_measurements += 1; - } - } - - pub(super) fn update_remote_fraction_lost(&mut self, fraction_lost: u8) { - self.remote_fraction_lost = Some(fraction_lost); - } - - pub(super) fn update_remote_jitter(&mut self, jitter: u32) { - self.remote_jitter = jitter; - } - - pub(super) fn update_remote_total_lost(&mut self, lost: u32) { - self.remote_total_lost = lost; - } - } - - /// A point in time snapshot of the stream stats for an outbound RTP stream. - /// - /// Created by [`StreamStats::snapshot`]. - #[derive(Debug)] - pub struct StatsSnapshot { - /// Sent RTP stats. - rtp_stats: RTPStats, - /// Common RTCP stats derived from inbound and outbound RTCP packets. - rtcp_stats: RTCPStats, - - /// The number of inbound packets received by the remote side for this stream. - remote_packets_received: u64, - - /// The number of lost packets reported by the remote for this tream. - remote_total_lost: u32, - - /// The estimated remote jitter for this stream in timestamp units. - remote_jitter: u32, - - /// The most recent remote round trip time in milliseconds. - remote_round_trip_time: Option, - - /// The cumulative total round trip times reported in ms. - remote_total_round_trip_time: f64, - - /// The total number of measurements of the remote round trip time. - remote_round_trip_time_measurements: u64, - - /// The fraction of packets lost reported for this stream. - /// Calculated as defined in [RFC3550](https://www.rfc-editor.org/rfc/rfc3550) section 6.4.1 and Appendix A.3. - remote_fraction_lost: Option, - } - - impl StatsSnapshot { - pub fn packets_sent(&self) -> u64 { - self.rtp_stats.packets - } - - pub fn payload_bytes_sent(&self) -> u64 { - self.rtp_stats.payload_bytes - } - - pub fn header_bytes_sent(&self) -> u64 { - self.rtp_stats.header_bytes - } - - pub fn last_packet_sent_timestamp(&self) -> Option { - self.rtp_stats.last_packet_timestamp - } - - pub fn nacks_received(&self) -> u64 { - self.rtcp_stats.nack_count - } - - pub fn firs_received(&self) -> u64 { - self.rtcp_stats.fir_count - } - - pub fn plis_received(&self) -> u64 { - self.rtcp_stats.pli_count - } - - /// Packets received on the remote side. - pub fn remote_packets_received(&self) -> u64 { - self.remote_packets_received - } - - /// The number of lost packets reported by the remote for this tream. - pub fn remote_total_lost(&self) -> u32 { - self.remote_total_lost - } - - /// The estimated remote jitter for this stream in timestamp units. - pub fn remote_jitter(&self) -> u32 { - self.remote_jitter - } - - /// The latest RTT in ms if enough data is available to measure it. - pub fn remote_round_trip_time(&self) -> Option { - self.remote_round_trip_time - } - - /// Total RTT in ms. - pub fn remote_total_round_trip_time(&self) -> f64 { - self.remote_total_round_trip_time - } - - /// The number of RTT measurements so far. - pub fn remote_round_trip_time_measurements(&self) -> u64 { - self.remote_round_trip_time_measurements - } - - /// The latest fraction lost value from the remote or None if it hasn't been reported yet. - pub fn remote_fraction_lost(&self) -> Option { - self.remote_fraction_lost - } - } - - impl From<&StreamStats> for StatsSnapshot { - fn from(stream_stats: &StreamStats) -> Self { - Self { - rtp_stats: stream_stats.rtp_stats.clone(), - rtcp_stats: stream_stats.rtcp_stats.clone(), - remote_packets_received: stream_stats.remote_packets_received, - remote_total_lost: stream_stats.remote_total_lost, - remote_jitter: stream_stats.remote_jitter, - remote_round_trip_time: stream_stats.remote_round_trip_time, - remote_total_round_trip_time: stream_stats.remote_total_round_trip_time, - remote_round_trip_time_measurements: stream_stats - .remote_round_trip_time_measurements, - remote_fraction_lost: stream_stats - .remote_fraction_lost - .map(|fraction| (fraction as f64) / (u8::MAX as f64)), - } - } - } -} - -#[derive(Default, Debug)] -struct StatsContainer { - inbound_stats: HashMap, - outbound_stats: HashMap, -} - -impl StatsContainer { - fn get_or_create_inbound_stream_stats(&mut self, ssrc: u32) -> &mut inbound::StreamStats { - self.inbound_stats.entry(ssrc).or_default() - } - - fn get_or_create_outbound_stream_stats(&mut self, ssrc: u32) -> &mut outbound::StreamStats { - self.outbound_stats.entry(ssrc).or_default() - } - - fn get_inbound_stats(&self, ssrc: u32) -> Option<&inbound::StreamStats> { - self.inbound_stats.get(&ssrc) - } - - fn get_outbound_stats(&self, ssrc: u32) -> Option<&outbound::StreamStats> { - self.outbound_stats.get(&ssrc) - } - - fn remove_stale_entries(&mut self) { - const MAX_AGE: Duration = Duration::from_secs(60); - - self.inbound_stats - .retain(|_, s| s.duration_since_last_update() < MAX_AGE); - self.outbound_stats - .retain(|_, s| s.duration_since_last_update() < MAX_AGE); - } -} - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -/// Records stats about a given RTP stream. -pub struct RTPStats { - /// Packets sent or received - packets: u64, - - /// Payload bytes sent or received - payload_bytes: u64, - - /// Header bytes sent or received - header_bytes: u64, - - /// A wall clock timestamp for when the last packet was sent or received encoded as milliseconds since - /// [`SystemTime::UNIX_EPOCH`]. - last_packet_timestamp: Option, -} - -impl RTPStats { - fn update(&mut self, header_bytes: u64, payload_bytes: u64, packets: u64, now: SystemTime) { - self.header_bytes += header_bytes; - self.payload_bytes += payload_bytes; - self.packets += packets; - self.last_packet_timestamp = Some(now); - } - - pub fn header_bytes(&self) -> u64 { - self.header_bytes - } - - pub fn payload_bytes(&self) -> u64 { - self.payload_bytes - } - - pub fn packets(&self) -> u64 { - self.packets - } - - pub fn last_packet_timestamp(&self) -> Option { - self.last_packet_timestamp - } -} - -#[derive(Debug, Default, Clone)] -pub struct RTCPStats { - /// The number of FIRs sent or received - fir_count: u64, - - /// The number of PLIs sent or received - pli_count: u64, - - /// The number of NACKs sent or received - nack_count: u64, -} - -impl RTCPStats { - #[allow(clippy::too_many_arguments)] - fn update(&mut self, fir_count: Option, pli_count: Option, nack_count: Option) { - if let Some(fir_count) = fir_count { - self.fir_count += fir_count; - } - - if let Some(pli_count) = pli_count { - self.pli_count += pli_count; - } - - if let Some(nack_count) = nack_count { - self.nack_count += nack_count; - } - } - - pub fn fir_count(&self) -> u64 { - self.fir_count - } - - pub fn pli_count(&self) -> u64 { - self.pli_count - } - - pub fn nack_count(&self) -> u64 { - self.nack_count - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_rtp_stats() { - let mut stats: RTPStats = Default::default(); - assert_eq!( - (stats.header_bytes(), stats.payload_bytes(), stats.packets()), - (0, 0, 0), - ); - - stats.update(24, 960, 1, SystemTime::now()); - - assert_eq!( - (stats.header_bytes(), stats.payload_bytes(), stats.packets()), - (24, 960, 1), - ); - } - - #[test] - fn test_rtcp_stats() { - let mut stats: RTCPStats = Default::default(); - assert_eq!( - (stats.fir_count(), stats.pli_count(), stats.nack_count()), - (0, 0, 0), - ); - - stats.update(Some(1), Some(2), Some(3)); - - assert_eq!( - (stats.fir_count(), stats.pli_count(), stats.nack_count()), - (1, 2, 3), - ); - } - - #[test] - fn test_rtp_stats_send_sync() { - fn test_send_sync() {} - test_send_sync::(); - } - - #[test] - fn test_rtcp_stats_send_sync() { - fn test_send_sync() {} - test_send_sync::(); - } -} diff --git a/interceptor/src/stream_info.rs b/interceptor/src/stream_info.rs deleted file mode 100644 index 5e9f93d5e..000000000 --- a/interceptor/src/stream_info.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::Attributes; - -/// RTPHeaderExtension represents a negotiated RFC5285 RTP header extension. -#[derive(Default, Debug, Clone)] -pub struct RTPHeaderExtension { - pub uri: String, - pub id: isize, -} - -/// StreamInfo is the Context passed when a StreamLocal or StreamRemote has been Binded or Unbinded -#[derive(Default, Debug, Clone)] -pub struct StreamInfo { - pub id: String, - pub attributes: Attributes, - pub ssrc: u32, - pub payload_type: u8, - pub rtp_header_extensions: Vec, - pub mime_type: String, - pub clock_rate: u32, - pub channels: u16, - pub sdp_fmtp_line: String, - pub rtcp_feedback: Vec, -} - -/// RTCPFeedback signals the connection to use additional RTCP packet types. -/// -#[derive(Default, Debug, Clone)] -pub struct RTCPFeedback { - /// Type is the type of feedback. - /// see: - /// valid: ack, ccm, nack, goog-remb, transport-cc - pub typ: String, - - /// The parameter value depends on the type. - /// For example, type="nack" parameter="pli" will send Picture Loss Indicator packets. - pub parameter: String, -} diff --git a/interceptor/src/stream_reader.rs b/interceptor/src/stream_reader.rs deleted file mode 100644 index 4f4578631..000000000 --- a/interceptor/src/stream_reader.rs +++ /dev/null @@ -1,27 +0,0 @@ -use async_trait::async_trait; -use srtp::stream::Stream; - -use crate::error::Result; -use crate::{Attributes, RTCPReader, RTPReader}; - -#[async_trait] -impl RTPReader for Stream { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - Ok((self.read_rtp(buf).await?, a.clone())) - } -} - -#[async_trait] -impl RTCPReader for Stream { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> Result<(Vec>, Attributes)> { - Ok((self.read_rtcp(buf).await?, a.clone())) - } -} diff --git a/interceptor/src/twcc/mod.rs b/interceptor/src/twcc/mod.rs deleted file mode 100644 index ed3fb10dc..000000000 --- a/interceptor/src/twcc/mod.rs +++ /dev/null @@ -1,279 +0,0 @@ -#[cfg(test)] -mod twcc_test; - -pub mod receiver; -pub mod sender; - -use std::cmp::Ordering; - -use rtcp::transport_feedbacks::transport_layer_cc::{ - PacketStatusChunk, RecvDelta, RunLengthChunk, StatusChunkTypeTcc, StatusVectorChunk, - SymbolSizeTypeTcc, SymbolTypeTcc, TransportLayerCc, -}; - -#[derive(Default, Debug, PartialEq, Clone)] -struct PktInfo { - sequence_number: u32, - arrival_time: i64, -} - -/// Recorder records incoming RTP packets and their delays and creates -/// transport wide congestion control feedback reports as specified in -/// -#[derive(Default, Debug, PartialEq, Clone)] -pub struct Recorder { - received_packets: Vec, - - cycles: u32, - last_sequence_number: u16, - - sender_ssrc: u32, - media_ssrc: u32, - fb_pkt_cnt: u8, -} - -impl Recorder { - /// new creates a new Recorder which uses the given sender_ssrc in the created - /// feedback packets. - pub fn new(sender_ssrc: u32) -> Self { - Recorder { - sender_ssrc, - ..Default::default() - } - } - - /// record marks a packet with media_ssrc and a transport wide sequence number sequence_number as received at arrival_time. - pub fn record(&mut self, media_ssrc: u32, sequence_number: u16, arrival_time: i64) { - self.media_ssrc = media_ssrc; - if sequence_number < 0x0fff && self.last_sequence_number > 0xf000 { - self.cycles += 1 << 16; - } - self.received_packets.push(PktInfo { - sequence_number: self.cycles | sequence_number as u32, - arrival_time, - }); - self.last_sequence_number = sequence_number; - } - - /// build_feedback_packet creates a new RTCP packet containing a TWCC feedback report. - pub fn build_feedback_packet(&mut self) -> Vec> { - if self.received_packets.len() < 2 { - return vec![]; - } - let mut feedback = Feedback::new(self.sender_ssrc, self.media_ssrc, self.fb_pkt_cnt); - self.fb_pkt_cnt = self.fb_pkt_cnt.wrapping_add(1); - - self.received_packets - .sort_by(|a: &PktInfo, b: &PktInfo| -> Ordering { - a.sequence_number.cmp(&b.sequence_number) - }); - feedback.set_base( - (self.received_packets[0].sequence_number & 0xffff) as u16, - self.received_packets[0].arrival_time, - ); - - let mut pkts = vec![]; - for pkt in &self.received_packets { - let built = - feedback.add_received((pkt.sequence_number & 0xffff) as u16, pkt.arrival_time); - if !built { - let p: Box = Box::new(feedback.get_rtcp()); - pkts.push(p); - feedback = Feedback::new(self.sender_ssrc, self.media_ssrc, self.fb_pkt_cnt); - self.fb_pkt_cnt = self.fb_pkt_cnt.wrapping_add(1); - feedback.add_received((pkt.sequence_number & 0xffff) as u16, pkt.arrival_time); - } - } - self.received_packets.clear(); - let p: Box = Box::new(feedback.get_rtcp()); - pkts.push(p); - pkts - } -} - -#[derive(Default, Debug, PartialEq, Clone)] -struct Feedback { - rtcp: TransportLayerCc, - base_sequence_number: u16, - ref_timestamp64ms: i64, - last_timestamp_us: i64, - next_sequence_number: u16, - sequence_number_count: u16, - len: usize, - last_chunk: Chunk, - chunks: Vec, - deltas: Vec, -} - -impl Feedback { - fn new(sender_ssrc: u32, media_ssrc: u32, fb_pkt_count: u8) -> Self { - Feedback { - rtcp: TransportLayerCc { - sender_ssrc, - media_ssrc, - fb_pkt_count, - ..Default::default() - }, - ..Default::default() - } - } - - fn set_base(&mut self, sequence_number: u16, time_us: i64) { - self.base_sequence_number = sequence_number; - self.next_sequence_number = self.base_sequence_number; - self.ref_timestamp64ms = time_us / 64000; - self.last_timestamp_us = self.ref_timestamp64ms * 64000; - } - - fn get_rtcp(&mut self) -> TransportLayerCc { - self.rtcp.packet_status_count = self.sequence_number_count; - self.rtcp.reference_time = self.ref_timestamp64ms as u32; - self.rtcp.base_sequence_number = self.base_sequence_number; - while !self.last_chunk.deltas.is_empty() { - self.chunks.push(self.last_chunk.encode()); - } - self.rtcp.packet_chunks.extend_from_slice(&self.chunks); - self.rtcp.recv_deltas.clone_from(&self.deltas); - - self.rtcp.clone() - } - - fn add_received(&mut self, sequence_number: u16, timestamp_us: i64) -> bool { - let delta_us = timestamp_us - self.last_timestamp_us; - let delta250us = delta_us / 250; - if delta250us < i16::MIN as i64 || delta250us > i16::MAX as i64 { - // delta doesn't fit into 16 bit, need to create new packet - return false; - } - - while self.next_sequence_number != sequence_number { - if !self - .last_chunk - .can_add(SymbolTypeTcc::PacketNotReceived as u16) - { - self.chunks.push(self.last_chunk.encode()); - } - self.last_chunk.add(SymbolTypeTcc::PacketNotReceived as u16); - self.sequence_number_count = self.sequence_number_count.wrapping_add(1); - self.next_sequence_number = self.next_sequence_number.wrapping_add(1); - } - - let recv_delta = if (0..=0xff).contains(&delta250us) { - self.len += 1; - SymbolTypeTcc::PacketReceivedSmallDelta - } else { - self.len += 2; - SymbolTypeTcc::PacketReceivedLargeDelta - }; - - if !self.last_chunk.can_add(recv_delta as u16) { - self.chunks.push(self.last_chunk.encode()); - } - self.last_chunk.add(recv_delta as u16); - self.deltas.push(RecvDelta { - type_tcc_packet: recv_delta, - delta: delta_us, - }); - self.last_timestamp_us = timestamp_us; - self.sequence_number_count = self.sequence_number_count.wrapping_add(1); - self.next_sequence_number = self.next_sequence_number.wrapping_add(1); - true - } -} - -const MAX_RUN_LENGTH_CAP: usize = 0x1fff; // 13 bits -const MAX_ONE_BIT_CAP: usize = 14; // bits -const MAX_TWO_BIT_CAP: usize = 7; // bits - -#[derive(Default, Debug, PartialEq, Clone)] -struct Chunk { - has_large_delta: bool, - has_different_types: bool, - deltas: Vec, -} - -impl Chunk { - fn can_add(&self, delta: u16) -> bool { - if self.deltas.len() < MAX_TWO_BIT_CAP { - return true; - } - if self.deltas.len() < MAX_ONE_BIT_CAP - && !self.has_large_delta - && delta != SymbolTypeTcc::PacketReceivedLargeDelta as u16 - { - return true; - } - if self.deltas.len() < MAX_RUN_LENGTH_CAP - && !self.has_different_types - && delta == self.deltas[0] - { - return true; - } - false - } - - fn add(&mut self, delta: u16) { - self.deltas.push(delta); - self.has_large_delta = - self.has_large_delta || delta == SymbolTypeTcc::PacketReceivedLargeDelta as u16; - self.has_different_types = self.has_different_types || delta != self.deltas[0]; - } - - fn encode(&mut self) -> PacketStatusChunk { - if !self.has_different_types { - let p = PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: self.deltas[0].into(), - run_length: self.deltas.len() as u16, - }); - self.reset(); - return p; - } - if self.deltas.len() == MAX_ONE_BIT_CAP { - let p = PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: self - .deltas - .iter() - .map(|x| SymbolTypeTcc::from(*x)) - .collect::>(), - }); - self.reset(); - return p; - } - - let min_cap = std::cmp::min(MAX_TWO_BIT_CAP, self.deltas.len()); - let svc = PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: self.deltas[..min_cap] - .iter() - .map(|x| SymbolTypeTcc::from(*x)) - .collect::>(), - }); - self.deltas.drain(..min_cap); - self.has_different_types = false; - self.has_large_delta = false; - - if !self.deltas.is_empty() { - let tmp = self.deltas[0]; - for d in &self.deltas { - if tmp != *d { - self.has_different_types = true; - } - if *d == SymbolTypeTcc::PacketReceivedLargeDelta as u16 { - self.has_large_delta = true; - } - } - } - - svc - } - - fn reset(&mut self) { - self.deltas = vec![]; - self.has_large_delta = false; - self.has_different_types = false; - } -} diff --git a/interceptor/src/twcc/receiver/mod.rs b/interceptor/src/twcc/receiver/mod.rs deleted file mode 100644 index aaaf81369..000000000 --- a/interceptor/src/twcc/receiver/mod.rs +++ /dev/null @@ -1,262 +0,0 @@ -mod receiver_stream; -#[cfg(test)] -mod receiver_test; - -use std::time::Duration; - -use receiver_stream::ReceiverStream; -use rtp::extension::transport_cc_extension::TransportCcExtension; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::MissedTickBehavior; -use util::Unmarshal; -use waitgroup::WaitGroup; - -use crate::twcc::sender::TRANSPORT_CC_URI; -use crate::twcc::Recorder; -use crate::*; - -/// ReceiverBuilder is a InterceptorBuilder for a SenderInterceptor -#[derive(Default)] -pub struct ReceiverBuilder { - interval: Option, -} - -impl ReceiverBuilder { - /// with_interval sets send interval for the interceptor. - pub fn with_interval(mut self, interval: Duration) -> ReceiverBuilder { - self.interval = Some(interval); - self - } -} - -impl InterceptorBuilder for ReceiverBuilder { - fn build(&self, _id: &str) -> Result> { - let (close_tx, close_rx) = mpsc::channel(1); - let (packet_chan_tx, packet_chan_rx) = mpsc::channel(1); - Ok(Arc::new(Receiver { - internal: Arc::new(ReceiverInternal { - interval: if let Some(interval) = &self.interval { - *interval - } else { - Duration::from_millis(100) - }, - recorder: Mutex::new(Recorder::default()), - packet_chan_rx: Mutex::new(Some(packet_chan_rx)), - streams: Mutex::new(HashMap::new()), - close_rx: Mutex::new(Some(close_rx)), - }), - start_time: tokio::time::Instant::now(), - packet_chan_tx, - wg: Mutex::new(Some(WaitGroup::new())), - close_tx: Mutex::new(Some(close_tx)), - })) - } -} - -struct Packet { - hdr: rtp::header::Header, - sequence_number: u16, - arrival_time: i64, - ssrc: u32, -} - -struct ReceiverInternal { - interval: Duration, - recorder: Mutex, - packet_chan_rx: Mutex>>, - streams: Mutex>>, - close_rx: Mutex>>, -} - -/// Receiver sends transport-wide congestion control reports as specified in: -/// -pub struct Receiver { - internal: Arc, - - // we use tokio's Instant because it makes testing easier via `tokio::time::advance`. - start_time: tokio::time::Instant, - packet_chan_tx: mpsc::Sender, - - wg: Mutex>, - close_tx: Mutex>>, -} - -impl Receiver { - /// builder returns a new ReceiverBuilder. - pub fn builder() -> ReceiverBuilder { - ReceiverBuilder::default() - } - - async fn is_closed(&self) -> bool { - let close_tx = self.close_tx.lock().await; - close_tx.is_none() - } - - async fn run( - rtcp_writer: Arc, - internal: Arc, - ) -> Result<()> { - let mut close_rx = { - let mut close_rx = internal.close_rx.lock().await; - if let Some(close_rx) = close_rx.take() { - close_rx - } else { - return Err(Error::ErrInvalidCloseRx); - } - }; - let mut packet_chan_rx = { - let mut packet_chan_rx = internal.packet_chan_rx.lock().await; - if let Some(packet_chan_rx) = packet_chan_rx.take() { - packet_chan_rx - } else { - return Err(Error::ErrInvalidPacketRx); - } - }; - - let a = Attributes::new(); - let mut ticker = tokio::time::interval(internal.interval); - ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); - loop { - tokio::select! { - _ = close_rx.recv() =>{ - return Ok(()); - } - p = packet_chan_rx.recv() => { - if let Some(p) = p { - let mut recorder = internal.recorder.lock().await; - recorder.record(p.ssrc, p.sequence_number, p.arrival_time); - } - } - _ = ticker.tick() =>{ - // build and send twcc - let pkts = { - let mut recorder = internal.recorder.lock().await; - recorder.build_feedback_packet() - }; - - if pkts.is_empty() { - continue; - } - - if let Err(err) = rtcp_writer.write(&pkts, &a).await{ - log::error!("rtcp_writer.write got err: {}", err); - } - } - } - } - } -} - -#[async_trait] -impl Interceptor for Receiver { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - reader - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - if self.is_closed().await { - return writer; - } - - { - let mut recorder = self.internal.recorder.lock().await; - *recorder = Recorder::new(rand::random::()); - } - - let mut w = { - let wait_group = self.wg.lock().await; - wait_group.as_ref().map(|wg| wg.worker()) - }; - let writer2 = Arc::clone(&writer); - let internal = Arc::clone(&self.internal); - tokio::spawn(async move { - let _d = w.take(); - if let Err(err) = Receiver::run(writer2, internal).await { - log::warn!("bind_rtcp_writer TWCC Sender::run got error: {}", err); - } - }); - - writer - } - - /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method - /// will be called once per rtp packet. - async fn bind_local_stream( - &self, - _info: &StreamInfo, - writer: Arc, - ) -> Arc { - writer - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, _info: &StreamInfo) {} - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - info: &StreamInfo, - reader: Arc, - ) -> Arc { - let mut hdr_ext_id = 0u8; - for e in &info.rtp_header_extensions { - if e.uri == TRANSPORT_CC_URI { - hdr_ext_id = e.id as u8; - break; - } - } - if hdr_ext_id == 0 { - // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID - return reader; - } - - let stream = Arc::new(ReceiverStream::new( - reader, - hdr_ext_id, - info.ssrc, - self.packet_chan_tx.clone(), - self.start_time, - )); - - { - let mut streams = self.internal.streams.lock().await; - streams.insert(info.ssrc, Arc::clone(&stream)); - } - - stream - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, info: &StreamInfo) { - let mut streams = self.internal.streams.lock().await; - streams.remove(&info.ssrc); - } - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - { - let mut close_tx = self.close_tx.lock().await; - close_tx.take(); - } - - { - let mut wait_group = self.wg.lock().await; - if let Some(wg) = wait_group.take() { - wg.wait().await; - } - } - - Ok(()) - } -} diff --git a/interceptor/src/twcc/receiver/receiver_stream.rs b/interceptor/src/twcc/receiver/receiver_stream.rs deleted file mode 100644 index 764b26c9d..000000000 --- a/interceptor/src/twcc/receiver/receiver_stream.rs +++ /dev/null @@ -1,57 +0,0 @@ -use super::*; - -pub(super) struct ReceiverStream { - parent_rtp_reader: Arc, - hdr_ext_id: u8, - ssrc: u32, - packet_chan_tx: mpsc::Sender, - // we use tokio's Instant because it makes testing easier via `tokio::time::advance`. - start_time: tokio::time::Instant, -} - -impl ReceiverStream { - pub(super) fn new( - parent_rtp_reader: Arc, - hdr_ext_id: u8, - ssrc: u32, - packet_chan_tx: mpsc::Sender, - start_time: tokio::time::Instant, - ) -> Self { - ReceiverStream { - parent_rtp_reader, - hdr_ext_id, - ssrc, - packet_chan_tx, - start_time, - } - } -} - -#[async_trait] -impl RTPReader for ReceiverStream { - /// read a rtp packet - async fn read( - &self, - buf: &mut [u8], - attributes: &Attributes, - ) -> Result<(rtp::packet::Packet, Attributes)> { - let (pkt, attr) = self.parent_rtp_reader.read(buf, attributes).await?; - - if let Some(mut ext) = pkt.header.get_extension(self.hdr_ext_id) { - let tcc_ext = TransportCcExtension::unmarshal(&mut ext)?; - - let _ = self - .packet_chan_tx - .send(Packet { - hdr: pkt.header.clone(), - sequence_number: tcc_ext.transport_sequence, - arrival_time: (tokio::time::Instant::now() - self.start_time).as_micros() - as i64, - ssrc: self.ssrc, - }) - .await; - } - - Ok((pkt, attr)) - } -} diff --git a/interceptor/src/twcc/receiver/receiver_test.rs b/interceptor/src/twcc/receiver/receiver_test.rs deleted file mode 100644 index 24f60f184..000000000 --- a/interceptor/src/twcc/receiver/receiver_test.rs +++ /dev/null @@ -1,361 +0,0 @@ -// Silence warning on `..Default::default()` with no effect: -#![allow(clippy::needless_update)] - -use rtcp::transport_feedbacks::transport_layer_cc::{ - PacketStatusChunk, RunLengthChunk, StatusChunkTypeTcc, StatusVectorChunk, SymbolSizeTypeTcc, - SymbolTypeTcc, TransportLayerCc, -}; -use util::Marshal; - -use super::*; -use crate::mock::mock_stream::MockStream; -use crate::stream_info::RTPHeaderExtension; - -#[tokio::test] -async fn test_twcc_receiver_interceptor_before_any_packets() -> Result<()> { - let builder = Receiver::builder(); - let icpr = builder.build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtp_header_extensions: vec![RTPHeaderExtension { - uri: TRANSPORT_CC_URI.to_owned(), - id: 1, - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - tokio::select! { - pkts = stream.written_rtcp() => { - assert!(pkts.map(|p| p.is_empty()).unwrap_or(true), "Should not have sent an RTCP packet before receiving the first RTP packets") - } - _ = tokio::time::sleep(Duration::from_millis(300)) => { - // All good - } - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_twcc_receiver_interceptor_after_rtp_packets() -> Result<()> { - let builder = Receiver::builder(); - let icpr = builder.build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtp_header_extensions: vec![RTPHeaderExtension { - uri: TRANSPORT_CC_URI.to_owned(), - id: 1, - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - for i in 0..10 { - let mut hdr = rtp::header::Header::default(); - let tcc = TransportCcExtension { - transport_sequence: i, - } - .marshal()?; - hdr.set_extension(1, tcc)?; - stream - .receive_rtp(rtp::packet::Packet { - header: hdr, - ..Default::default() - }) - .await; - } - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(cc) = pkts[0].as_any().downcast_ref::() { - assert_eq!(cc.media_ssrc, 1); - assert_eq!(cc.base_sequence_number, 0); - assert_eq!( - cc.packet_chunks, - vec![PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 10, - })] - ); - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test(start_paused = true)] -async fn test_twcc_receiver_interceptor_different_delays_between_rtp_packets() -> Result<()> { - let builder = Receiver::builder().with_interval(Duration::from_millis(500)); - let icpr = builder.build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtp_header_extensions: vec![RTPHeaderExtension { - uri: TRANSPORT_CC_URI.to_owned(), - id: 1, - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - let delays = [0, 10, 100, 200]; - for (i, d) in delays.iter().enumerate() { - tokio::time::advance(Duration::from_millis(*d)).await; - - let mut hdr = rtp::header::Header::default(); - let tcc = TransportCcExtension { - transport_sequence: i as u16, - } - .marshal()?; - - hdr.set_extension(1, tcc)?; - stream - .receive_rtp(rtp::packet::Packet { - header: hdr, - ..Default::default() - }) - .await; - - // Yield so this packet can be processed - tokio::task::yield_now().await; - } - - // Force a packet to be generated - tokio::time::advance(Duration::from_millis(2001)).await; - tokio::task::yield_now().await; - - let pkts = stream.written_rtcp().await.unwrap(); - - assert_eq!(pkts.len(), 1); - if let Some(cc) = pkts[0].as_any().downcast_ref::() { - assert_eq!(cc.base_sequence_number, 0); - assert_eq!( - cc.packet_chunks, - vec![PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketReceivedLargeDelta, - ], - })] - ); - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test(start_paused = true)] -async fn test_twcc_receiver_interceptor_packet_loss() -> Result<()> { - let builder = Receiver::builder().with_interval(Duration::from_secs(2)); - let icpr = builder.build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtp_header_extensions: vec![RTPHeaderExtension { - uri: TRANSPORT_CC_URI.to_owned(), - id: 1, - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - let sequence_number_to_delay = &[ - (0, 0), - (1, 10), - (4, 100), - (8, 200), - (9, 20), - (10, 20), - (30, 300), - ]; - - for (i, d) in sequence_number_to_delay { - tokio::time::advance(Duration::from_millis(*d)).await; - let mut hdr = rtp::header::Header::default(); - let tcc = TransportCcExtension { - transport_sequence: *i, - } - .marshal()?; - hdr.set_extension(1, tcc)?; - stream - .receive_rtp(rtp::packet::Packet { - header: hdr, - ..Default::default() - }) - .await; - - // Yield so this packet can be processed - tokio::task::yield_now().await; - } - - // Force a packet to be generated - tokio::time::advance(Duration::from_millis(2001)).await; - tokio::task::yield_now().await; - - let pkts = stream.written_rtcp().await.unwrap(); - - assert_eq!(pkts.len(), 1); - if let Some(cc) = pkts[0].as_any().downcast_ref::() { - assert_eq!(cc.base_sequence_number, 0); - assert_eq!( - cc.packet_chunks, - vec![ - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }), - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }), - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketNotReceived, - run_length: 16, - }), - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedLargeDelta, - run_length: 1, - }), - ] - ); - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_twcc_receiver_interceptor_overflow() -> Result<()> { - let builder = Receiver::builder(); - let icpr = builder.build("")?; - - let stream = MockStream::new( - &StreamInfo { - ssrc: 1, - rtp_header_extensions: vec![RTPHeaderExtension { - uri: TRANSPORT_CC_URI.to_owned(), - id: 1, - ..Default::default() - }], - ..Default::default() - }, - icpr, - ) - .await; - - for i in [65530, 65534, 65535, 1, 2, 10] { - let mut hdr = rtp::header::Header::default(); - let tcc = TransportCcExtension { - transport_sequence: i, - } - .marshal()?; - hdr.set_extension(1, tcc)?; - stream - .receive_rtp(rtp::packet::Packet { - header: hdr, - ..Default::default() - }) - .await; - } - - let pkts = stream.written_rtcp().await.unwrap(); - assert_eq!(pkts.len(), 1); - if let Some(cc) = pkts[0].as_any().downcast_ref::() { - assert_eq!(cc.base_sequence_number, 65530); - assert_eq!( - cc.packet_chunks, - vec![ - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }), - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - ], - }), - ] - ); - } else { - panic!(); - } - - stream.close().await?; - - Ok(()) -} diff --git a/interceptor/src/twcc/sender/mod.rs b/interceptor/src/twcc/sender/mod.rs deleted file mode 100644 index d3ed5673d..000000000 --- a/interceptor/src/twcc/sender/mod.rs +++ /dev/null @@ -1,132 +0,0 @@ -mod sender_stream; -#[cfg(test)] -mod sender_test; - -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicU32; -use rtp::extension::transport_cc_extension::TransportCcExtension; -use sender_stream::SenderStream; -use tokio::sync::Mutex; -use util::Marshal; - -use crate::{Attributes, RTPWriter, *}; - -pub(crate) const TRANSPORT_CC_URI: &str = - "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"; - -/// HeaderExtensionBuilder is a InterceptorBuilder for a HeaderExtension Interceptor -#[derive(Default)] -pub struct SenderBuilder { - init_sequence_nr: u32, -} - -impl SenderBuilder { - /// with_init_sequence_nr sets the init sequence number of the interceptor. - pub fn with_init_sequence_nr(mut self, init_sequence_nr: u32) -> SenderBuilder { - self.init_sequence_nr = init_sequence_nr; - self - } -} - -impl InterceptorBuilder for SenderBuilder { - /// build constructs a new SenderInterceptor - fn build(&self, _id: &str) -> Result> { - Ok(Arc::new(Sender { - next_sequence_nr: Arc::new(AtomicU32::new(self.init_sequence_nr)), - streams: Mutex::new(HashMap::new()), - })) - } -} - -/// Sender adds transport wide sequence numbers as header extension to each RTP packet -pub struct Sender { - next_sequence_nr: Arc, - streams: Mutex>>, -} - -impl Sender { - /// builder returns a new SenderBuilder. - pub fn builder() -> SenderBuilder { - SenderBuilder::default() - } -} - -#[async_trait] -impl Interceptor for Sender { - /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might - /// change in the future. The returned method will be called once per packet batch. - async fn bind_rtcp_reader( - &self, - reader: Arc, - ) -> Arc { - reader - } - - /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method - /// will be called once per packet batch. - async fn bind_rtcp_writer( - &self, - writer: Arc, - ) -> Arc { - writer - } - - /// bind_local_stream returns a writer that adds a rtp TransportCCExtension - /// header with increasing sequence numbers to each outgoing packet. - async fn bind_local_stream( - &self, - info: &StreamInfo, - writer: Arc, - ) -> Arc { - let mut hdr_ext_id = 0u8; - for e in &info.rtp_header_extensions { - if e.uri == TRANSPORT_CC_URI { - hdr_ext_id = e.id as u8; - break; - } - } - if hdr_ext_id == 0 { - // Don't add header extension if ID is 0, because 0 is an invalid extension ID - return writer; - } - - let stream = Arc::new(SenderStream::new( - writer, - Arc::clone(&self.next_sequence_nr), - hdr_ext_id, - )); - - { - let mut streams = self.streams.lock().await; - streams.insert(info.ssrc, Arc::clone(&stream)); - } - - stream - } - - /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_local_stream(&self, info: &StreamInfo) { - let mut streams = self.streams.lock().await; - streams.remove(&info.ssrc); - } - - /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method - /// will be called once per rtp packet. - async fn bind_remote_stream( - &self, - _info: &StreamInfo, - reader: Arc, - ) -> Arc { - reader - } - - /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track. - async fn unbind_remote_stream(&self, _info: &StreamInfo) {} - - /// close closes the Interceptor, cleaning up any data if necessary. - async fn close(&self) -> Result<()> { - Ok(()) - } -} diff --git a/interceptor/src/twcc/sender/sender_stream.rs b/interceptor/src/twcc/sender/sender_stream.rs deleted file mode 100644 index 29754070d..000000000 --- a/interceptor/src/twcc/sender/sender_stream.rs +++ /dev/null @@ -1,40 +0,0 @@ -use super::*; - -pub(super) struct SenderStream { - next_rtp_writer: Arc, - next_sequence_nr: Arc, - hdr_ext_id: u8, -} - -impl SenderStream { - pub(super) fn new( - next_rtp_writer: Arc, - next_sequence_nr: Arc, - hdr_ext_id: u8, - ) -> Self { - SenderStream { - next_rtp_writer, - next_sequence_nr, - hdr_ext_id, - } - } -} - -/// RTPWriter is used by Interceptor.bind_local_stream. -#[async_trait] -impl RTPWriter for SenderStream { - /// write a rtp packet - async fn write(&self, pkt: &rtp::packet::Packet, a: &Attributes) -> Result { - let sequence_number = self.next_sequence_nr.fetch_add(1, Ordering::SeqCst); - - let tcc_ext = TransportCcExtension { - transport_sequence: sequence_number as u16, - }; - let tcc_payload = tcc_ext.marshal()?; - - let mut pkt = pkt.clone(); - pkt.header.set_extension(self.hdr_ext_id, tcc_payload)?; - - self.next_rtp_writer.write(&pkt, a).await - } -} diff --git a/interceptor/src/twcc/sender/sender_test.rs b/interceptor/src/twcc/sender/sender_test.rs deleted file mode 100644 index 4ae519792..000000000 --- a/interceptor/src/twcc/sender/sender_test.rs +++ /dev/null @@ -1,85 +0,0 @@ -use rtp::packet::Packet; -use tokio::sync::mpsc; -use tokio::time::Duration; -use util::Unmarshal; -use waitgroup::WaitGroup; - -use super::*; -use crate::mock::mock_stream::MockStream; -use crate::stream_info::RTPHeaderExtension; - -#[tokio::test] -async fn test_twcc_sender_interceptor() -> Result<()> { - // "add transport wide cc to each packet" - let builder = Sender::builder().with_init_sequence_nr(0); - let icpr = builder.build("")?; - - let (p_chan_tx, mut p_chan_rx) = mpsc::channel::(10 * 5); - tokio::spawn(async move { - // start some parallel streams using the same interceptor to test for race conditions - let wg = WaitGroup::new(); - for i in 0..10 { - let w = wg.worker(); - let p_chan_tx2 = p_chan_tx.clone(); - let icpr2 = Arc::clone(&icpr); - tokio::spawn(async move { - let _d = w; - let stream = MockStream::new( - &StreamInfo { - rtp_header_extensions: vec![RTPHeaderExtension { - uri: TRANSPORT_CC_URI.to_owned(), - id: 1, - }], - ..Default::default() - }, - icpr2, - ) - .await; - - let id = i + 1; - #[allow(clippy::identity_op)] - for seq_num in [id * 1, id * 2, id * 3, id * 4, id * 5] { - stream - .write_rtp(&rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: seq_num, - ..Default::default() - }, - ..Default::default() - }) - .await - .unwrap(); - - let timeout = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timeout); - - tokio::select! { - p = stream.written_rtp() =>{ - if let Some(p) = p { - assert_eq!(p.header.sequence_number, seq_num); - let _ = p_chan_tx2.send(p).await; - }else{ - panic!("stream.written_rtp none"); - } - } - _ = timeout.as_mut()=>{ - panic!("written rtp packet not found"); - } - }; - } - - let _ = stream.close().await; - }); - } - wg.wait().await; - }); - - while let Some(p) = p_chan_rx.recv().await { - // Can't check for increasing transport cc sequence number, since we can't ensure ordering between the streams - // on pChan is same as in the interceptor, but at least make sure each packet has a seq nr. - let mut extension_header = p.header.get_extension(1).unwrap(); - let _twcc = TransportCcExtension::unmarshal(&mut extension_header)?; - } - - Ok(()) -} diff --git a/interceptor/src/twcc/twcc_test.rs b/interceptor/src/twcc/twcc_test.rs deleted file mode 100644 index 2bef1ffda..000000000 --- a/interceptor/src/twcc/twcc_test.rs +++ /dev/null @@ -1,565 +0,0 @@ -use rtcp::packet::Packet; -use util::Marshal; - -use super::*; -use crate::error::Result; - -#[test] -fn test_chunk_add() -> Result<()> { - //"fill with not received" - { - let mut c = Chunk::default(); - - for i in 0..MAX_RUN_LENGTH_CAP { - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16), "{}", i); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - } - assert_eq!(c.deltas, vec![0u16; MAX_RUN_LENGTH_CAP]); - assert!(!c.has_different_types); - - assert!(!c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - assert!(!c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - assert!(!c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - - let status_chunk = c.encode(); - match status_chunk { - PacketStatusChunk::RunLengthChunk(_) => {} - _ => panic!(), - }; - - let buf = status_chunk.marshal()?; - assert_eq!(&buf[..], &[0x1f, 0xff]); - } - - //"fill with small delta" - { - let mut c = Chunk::default(); - - for i in 0..MAX_ONE_BIT_CAP { - assert!( - c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16), - "{}", - i - ); - c.add(SymbolTypeTcc::PacketReceivedSmallDelta as u16); - } - - assert_eq!(c.deltas, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]); - assert!(!c.has_different_types); - - assert!(!c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - assert!(!c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - - let status_chunk = c.encode(); - match status_chunk { - PacketStatusChunk::RunLengthChunk(_) => {} - _ => panic!(), - }; - - let buf = status_chunk.marshal()?; - assert_eq!(&buf[..], &[0x20, 0xe]); - } - - //"fill with large delta" - { - let mut c = Chunk::default(); - - for i in 0..MAX_TWO_BIT_CAP { - assert!( - c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16), - "{}", - i - ); - c.add(SymbolTypeTcc::PacketReceivedLargeDelta as u16); - } - - assert_eq!(c.deltas, vec![2, 2, 2, 2, 2, 2, 2]); - assert!(c.has_large_delta); - assert!(!c.has_different_types); - - assert!(!c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - assert!(!c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - - let status_chunk = c.encode(); - match status_chunk { - PacketStatusChunk::RunLengthChunk(_) => {} - _ => panic!(), - }; - - let buf = status_chunk.marshal()?; - assert_eq!(&buf[..], &[0x40, 0x7]); - } - - // "fill with different types" - { - let mut c = Chunk::default(); - - assert!(c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedSmallDelta as u16); - assert!(c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedSmallDelta as u16); - assert!(c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedSmallDelta as u16); - assert!(c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedSmallDelta as u16); - - assert!(c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedLargeDelta as u16); - assert!(c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedLargeDelta as u16); - assert!(c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedLargeDelta as u16); - - assert!(!c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - - let status_chunk = c.encode(); - match status_chunk { - PacketStatusChunk::StatusVectorChunk(_) => {} - _ => panic!(), - }; - - let buf = status_chunk.marshal()?; - assert_eq!(&buf[..], &[0xd5, 0x6a]); - } - - //"overfill and encode" - { - let mut c = Chunk::default(); - - assert!(c.can_add(SymbolTypeTcc::PacketReceivedSmallDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedSmallDelta as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - assert!(c.can_add(SymbolTypeTcc::PacketNotReceived as u16)); - c.add(SymbolTypeTcc::PacketNotReceived as u16); - - assert!(!c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - - let status_chunk1 = c.encode(); - match status_chunk1 { - PacketStatusChunk::StatusVectorChunk(_) => {} - _ => panic!(), - }; - assert_eq!(c.deltas.len(), 1); - - assert!(c.can_add(SymbolTypeTcc::PacketReceivedLargeDelta as u16)); - c.add(SymbolTypeTcc::PacketReceivedLargeDelta as u16); - - let status_chunk2 = c.encode(); - match status_chunk2 { - PacketStatusChunk::StatusVectorChunk(_) => {} - _ => panic!(), - }; - assert_eq!(c.deltas.len(), 0); - - assert_eq!( - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta - ], - }), - status_chunk2 - ); - } - - Ok(()) -} - -#[test] -fn test_feedback() -> Result<()> { - //"add simple" - { - let mut f = Feedback::default(); - let got = f.add_received(0, 10); - assert!(got); - } - - //"add too large" - { - let mut f = Feedback::default(); - - assert!(!f.add_received(12, 8200 * 1000 * 250)); - } - - // "add received 1" - { - let mut f = Feedback::default(); - f.set_base(1, 1000 * 1000); - - let got = f.add_received(1, 1023 * 1000); - - assert!(got); - assert_eq!(f.next_sequence_number, 2); - assert_eq!(f.ref_timestamp64ms, 15); - - let got = f.add_received(4, 1086 * 1000); - assert!(got); - assert_eq!(f.next_sequence_number, 5); - assert_eq!(f.ref_timestamp64ms, 15); - - assert!(f.last_chunk.has_different_types); - assert_eq!(f.last_chunk.deltas.len(), 4); - assert!(!f - .last_chunk - .deltas - .contains(&(SymbolTypeTcc::PacketReceivedLargeDelta as u16))); - } - - //"add received 2" - { - let mut f = Feedback::new(0, 0, 0); - f.set_base(5, 320 * 1000); - - let mut got = f.add_received(5, 320 * 1000); - assert!(got); - got = f.add_received(7, 448 * 1000); - assert!(got); - got = f.add_received(8, 512 * 1000); - assert!(got); - got = f.add_received(11, 768 * 1000); - assert!(got); - - let pkt = f.get_rtcp(); - - assert!(pkt.header().padding); - assert_eq!(pkt.header().length, 7); - assert_eq!(pkt.base_sequence_number, 5); - assert_eq!(pkt.packet_status_count, 7); - assert_eq!(pkt.reference_time, 5); - assert_eq!(pkt.fb_pkt_count, 0); - assert_eq!(pkt.packet_chunks.len(), 1); - - assert_eq!( - vec![PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - ], - })], - pkt.packet_chunks - ); - - let expected_deltas = [ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0x0200 * TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0x0100 * TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0x0400 * TYPE_TCC_DELTA_SCALE_FACTOR, - }, - ]; - assert_eq!(pkt.recv_deltas.len(), expected_deltas.len()); - for (i, expected) in expected_deltas.iter().enumerate() { - assert_eq!(&pkt.recv_deltas[i], expected); - } - } - - //"add received wrapped sequence number" - { - let mut f = Feedback::new(0, 0, 0); - f.set_base(65535, 320 * 1000); - - let mut got = f.add_received(65535, 320 * 1000); - assert!(got); - got = f.add_received(7, 448 * 1000); - assert!(got); - got = f.add_received(8, 512 * 1000); - assert!(got); - got = f.add_received(11, 768 * 1000); - assert!(got); - - let pkt = f.get_rtcp(); - - assert!(pkt.header().padding); - assert_eq!(pkt.header().length, 7); - assert_eq!(pkt.base_sequence_number, 65535); - assert_eq!(pkt.packet_status_count, 13); - assert_eq!(pkt.reference_time, 5); - assert_eq!(pkt.fb_pkt_count, 0); - assert_eq!(pkt.packet_chunks.len(), 2); - - assert_eq!( - vec![ - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }), - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - ], - }), - ], - pkt.packet_chunks - ); - - let expected_deltas = [ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0x0200 * TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0x0100 * TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0x0400 * TYPE_TCC_DELTA_SCALE_FACTOR, - }, - ]; - assert_eq!(pkt.recv_deltas.len(), expected_deltas.len()); - for (i, expected) in expected_deltas.iter().enumerate() { - assert_eq!(&pkt.recv_deltas[i], expected); - } - } - - //"get RTCP" - { - let tests = vec![(320, 1, 5, 1), (1000, 2, 15, 2)]; - for (arrival_ts, sequence_number, want_ref_time, want_base_sequence_number) in tests { - let mut f = Feedback::new(0, 0, 0); - f.set_base(sequence_number, arrival_ts * 1000); - - let got = f.get_rtcp(); - assert_eq!(got.reference_time, want_ref_time); - assert_eq!(got.base_sequence_number, want_base_sequence_number); - } - } - - Ok(()) -} - -fn add_run(r: &mut Recorder, sequence_numbers: &[u16], arrival_times: &[i64]) { - assert_eq!(sequence_numbers.len(), arrival_times.len()); - - for i in 0..sequence_numbers.len() { - r.record(5000, sequence_numbers[i], arrival_times[i]); - } -} - -const TYPE_TCC_DELTA_SCALE_FACTOR: i64 = 250; -const SCALE_FACTOR_REFERENCE_TIME: i64 = 64000; - -fn increase_time(arrival_time: &mut i64, increase_amount: i64) -> i64 { - *arrival_time += increase_amount; - *arrival_time -} - -fn marshal_all(pkts: &[Box]) -> Result<()> { - for pkt in pkts { - let _ = pkt.marshal()?; - } - Ok(()) -} - -#[test] -fn test_build_feedback_packet() -> Result<()> { - let mut r = Recorder::new(5000); - - let mut arrival_time = SCALE_FACTOR_REFERENCE_TIME; - add_run( - &mut r, - &[0, 1, 2, 3, 4, 5, 6, 7], - &[ - SCALE_FACTOR_REFERENCE_TIME, - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR * 256), - ], - ); - - let rtcp_packets = r.build_feedback_packet(); - assert_eq!(1, rtcp_packets.len()); - - let expected = TransportLayerCc { - sender_ssrc: 5000, - media_ssrc: 5000, - base_sequence_number: 0, - reference_time: 1, - fb_pkt_count: 0, - packet_status_count: 8, - packet_chunks: vec![ - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 7, - }), - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedLargeDelta, - run_length: 1, - }), - ], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR * 256, - }, - ], - }; - - if let Some(tcc) = rtcp_packets[0].as_any().downcast_ref::() { - assert_eq!(tcc, &expected); - } else { - panic!(); - } - - marshal_all(&rtcp_packets[..])?; - - Ok(()) -} - -#[test] -fn test_build_feedback_packet_rolling() -> Result<()> { - let mut r = Recorder::new(5000); - - let mut arrival_time = SCALE_FACTOR_REFERENCE_TIME; - add_run(&mut r, &[3], &[arrival_time]); - - let rtcp_packets = r.build_feedback_packet(); - assert_eq!(0, rtcp_packets.len()); - - add_run( - &mut r, - &[4, 8, 9], - &[ - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - increase_time(&mut arrival_time, TYPE_TCC_DELTA_SCALE_FACTOR), - ], - ); - - let rtcp_packets = r.build_feedback_packet(); - assert_eq!(rtcp_packets.len(), 1); - - let expected = TransportLayerCc { - sender_ssrc: 5000, - media_ssrc: 5000, - base_sequence_number: 3, - reference_time: 1, - fb_pkt_count: 0, - packet_status_count: 7, - packet_chunks: vec![PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - ], - })], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: TYPE_TCC_DELTA_SCALE_FACTOR, - }, - ], - }; - - if let Some(tcc) = rtcp_packets[0].as_any().downcast_ref::() { - assert_eq!(tcc, &expected); - } else { - panic!(); - } - - marshal_all(&rtcp_packets[..])?; - - Ok(()) -} diff --git a/mdns/.gitignore b/mdns/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/mdns/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/mdns/CHANGELOG.md b/mdns/CHANGELOG.md deleted file mode 100644 index 597cbc79c..000000000 --- a/mdns/CHANGELOG.md +++ /dev/null @@ -1,23 +0,0 @@ -# webrtc-mdns changelog - -## Unreleased - -## v0.5.2 - -* Change log level for packet reception [#366](https://github.com/webrtc-rs/webrtc/pull/366). - -## v0.5.1 - -* Increased minimum support rust version to `1.60.0`. -* Increased required `webrtc-util` version to `0.7.0`. - -## v0.5.0 - -* Increased min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). -* [#275 mdns: get_interface_addr_for_ip once per query](https://github.com/webrtc-rs/webrtc/pull/275) by [@melekes](https://github.com/melekes). - - -## Prior to 0.5.0 - -Before 0.5.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/mdns/releases). - diff --git a/mdns/Cargo.toml b/mdns/Cargo.toml deleted file mode 100644 index 33bfafbc0..000000000 --- a/mdns/Cargo.toml +++ /dev/null @@ -1,53 +0,0 @@ -[package] -name = "webrtc-mdns" -version = "0.7.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of mDNS" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-mdns" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/mdns" - -[features] -default = [ "reuse_port" ] -reuse_port = [] - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["ifaces"] } - -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -socket2 = { version = "0.5", features = ["all"] } -log = "0.4" -thiserror = "1" - -[dev-dependencies] -env_logger = "0.10" -chrono = "0.4.28" -clap = "3" - -[[example]] -name = "mdns_query" -path = "examples/mdns_query.rs" -bench = false - -[[example]] -name = "mdns_server" -path = "examples/mdns_server.rs" -bench = false - -[[example]] -name = "mdns_server_query" -path = "examples/mdns_server_query.rs" -bench = false diff --git a/mdns/LICENSE-APACHE b/mdns/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/mdns/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/mdns/LICENSE-MIT b/mdns/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/mdns/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/mdns/README.md b/mdns/README.md deleted file mode 100644 index e19a09ee8..000000000 --- a/mdns/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of mDNS. Rewrite Pion mDNS in Rust -

diff --git a/mdns/codecov.yml b/mdns/codecov.yml deleted file mode 100644 index b69bc9bed..000000000 --- a/mdns/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: eb4a349b-5dd5-442b-a57c-a3de73aedf9b - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/mdns/doc/webrtc.rs.png b/mdns/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/mdns/doc/webrtc.rs.png and /dev/null differ diff --git a/mdns/examples/mdns_query.rs b/mdns/examples/mdns_query.rs deleted file mode 100644 index 1cac50471..000000000 --- a/mdns/examples/mdns_query.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::io::Write; -use std::net::SocketAddr; -use std::str::FromStr; - -use clap::{App, AppSettings, Arg}; -use mdns::config::*; -use mdns::conn::*; -use mdns::Error; -use tokio::sync::mpsc; -use webrtc_mdns as mdns; - -// For interop with webrtc-rs/mdns_server -// cargo run --color=always --package webrtc-mdns --example mdns_query - -// For interop with pion/mdns_server: -// cargo run --color=always --package webrtc-mdns --example mdns_query -- --local-name pion-test.local - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("mDNS Query") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of mDNS Query") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("0.0.0.0:5353") - .long("server") - .help("mDNS Server name."), - ) - .arg( - Arg::with_name("local-name") - .long("local-name") - .takes_value(true) - .default_value("webrtc-rs-test.local") - .help("Local name"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - let local_name = matches.value_of("local-name").unwrap(); - - let server = DnsConn::server( - SocketAddr::from_str(server)?, - Config { - ..Default::default() - }, - ) - .unwrap(); - - log::info!("querying dns"); - - let (_a, b) = mpsc::channel(1); - - let (answer, src) = server.query(local_name, b).await.unwrap(); - log::info!("dns queried"); - println!("answer = {answer}, src = {src}"); - - server.close().await.unwrap(); - Ok(()) -} diff --git a/mdns/examples/mdns_server.rs b/mdns/examples/mdns_server.rs deleted file mode 100644 index a66d5afcf..000000000 --- a/mdns/examples/mdns_server.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::io::Write; -use std::net::SocketAddr; -use std::str::FromStr; - -use clap::{App, AppSettings, Arg}; -use mdns::config::*; -use mdns::conn::*; -use mdns::Error; -use webrtc_mdns as mdns; - -// For interop with webrtc-rs/mdns_server -// cargo run --color=always --package webrtc-mdns --example mdns_server - -// For interop with pion/mdns_client: -// cargo run --color=always --package webrtc-mdns --example mdns_server -- --local-name pion-test.local - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - let mut app = App::new("mDNS Sever") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of mDNS Sever") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("0.0.0.0:5353") - .long("server") - .help("mDNS Server name."), - ) - .arg( - Arg::with_name("local-name") - .long("local-name") - .takes_value(true) - .default_value("webrtc-rs-test.local") - .help("Local name"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - let local_name = matches.value_of("local-name").unwrap(); - - let server = DnsConn::server( - SocketAddr::from_str(server)?, - Config { - local_names: vec![local_name.to_owned()], - ..Default::default() - }, - ) - .unwrap(); - - println!("Press ctlr-c to stop server"); - tokio::signal::ctrl_c().await.unwrap(); - server.close().await.unwrap(); - Ok(()) -} diff --git a/mdns/examples/mdns_server_query.rs b/mdns/examples/mdns_server_query.rs deleted file mode 100644 index b1a195ad9..000000000 --- a/mdns/examples/mdns_server_query.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::net::{Ipv4Addr, SocketAddr}; - -use tokio::sync::mpsc; -use webrtc_mdns::config::*; -use webrtc_mdns::conn::*; - -#[tokio::main] -async fn main() { - env_logger::init(); - - log::trace!("server a created"); - - let server_a = DnsConn::server( - SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 5353), - Config { - local_names: vec![ - "webrtc-rs-mdns-1.local".to_owned(), - "webrtc-rs-mdns-2.local".to_owned(), - ], - ..Default::default() - }, - ) - .unwrap(); - - let server_b = DnsConn::server( - SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 5353), - Config { - ..Default::default() - }, - ) - .unwrap(); - - let (a, b) = mpsc::channel(1); - - tokio::spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_secs(20)).await; - a.send(()).await - }); - - let (answer, src) = server_b.query("webrtc-rs-mdns-1.local", b).await.unwrap(); - println!("webrtc-rs-mdns-1.local answer = {answer}, src = {src}"); - - let (a, b) = mpsc::channel(1); - - tokio::spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_secs(20)).await; - a.send(()).await - }); - - let (answer, src) = server_b.query("webrtc-rs-mdns-2.local", b).await.unwrap(); - println!("webrtc-rs-mdns-2.local answer = {answer}, src = {src}"); - - server_a.close().await.unwrap(); - server_b.close().await.unwrap(); -} diff --git a/mdns/src/config.rs b/mdns/src/config.rs deleted file mode 100644 index b9f06de46..000000000 --- a/mdns/src/config.rs +++ /dev/null @@ -1,14 +0,0 @@ -use std::time::Duration; - -// Config is used to configure a mDNS client or server. -#[derive(Default, Debug)] -pub struct Config { - // query_interval controls how often we sends Queries until we - // get a response for the requested name - pub query_interval: Duration, - - // local_names are the names that we will generate answers for - // when we get questions - pub local_names: Vec, - //LoggerFactory logging.LoggerFactory -} diff --git a/mdns/src/conn/conn_test.rs b/mdns/src/conn/conn_test.rs deleted file mode 100644 index f284b86d9..000000000 --- a/mdns/src/conn/conn_test.rs +++ /dev/null @@ -1,47 +0,0 @@ -#[cfg(test)] -mod test { - use tokio::time::timeout; - - use crate::config::Config; - use crate::conn::*; - - #[tokio::test] - async fn test_multiple_close() -> Result<()> { - let server_a = DnsConn::server( - SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 5353), - Config::default(), - )?; - - server_a.close().await?; - - if let Err(err) = server_a.close().await { - assert_eq!(err, Error::ErrConnectionClosed); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_query_respect_timeout() -> Result<()> { - let server_a = DnsConn::server( - SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 5353), - Config::default(), - )?; - - let (a, b) = mpsc::channel(1); - - timeout(Duration::from_millis(100), a.send(())) - .await - .unwrap() - .unwrap(); - - let res = server_a.query("invalid-host", b).await; - assert!(res.is_err(), "server_a.query expects timeout!"); - - server_a.close().await?; - - Ok(()) - } -} diff --git a/mdns/src/conn/mod.rs b/mdns/src/conn/mod.rs deleted file mode 100644 index e7377977b..000000000 --- a/mdns/src/conn/mod.rs +++ /dev/null @@ -1,436 +0,0 @@ -use core::sync::atomic; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::sync::Arc; -use std::time::Duration; - -use socket2::SockAddr; -use tokio::net::{ToSocketAddrs, UdpSocket}; -use tokio::sync::{mpsc, Mutex}; -use util::ifaces; - -use crate::config::*; -use crate::error::*; -use crate::message::header::*; -use crate::message::name::*; -use crate::message::parser::*; -use crate::message::question::*; -use crate::message::resource::a::*; -use crate::message::resource::*; -use crate::message::*; - -mod conn_test; - -pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353"; - -const INBOUND_BUFFER_SIZE: usize = 65535; -const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1); -const MAX_MESSAGE_RECORDS: usize = 3; -const RESPONSE_TTL: u32 = 120; - -// Conn represents a mDNS Server -pub struct DnsConn { - socket: Arc, - dst_addr: SocketAddr, - - query_interval: Duration, - queries: Arc>>, - - is_server_closed: Arc, - close_server: mpsc::Sender<()>, -} - -struct Query { - name_with_suffix: String, - query_result_chan: mpsc::Sender, -} - -struct QueryResult { - answer: ResourceHeader, - addr: SocketAddr, -} - -impl DnsConn { - /// server establishes a mDNS connection over an existing connection - pub fn server(addr: SocketAddr, config: Config) -> Result { - let socket = socket2::Socket::new( - socket2::Domain::IPV4, - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - - #[cfg(feature = "reuse_port")] - #[cfg(target_family = "unix")] - socket.set_reuse_port(true)?; - - socket.set_reuse_address(true)?; - socket.set_broadcast(true)?; - socket.set_nonblocking(true)?; - - socket.bind(&SockAddr::from(addr))?; - { - let mut join_error_count = 0; - let interfaces = match ifaces::ifaces() { - Ok(e) => e, - Err(e) => { - log::error!("Error getting interfaces: {:?}", e); - return Err(Error::Other(e.to_string())); - } - }; - - for interface in &interfaces { - if let Some(SocketAddr::V4(e)) = interface.addr { - if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip()) - { - log::trace!("Error connecting multicast, error: {:?}", e); - join_error_count += 1; - continue; - } - - log::trace!("Connected to interface address {:?}", e); - } - } - - if join_error_count >= interfaces.len() { - return Err(Error::ErrJoiningMulticastGroup); - } - } - - let socket = UdpSocket::from_std(socket.into())?; - - let local_names = config - .local_names - .iter() - .map(|l| l.to_string() + ".") - .collect(); - - let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?; - - let is_server_closed = Arc::new(atomic::AtomicBool::new(false)); - - let (close_server_send, close_server_rcv) = mpsc::channel(1); - - let c = DnsConn { - query_interval: if config.query_interval != Duration::from_secs(0) { - config.query_interval - } else { - DEFAULT_QUERY_INTERVAL - }, - - queries: Arc::new(Mutex::new(vec![])), - socket: Arc::new(socket), - dst_addr, - is_server_closed: Arc::clone(&is_server_closed), - close_server: close_server_send, - }; - - let queries = c.queries.clone(); - let socket = Arc::clone(&c.socket); - - tokio::spawn(async move { - DnsConn::start( - close_server_rcv, - is_server_closed, - socket, - local_names, - dst_addr, - queries, - ) - .await - }); - - Ok(c) - } - - /// Close closes the mDNS Conn - pub async fn close(&self) -> Result<()> { - log::info!("Closing connection"); - if self.is_server_closed.load(atomic::Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - log::trace!("Sending close command to server"); - match self.close_server.send(()).await { - Ok(_) => { - log::trace!("Close command sent"); - Ok(()) - } - Err(e) => { - log::warn!("Error sending close command to server: {:?}", e); - Err(Error::ErrConnectionClosed) - } - } - } - - /// Query sends mDNS Queries for the following name until - /// either there's a close signal or we get a result - pub async fn query( - &self, - name: &str, - mut close_query_signal: mpsc::Receiver<()>, - ) -> Result<(ResourceHeader, SocketAddr)> { - if self.is_server_closed.load(atomic::Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let name_with_suffix = name.to_owned() + "."; - - let (query_tx, mut query_rx) = mpsc::channel(1); - { - let mut queries = self.queries.lock().await; - queries.push(Query { - name_with_suffix: name_with_suffix.clone(), - query_result_chan: query_tx, - }); - } - - log::trace!("Sending query"); - self.send_question(&name_with_suffix).await; - - loop { - tokio::select! { - _ = tokio::time::sleep(self.query_interval) => { - log::trace!("Sending query"); - self.send_question(&name_with_suffix).await - }, - - _ = close_query_signal.recv() => { - log::info!("Query close signal received."); - return Err(Error::ErrConnectionClosed) - }, - - res_opt = query_rx.recv() =>{ - log::info!("Received query result"); - if let Some(res) = res_opt{ - return Ok((res.answer, res.addr)); - } - } - } - } - } - - async fn send_question(&self, name: &str) { - let packed_name = match Name::new(name) { - Ok(pn) => pn, - Err(err) => { - log::warn!("Failed to construct mDNS packet: {}", err); - return; - } - }; - - let raw_query = { - let mut msg = Message { - header: Header::default(), - questions: vec![Question { - typ: DnsType::A, - class: DNSCLASS_INET, - name: packed_name, - }], - ..Default::default() - }; - - match msg.pack() { - Ok(v) => v, - Err(err) => { - log::error!("Failed to construct mDNS packet {}", err); - return; - } - } - }; - - log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query); - if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await { - log::error!("Failed to send mDNS packet {}", err); - } - } - - async fn start( - mut closed_rx: mpsc::Receiver<()>, - close_server: Arc, - socket: Arc, - local_names: Vec, - dst_addr: SocketAddr, - queries: Arc>>, - ) -> Result<()> { - log::info!("Looping and listening {:?}", socket.local_addr()); - - let mut b = vec![0u8; INBOUND_BUFFER_SIZE]; - let (mut n, mut src); - - loop { - tokio::select! { - _ = closed_rx.recv() => { - log::info!("Closing server connection"); - close_server.store(true, atomic::Ordering::SeqCst); - - return Ok(()); - } - - result = socket.recv_from(&mut b) => { - match result{ - Ok((len, addr)) => { - n = len; - src = addr; - log::trace!("Received new connection from {:?}", addr); - }, - - Err(err) => { - log::error!("Error receiving from socket connection: {:?}", err); - continue; - }, - } - } - } - - let mut p = Parser::default(); - if let Err(err) = p.start(&b[..n]) { - log::error!("Failed to parse mDNS packet {}", err); - continue; - } - - run(&mut p, &socket, &local_names, src, dst_addr, &queries).await - } - } -} - -async fn run( - p: &mut Parser<'_>, - socket: &Arc, - local_names: &[String], - src: SocketAddr, - dst_addr: SocketAddr, - queries: &Arc>>, -) { - let mut interface_addr = None; - for _ in 0..=MAX_MESSAGE_RECORDS { - let q = match p.question() { - Ok(q) => q, - Err(err) => { - if Error::ErrSectionDone == err { - log::trace!("Parsing has completed"); - break; - } else { - log::error!("Failed to parse mDNS packet {}", err); - return; - } - } - }; - - for local_name in local_names { - if *local_name == q.name.data { - let interface_addr = match interface_addr { - Some(addr) => addr, - None => match get_interface_addr_for_ip(src).await { - Ok(addr) => { - interface_addr.replace(addr); - addr - } - Err(e) => { - log::warn!( - "Failed to get local interface to communicate with {}: {:?}", - &src, - e - ); - continue; - } - }, - }; - - log::trace!( - "Found local name: {} to send answer, IP {}, interface addr {}", - local_name, - src.ip(), - interface_addr - ); - if let Err(e) = - send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await - { - log::error!("Error sending answer to client: {:?}", e); - continue; - }; - } - } - } - - // There might be more than MAX_MESSAGE_RECORDS questions, so skip the rest - let _ = p.skip_all_questions(); - - for _ in 0..=MAX_MESSAGE_RECORDS { - let a = match p.answer_header() { - Ok(a) => a, - Err(err) => { - if Error::ErrSectionDone != err { - log::warn!("Failed to parse mDNS packet {}", err); - } - return; - } - }; - - if a.typ != DnsType::A && a.typ != DnsType::Aaaa { - continue; - } - - let mut qs = queries.lock().await; - for j in (0..qs.len()).rev() { - if qs[j].name_with_suffix == a.name.data { - let _ = qs[j] - .query_result_chan - .send(QueryResult { - answer: a.clone(), - addr: src, - }) - .await; - qs.remove(j); - } - } - } -} - -async fn send_answer( - socket: &Arc, - interface_addr: &SocketAddr, - name: &str, - dst: IpAddr, - dst_addr: SocketAddr, -) -> Result<()> { - let raw_answer = { - let mut msg = Message { - header: Header { - response: true, - authoritative: true, - ..Default::default() - }, - - answers: vec![Resource { - header: ResourceHeader { - typ: DnsType::A, - class: DNSCLASS_INET, - name: Name::new(name)?, - ttl: RESPONSE_TTL, - ..Default::default() - }, - body: Some(Box::new(AResource { - a: match interface_addr.ip() { - IpAddr::V4(ip) => ip.octets(), - IpAddr::V6(_) => { - return Err(Error::Other("Unexpected IpV6 addr".to_owned())) - } - }, - })), - }], - ..Default::default() - }; - - msg.pack()? - }; - - socket.send_to(&raw_answer, dst_addr).await?; - log::trace!("Sent answer to IP {}", dst); - - Ok(()) -} - -async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - socket.connect(addr).await?; - socket.local_addr() -} diff --git a/mdns/src/error.rs b/mdns/src/error.rs deleted file mode 100644 index 1149f3fed..000000000 --- a/mdns/src/error.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::string::FromUtf8Error; -use std::{io, net}; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("mDNS: failed to join multicast group")] - ErrJoiningMulticastGroup, - #[error("mDNS: connection is closed")] - ErrConnectionClosed, - #[error("mDNS: context has elapsed")] - ErrContextElapsed, - #[error("mDNS: config must not be nil")] - ErrNilConfig, - #[error("parsing/packing of this type isn't available yet")] - ErrNotStarted, - #[error("parsing/packing of this section has completed")] - ErrSectionDone, - #[error("parsing/packing of this section is header")] - ErrSectionHeader, - #[error("insufficient data for base length type")] - ErrBaseLen, - #[error("insufficient data for calculated length type")] - ErrCalcLen, - #[error("segment prefix is reserved")] - ErrReserved, - #[error("too many pointers (>10)")] - ErrTooManyPtr, - #[error("invalid pointer")] - ErrInvalidPtr, - #[error("nil resource body")] - ErrNilResourceBody, - #[error("insufficient data for resource body length")] - ErrResourceLen, - #[error("segment length too long")] - ErrSegTooLong, - #[error("zero length segment")] - ErrZeroSegLen, - #[error("resource length too long")] - ErrResTooLong, - #[error("too many Questions to pack (>65535)")] - ErrTooManyQuestions, - #[error("too many Answers to pack (>65535)")] - ErrTooManyAnswers, - #[error("too many Authorities to pack (>65535)")] - ErrTooManyAuthorities, - #[error("too many Additionals to pack (>65535)")] - ErrTooManyAdditionals, - #[error("name is not in canonical format (it must end with a .)")] - ErrNonCanonicalName, - #[error("character string exceeds maximum length (255)")] - ErrStringTooLong, - #[error("compressed name in SRV resource data")] - ErrCompressedSrv, - #[error("empty builder msg")] - ErrEmptyBuilderMsg, - #[error("{0}")] - Io(#[source] IoError), - #[error("utf-8 error: {0}")] - Utf8(#[from] FromUtf8Error), - #[error("parse addr: {0}")] - ParseIp(#[from] net::AddrParseError), - #[error("{0}")] - Other(String), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} diff --git a/mdns/src/lib.rs b/mdns/src/lib.rs deleted file mode 100644 index eda3a3e83..000000000 --- a/mdns/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod config; -pub mod conn; -mod error; -pub mod message; - -pub use error::Error; diff --git a/mdns/src/message/builder.rs b/mdns/src/message/builder.rs deleted file mode 100644 index 183a593c4..000000000 --- a/mdns/src/message/builder.rs +++ /dev/null @@ -1,212 +0,0 @@ -use std::collections::HashMap; - -use super::header::*; -use super::question::*; -use super::resource::*; -use super::*; -use crate::error::*; - -// A Builder allows incrementally packing a DNS message. -// -// Example usage: -// b := NewBuilder(Header{...}) -// b.enable_compression() -// // Optionally start a section and add things to that section. -// // Repeat adding sections as necessary. -// buf, err := b.Finish() -// // If err is nil, buf[2:] will contain the built bytes. -#[derive(Default)] -pub struct Builder { - // msg is the storage for the message being built. - pub msg: Option>, - - // section keeps track of the current section being built. - pub section: Section, - - // header keeps track of what should go in the header when Finish is - // called. - pub header: HeaderInternal, - - // start is the starting index of the bytes allocated in msg for header. - pub start: usize, - - // compression is a mapping from name suffixes to their starting index - // in msg. - pub compression: Option>, -} - -impl Builder { - // NewBuilder creates a new builder with compression disabled. - // - // Note: Most users will want to immediately enable compression with the - // enable_compression method. See that method's comment for why you may or may - // not want to enable compression. - // - // The DNS message is appended to the provided initial buffer buf (which may be - // nil) as it is built. The final message is returned by the (*Builder).Finish - // method, which may return the same underlying array if there was sufficient - // capacity in the slice. - pub fn new(h: &Header) -> Self { - let (id, bits) = h.pack(); - - Builder { - msg: Some(vec![0; HEADER_LEN]), - start: 0, - section: Section::Header, - header: HeaderInternal { - id, - bits, - ..Default::default() - }, - compression: None, - } - - //var hb [HEADER_LEN]byte - //b.msg = append(b.msg, hb[:]...) - //return b - } - - // enable_compression enables compression in the Builder. - // - // Leaving compression disabled avoids compression related allocations, but can - // result in larger message sizes. Be careful with this mode as it can cause - // messages to exceed the UDP size limit. - // - // According to RFC 1035, section 4.1.4, the use of compression is optional, but - // all implementations must accept both compressed and uncompressed DNS - // messages. - // - // Compression should be enabled before any sections are added for best results. - pub fn enable_compression(&mut self) { - self.compression = Some(HashMap::new()); - } - - fn start_check(&self, section: Section) -> Result<()> { - if self.section <= Section::NotStarted { - return Err(Error::ErrNotStarted); - } - if self.section > section { - return Err(Error::ErrSectionDone); - } - - Ok(()) - } - - // start_questions prepares the builder for packing Questions. - pub fn start_questions(&mut self) -> Result<()> { - self.start_check(Section::Questions)?; - self.section = Section::Questions; - Ok(()) - } - - // start_answers prepares the builder for packing Answers. - pub fn start_answers(&mut self) -> Result<()> { - self.start_check(Section::Answers)?; - self.section = Section::Answers; - Ok(()) - } - - // start_authorities prepares the builder for packing Authorities. - pub fn start_authorities(&mut self) -> Result<()> { - self.start_check(Section::Authorities)?; - self.section = Section::Authorities; - Ok(()) - } - - // start_additionals prepares the builder for packing Additionals. - pub fn start_additionals(&mut self) -> Result<()> { - self.start_check(Section::Additionals)?; - self.section = Section::Additionals; - Ok(()) - } - - fn increment_section_count(&mut self) -> Result<()> { - let section = self.section; - let (count, err) = match section { - Section::Questions => (&mut self.header.questions, Error::ErrTooManyQuestions), - Section::Answers => (&mut self.header.answers, Error::ErrTooManyAnswers), - Section::Authorities => (&mut self.header.authorities, Error::ErrTooManyAuthorities), - Section::Additionals => (&mut self.header.additionals, Error::ErrTooManyAdditionals), - Section::NotStarted => return Err(Error::ErrNotStarted), - Section::Done => return Err(Error::ErrSectionDone), - Section::Header => return Err(Error::ErrSectionHeader), - }; - - if *count == u16::MAX { - Err(err) - } else { - *count += 1; - Ok(()) - } - } - - // question adds a single question. - pub fn add_question(&mut self, q: &Question) -> Result<()> { - if self.section < Section::Questions { - return Err(Error::ErrNotStarted); - } - if self.section > Section::Questions { - return Err(Error::ErrSectionDone); - } - let msg = self.msg.take(); - if let Some(mut msg) = msg { - msg = q.pack(msg, &mut self.compression, self.start)?; - self.increment_section_count()?; - self.msg = Some(msg); - } - - Ok(()) - } - - fn check_resource_section(&self) -> Result<()> { - if self.section < Section::Answers { - return Err(Error::ErrNotStarted); - } - if self.section > Section::Additionals { - return Err(Error::ErrSectionDone); - } - Ok(()) - } - - // Resource adds a single resource. - pub fn add_resource(&mut self, r: &mut Resource) -> Result<()> { - self.check_resource_section()?; - - if let Some(body) = &r.body { - r.header.typ = body.real_type(); - } else { - return Err(Error::ErrNilResourceBody); - } - - if let Some(msg) = self.msg.take() { - let (mut msg, len_off) = r.header.pack(msg, &mut self.compression, self.start)?; - let pre_len = msg.len(); - if let Some(body) = &r.body { - msg = body.pack(msg, &mut self.compression, self.start)?; - r.header.fix_len(&mut msg, len_off, pre_len)?; - self.increment_section_count()?; - } - self.msg = Some(msg); - } - - Ok(()) - } - - // Finish ends message building and generates a binary message. - pub fn finish(&mut self) -> Result> { - if self.section < Section::Header { - return Err(Error::ErrNotStarted); - } - self.section = Section::Done; - - // Space for the header was allocated in NewBuilder. - let buf = self.header.pack(vec![]); - assert_eq!(buf.len(), HEADER_LEN); - if let Some(mut msg) = self.msg.take() { - msg[..HEADER_LEN].copy_from_slice(&buf[..HEADER_LEN]); - Ok(msg) - } else { - Err(Error::ErrEmptyBuilderMsg) - } - } -} diff --git a/mdns/src/message/header.rs b/mdns/src/message/header.rs deleted file mode 100644 index 7a382e2cf..000000000 --- a/mdns/src/message/header.rs +++ /dev/null @@ -1,165 +0,0 @@ -use super::*; - -// Header is a representation of a DNS message header. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub struct Header { - pub id: u16, - pub response: bool, - pub op_code: OpCode, - pub authoritative: bool, - pub truncated: bool, - pub recursion_desired: bool, - pub recursion_available: bool, - pub rcode: RCode, -} - -impl fmt::Display for Header { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.Header{{id: {}, response: {}, op_code: {}, authoritative: {}, truncated: {}, recursion_desired: {}, recursion_available: {}, rcode: {} }}", - self.id, - self.response, - self.op_code, - self.authoritative, - self.truncated, - self.recursion_desired, - self.recursion_available, - self.rcode - ) - } -} - -impl Header { - pub fn pack(&self) -> (u16, u16) { - let id = self.id; - let mut bits = self.op_code << 11 | self.rcode as u16; - if self.recursion_available { - bits |= HEADER_BIT_RA - } - if self.recursion_desired { - bits |= HEADER_BIT_RD - } - if self.truncated { - bits |= HEADER_BIT_TC - } - if self.authoritative { - bits |= HEADER_BIT_AA - } - if self.response { - bits |= HEADER_BIT_QR - } - - (id, bits) - } -} - -#[derive(Default, Copy, Clone, PartialOrd, PartialEq, Eq)] -pub enum Section { - #[default] - NotStarted = 0, - Header = 1, - Questions = 2, - Answers = 3, - Authorities = 4, - Additionals = 5, - Done = 6, -} - -impl From for Section { - fn from(v: u8) -> Self { - match v { - 0 => Section::NotStarted, - 1 => Section::Header, - 2 => Section::Questions, - 3 => Section::Answers, - 4 => Section::Authorities, - 5 => Section::Additionals, - _ => Section::Done, - } - } -} - -impl fmt::Display for Section { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - Section::NotStarted => "NotStarted", - Section::Header => "Header", - Section::Questions => "question", - Section::Answers => "answer", - Section::Authorities => "authority", - Section::Additionals => "additional", - Section::Done => "Done", - }; - write!(f, "{s}") - } -} - -// header is the wire format for a DNS message header. -#[derive(Default)] -pub struct HeaderInternal { - pub id: u16, - pub bits: u16, - pub questions: u16, - pub answers: u16, - pub authorities: u16, - pub additionals: u16, -} - -impl HeaderInternal { - pub(crate) fn count(&self, sec: Section) -> u16 { - match sec { - Section::Questions => self.questions, - Section::Answers => self.answers, - Section::Authorities => self.authorities, - Section::Additionals => self.additionals, - _ => 0, - } - } - - // pack appends the wire format of the header to msg. - pub(crate) fn pack(&self, mut msg: Vec) -> Vec { - msg = pack_uint16(msg, self.id); - msg = pack_uint16(msg, self.bits); - msg = pack_uint16(msg, self.questions); - msg = pack_uint16(msg, self.answers); - msg = pack_uint16(msg, self.authorities); - msg = pack_uint16(msg, self.additionals); - msg - } - - pub(crate) fn unpack(&mut self, msg: &[u8], off: usize) -> Result { - let (id, off) = unpack_uint16(msg, off)?; - self.id = id; - - let (bits, off) = unpack_uint16(msg, off)?; - self.bits = bits; - - let (questions, off) = unpack_uint16(msg, off)?; - self.questions = questions; - - let (answers, off) = unpack_uint16(msg, off)?; - self.answers = answers; - - let (authorities, off) = unpack_uint16(msg, off)?; - self.authorities = authorities; - - let (additionals, off) = unpack_uint16(msg, off)?; - self.additionals = additionals; - - Ok(off) - } - - pub(crate) fn header(&self) -> Header { - Header { - id: self.id, - response: (self.bits & HEADER_BIT_QR) != 0, - op_code: ((self.bits >> 11) & 0xF) as OpCode, - authoritative: (self.bits & HEADER_BIT_AA) != 0, - truncated: (self.bits & HEADER_BIT_TC) != 0, - recursion_desired: (self.bits & HEADER_BIT_RD) != 0, - recursion_available: (self.bits & HEADER_BIT_RA) != 0, - rcode: RCode::from((self.bits & 0xF) as u8), - } - } -} diff --git a/mdns/src/message/message_test.rs b/mdns/src/message/message_test.rs deleted file mode 100644 index d18acbab6..000000000 --- a/mdns/src/message/message_test.rs +++ /dev/null @@ -1,1311 +0,0 @@ -// Silence warning on complex types: -#![allow(clippy::type_complexity)] - -use std::collections::HashMap; - -use super::builder::*; -use super::header::*; -use super::name::*; -use super::parser::*; -use super::question::*; -use super::resource::a::*; -use super::resource::aaaa::*; -use super::resource::cname::*; -use super::resource::mx::*; -use super::resource::ns::*; -use super::resource::opt::*; -use super::resource::ptr::*; -use super::resource::soa::*; -use super::resource::srv::*; -use super::resource::txt::*; -use super::resource::*; -use super::*; -use crate::error::*; - -fn small_test_msg() -> Result { - let name = Name::new("example.com.")?; - Ok(Message { - header: Header { - response: true, - authoritative: true, - ..Default::default() - }, - questions: vec![Question { - name: name.clone(), - typ: DnsType::A, - class: DNSCLASS_INET, - }], - answers: vec![Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::A, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AResource { a: [127, 0, 0, 1] })), - }], - authorities: vec![Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::A, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AResource { a: [127, 0, 0, 1] })), - }], - additionals: vec![Resource { - header: ResourceHeader { - name, - typ: DnsType::A, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AResource { a: [127, 0, 0, 1] })), - }], - }) -} - -fn large_test_msg() -> Result { - let name = Name::new("foo.bar.example.com.")?; - Ok(Message { - header: Header { - response: true, - authoritative: true, - ..Default::default() - }, - questions: vec![Question { - name: name.clone(), - typ: DnsType::A, - class: DNSCLASS_INET, - }], - answers: vec![ - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::A, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AResource { a: [127, 0, 0, 1] })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::A, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AResource { a: [127, 0, 0, 2] })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AaaaResource { - aaaa: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Cname, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(CnameResource { - cname: Name::new("alias.example.com.")?, - })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Soa, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(SoaResource { - ns: Name::new("ns1.example.com.")?, - mbox: Name::new("mb.example.com.")?, - serial: 1, - refresh: 2, - retry: 3, - expire: 4, - min_ttl: 5, - })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Ptr, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(PtrResource { - ptr: Name::new("ptr.example.com.")?, - })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Mx, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(MxResource { - pref: 7, - mx: Name::new("mx.example.com.")?, - })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Srv, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(SrvResource { - priority: 8, - weight: 9, - port: 11, - target: Name::new("srv.example.com.")?, - })), - }, - ], - authorities: vec![ - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Ns, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(NsResource { - ns: Name::new("ns1.example.com.")?, - })), - }, - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Ns, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(NsResource { - ns: Name::new("ns2.example.com.")?, - })), - }, - ], - additionals: vec![ - Resource { - header: ResourceHeader { - name: name.clone(), - typ: DnsType::Txt, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(TxtResource { - txt: vec!["So Long, and Thanks for All the Fish".into()], - })), - }, - Resource { - header: ResourceHeader { - name, - typ: DnsType::Txt, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(TxtResource { - txt: vec!["Hamster Huey and the Gooey Kablooie".into()], - })), - }, - Resource { - header: must_edns0_resource_header(4096, 0xfe0 | (RCode::Success as u32), false)?, - body: Some(Box::new(OptResource { - options: vec![DnsOption { - code: 10, // see RFC 7873 - data: vec![0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef], - }], - })), - }, - ], - }) -} - -fn must_edns0_resource_header(l: u16, extrc: u32, d: bool) -> Result { - let mut h = ResourceHeader { - class: DNSCLASS_INET, - ..Default::default() - }; - h.set_edns0(l, extrc, d)?; - Ok(h) -} - -#[test] -fn test_name_string() -> Result<()> { - let want = "foo"; - let name = Name::new(want)?; - assert_eq!(name.to_string(), want); - - Ok(()) -} - -#[test] -fn test_question_pack_unpack() -> Result<()> { - let want = Question { - name: Name::new(".")?, - typ: DnsType::A, - class: DNSCLASS_INET, - }; - let buf = want.pack(vec![0; 1], &mut Some(HashMap::new()), 1)?; - let mut p = Parser { - msg: &buf, - header: HeaderInternal { - questions: 1, - ..Default::default() - }, - section: Section::Questions, - off: 1, - ..Default::default() - }; - - let got = p.question()?; - assert_eq!( - p.off, - buf.len(), - "unpacked different amount than packed: got = {}, want = {}", - p.off, - buf.len(), - ); - assert_eq!( - got, want, - "got from Parser.Question() = {got}, want = {want}" - ); - - Ok(()) -} - -#[test] -fn test_name() -> Result<()> { - let tests = vec![ - "", - ".", - "google..com", - "google.com", - "google..com.", - "google.com.", - ".google.com.", - "www..google.com.", - "www.google.com.", - ]; - - for test in tests { - let name = Name::new(test)?; - let ns = name.to_string(); - assert_eq!(ns, test, "got {name} = {ns}, want = {test}"); - } - - Ok(()) -} - -#[test] -fn test_name_pack_unpack() -> Result<()> { - let tests: Vec<(&str, &str, Option)> = vec![ - ("", "", Some(Error::ErrNonCanonicalName)), - (".", ".", None), - ("google..com", "", Some(Error::ErrNonCanonicalName)), - ("google.com", "", Some(Error::ErrNonCanonicalName)), - ("google..com.", "", Some(Error::ErrZeroSegLen)), - ("google.com.", "google.com.", None), - (".google.com.", "", Some(Error::ErrZeroSegLen)), - ("www..google.com.", "", Some(Error::ErrZeroSegLen)), - ("www.google.com.", "www.google.com.", None), - ]; - - for (input, want, want_err) in tests { - let input = Name::new(input)?; - let result = input.pack(vec![], &mut Some(HashMap::new()), 0); - if let Some(want_err) = want_err { - if let Err(actual_err) = result { - assert_eq!(actual_err, want_err); - } else { - panic!(); - } - continue; - } else { - assert!(result.is_ok()); - } - - let buf = result.unwrap(); - - let want = Name::new(want)?; - - let mut got = Name::default(); - let n = got.unpack(&buf, 0)?; - assert_eq!( - n, - buf.len(), - "unpacked different amount than packed for {}: got = {}, want = {}", - input, - n, - buf.len(), - ); - - assert_eq!( - got, want, - "unpacking packing of {input}: got = {got}, want = {want}" - ); - } - - Ok(()) -} - -#[test] -fn test_incompressible_name() -> Result<()> { - let name = Name::new("example.com.")?; - let mut compression = Some(HashMap::new()); - let buf = name.pack(vec![], &mut compression, 0)?; - let buf = name.pack(buf, &mut compression, 0)?; - let mut n1 = Name::default(); - let off = n1.unpack_compressed(&buf, 0, false /* allowCompression */)?; - let mut n2 = Name::default(); - let result = n2.unpack_compressed(&buf, off, false /* allowCompression */); - if let Err(err) = result { - assert_eq!( - Error::ErrCompressedSrv, - err, - "unpacking compressed incompressible name with pointers: got {}, want = {}", - err, - Error::ErrCompressedSrv - ); - } else { - panic!(); - } - - Ok(()) -} - -#[test] -fn test_header_unpack_error() -> Result<()> { - let wants = vec![ - "id", - "bits", - "questions", - "answers", - "authorities", - "additionals", - ]; - - let mut buf = vec![]; - for want in wants { - let mut h = HeaderInternal::default(); - let result = h.unpack(&buf, 0); - assert!(result.is_err(), "{}", want); - buf.extend_from_slice(&[0, 0]); - } - - Ok(()) -} - -#[test] -fn test_parser_start() -> Result<()> { - let mut p = Parser::default(); - let result = p.start(&[]); - assert!(result.is_err()); - - Ok(()) -} - -#[test] -fn test_resource_not_started() -> Result<()> { - let tests: Vec<(&str, Box) -> Result<()>>)> = vec![ - ( - "CNAMEResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "MXResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "NSResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "PTRResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "SOAResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "TXTResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "SRVResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "AResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ( - "AAAAResource", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.resource_body().map(|_| ()) }), - ), - ]; - - for (name, test_fn) in tests { - let mut p = Parser::default(); - if let Err(err) = test_fn(&mut p) { - assert_eq!(err, Error::ErrNotStarted, "{name}"); - } - } - - Ok(()) -} - -#[test] -fn test_srv_pack_unpack() -> Result<()> { - let want = Box::new(SrvResource { - priority: 8, - weight: 9, - port: 11, - target: Name::new("srv.example.com.")?, - }); - - let b = want.pack(vec![], &mut None, 0)?; - let mut got = SrvResource::default(); - got.unpack(&b, 0, 0)?; - assert_eq!(got.to_string(), want.to_string(),); - - Ok(()) -} - -#[test] -fn test_dns_pack_unpack() -> Result<()> { - let wants = vec![ - Message { - header: Header::default(), - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - }], - answers: vec![], - authorities: vec![], - additionals: vec![], - }, - large_test_msg()?, - ]; - - for mut want in wants { - let b = want.pack()?; - let mut got = Message::default(); - got.unpack(&b)?; - assert_eq!(got.to_string(), want.to_string()); - } - - Ok(()) -} - -#[test] -fn test_dns_append_pack_unpack() -> Result<()> { - let wants = vec![ - Message { - header: Header::default(), - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - }], - answers: vec![], - authorities: vec![], - additionals: vec![], - }, - large_test_msg()?, - ]; - - for mut want in wants { - let mut b = vec![0; 2]; - b = want.append_pack(b)?; - let mut got = Message::default(); - got.unpack(&b[2..])?; - assert_eq!(got.to_string(), want.to_string()); - } - - Ok(()) -} - -#[test] -fn test_skip_all() -> Result<()> { - let mut msg = large_test_msg()?; - let buf = msg.pack()?; - let mut p = Parser::default(); - p.start(&buf)?; - - for _ in 1..=3 { - p.skip_all_questions()?; - } - for _ in 1..=3 { - p.skip_all_answers()?; - } - for _ in 1..=3 { - p.skip_all_authorities()?; - } - for _ in 1..=3 { - p.skip_all_additionals()?; - } - - Ok(()) -} - -#[test] -fn test_skip_each() -> Result<()> { - let mut msg = small_test_msg()?; - let buf = msg.pack()?; - let mut p = Parser::default(); - p.start(&buf)?; - - // {"SkipQuestion", p.SkipQuestion}, - // {"SkipAnswer", p.SkipAnswer}, - // {"SkipAuthority", p.SkipAuthority}, - // {"SkipAdditional", p.SkipAdditional}, - - p.skip_question()?; - if let Err(err) = p.skip_question() { - assert_eq!(err, Error::ErrSectionDone); - } else { - panic!("expected error, but got ok"); - } - - p.skip_answer()?; - if let Err(err) = p.skip_answer() { - assert_eq!(err, Error::ErrSectionDone); - } else { - panic!("expected error, but got ok"); - } - - p.skip_authority()?; - if let Err(err) = p.skip_authority() { - assert_eq!(err, Error::ErrSectionDone); - } else { - panic!("expected error, but got ok"); - } - - p.skip_additional()?; - if let Err(err) = p.skip_additional() { - assert_eq!(err, Error::ErrSectionDone); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_skip_after_read() -> Result<()> { - let mut msg = small_test_msg()?; - let buf = msg.pack()?; - let mut p = Parser::default(); - p.start(&buf)?; - - let tests: Vec<(&str, Box) -> Result<()>>)> = vec![ - ( - "Question", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.question().map(|_| ()) }), - ), - ( - "Answer", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.answer().map(|_| ()) }), - ), - ( - "Authority", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.authority().map(|_| ()) }), - ), - ( - "Additional", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.additional().map(|_| ()) }), - ), - ]; - - for (name, read_fn) in tests { - read_fn(&mut p)?; - - let result = match name { - "Question" => p.skip_question(), - "Answer" => p.skip_answer(), - "Authority" => p.skip_authority(), - _ => p.skip_additional(), - }; - - if let Err(err) = result { - assert_eq!(err, Error::ErrSectionDone); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_skip_not_started() -> Result<()> { - let tests: Vec<(&str, Box) -> Result<()>>)> = vec![ - ( - "SkipAllQuestions", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.skip_all_questions() }), - ), - ( - "SkipAllAnswers", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.skip_all_answers() }), - ), - ( - "SkipAllAuthorities", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.skip_all_authorities() }), - ), - ( - "SkipAllAdditionals", - Box::new(|p: &mut Parser<'_>| -> Result<()> { p.skip_all_additionals() }), - ), - ]; - - let mut p = Parser::default(); - for (name, test_fn) in tests { - if let Err(err) = test_fn(&mut p) { - assert_eq!(err, Error::ErrNotStarted); - } else { - panic!("{name} expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_too_many_records() -> Result<()> { - let recs: usize = u16::MAX as usize + 1; - let tests = vec![ - ( - "Questions", - Message { - questions: vec![Question::default(); recs], - ..Default::default() - }, - Error::ErrTooManyQuestions, - ), - ( - "Answers", - Message { - answers: { - let mut a = vec![]; - for _ in 0..recs { - a.push(Resource::default()); - } - a - }, - ..Default::default() - }, - Error::ErrTooManyAnswers, - ), - ( - "Authorities", - Message { - authorities: { - let mut a = vec![]; - for _ in 0..recs { - a.push(Resource::default()); - } - a - }, - ..Default::default() - }, - Error::ErrTooManyAuthorities, - ), - ( - "Additionals", - Message { - additionals: { - let mut a = vec![]; - for _ in 0..recs { - a.push(Resource::default()); - } - a - }, - ..Default::default() - }, - Error::ErrTooManyAdditionals, - ), - ]; - - for (name, mut msg, want) in tests { - if let Err(got) = msg.pack() { - assert_eq!( - got, want, - "got Message.Pack() for {name} = {got}, want = {want}" - ) - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_very_long_txt() -> Result<()> { - let mut str255 = String::new(); - for _ in 0..255 { - str255.push('.'); - } - - let mut want = Resource { - header: ResourceHeader { - name: Name::new("foo.bar.example.com.")?, - typ: DnsType::Txt, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(TxtResource { - txt: vec![ - "".to_owned(), - "".to_owned(), - "foo bar".to_owned(), - "".to_owned(), - "www.example.com".to_owned(), - "www.example.com.".to_owned(), - str255, - ], - })), - }; - - let buf = want.pack(vec![], &mut Some(HashMap::new()), 0)?; - let mut got = Resource::default(); - let off = got.header.unpack(&buf, 0, 0)?; - let (body, n) = unpack_resource_body(got.header.typ, &buf, off, got.header.length as usize)?; - got.body = Some(body); - assert_eq!( - n, - buf.len(), - "unpacked different amount than packed: got = {}, want = {}", - n, - buf.len(), - ); - assert_eq!(got.to_string(), want.to_string()); - - Ok(()) -} - -#[test] -fn test_too_long_txt() -> Result<()> { - let mut str256 = String::new(); - for _ in 0..256 { - str256.push('.'); - } - let rb = TxtResource { txt: vec![str256] }; - if let Err(err) = rb.pack(vec![], &mut Some(HashMap::new()), 0) { - assert_eq!(err, Error::ErrStringTooLong); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_start_error() -> Result<()> { - let tests: Vec<(&str, Box Result<()>>)> = vec![ - ( - "Questions", - Box::new(|b: &mut Builder| -> Result<()> { b.start_questions() }), - ), - ( - "Answers", - Box::new(|b: &mut Builder| -> Result<()> { b.start_answers() }), - ), - ( - "Authorities", - Box::new(|b: &mut Builder| -> Result<()> { b.start_authorities() }), - ), - ( - "Additionals", - Box::new(|b: &mut Builder| -> Result<()> { b.start_additionals() }), - ), - ]; - - let envs: Vec<(&str, Box Builder>, Error)> = vec![ - ( - "sectionNotStarted", - Box::new(|| -> Builder { - Builder { - section: Section::NotStarted, - ..Default::default() - } - }), - Error::ErrNotStarted, - ), - ( - "sectionDone", - Box::new(|| -> Builder { - Builder { - section: Section::Done, - ..Default::default() - } - }), - Error::ErrSectionDone, - ), - ]; - - for (env_name, env_fn, env_err) in &envs { - for (test_name, test_fn) in &tests { - let mut b = env_fn(); - if let Err(got_err) = test_fn(&mut b) { - assert_eq!( - got_err, *env_err, - "got Builder{env_name}.{test_name} = {got_err}, want = {env_err}" - ); - } else { - panic!("{env_name}.{test_name}expected error, but got ok"); - } - } - } - - Ok(()) -} - -#[test] -fn test_builder_resource_error() -> Result<()> { - let tests: Vec<(&str, Box Result<()>>)> = vec![ - ( - "CNAMEResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "MXResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "NSResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "PTRResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "SOAResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "TXTResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "SRVResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "AResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "AAAAResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ( - "OPTResource", - Box::new(|b: &mut Builder| -> Result<()> { - b.add_resource(&mut Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }) - }), - ), - ]; - - let envs: Vec<(&str, Box Builder>, Error)> = vec![ - ( - "sectionNotStarted", - Box::new(|| -> Builder { - Builder { - section: Section::NotStarted, - ..Default::default() - } - }), - Error::ErrNotStarted, - ), - ( - "sectionHeader", - Box::new(|| -> Builder { - Builder { - section: Section::Header, - ..Default::default() - } - }), - Error::ErrNotStarted, - ), - ( - "sectionQuestions", - Box::new(|| -> Builder { - Builder { - section: Section::Questions, - ..Default::default() - } - }), - Error::ErrNotStarted, - ), - ( - "sectionDone", - Box::new(|| -> Builder { - Builder { - section: Section::Done, - ..Default::default() - } - }), - Error::ErrSectionDone, - ), - ]; - - for (env_name, env_fn, env_err) in &envs { - for (test_name, test_fn) in &tests { - let mut b = env_fn(); - if let Err(got_err) = test_fn(&mut b) { - assert_eq!( - got_err, *env_err, - "got Builder{env_name}.{test_name} = {got_err}, want = {env_err}" - ); - } else { - panic!("{env_name}.{test_name}expected error, but got ok"); - } - } - } - - Ok(()) -} - -#[test] -fn test_finish_error() -> Result<()> { - let mut b = Builder::default(); - let want = Error::ErrNotStarted; - if let Err(got) = b.finish() { - assert_eq!(got, want, "got Builder.Finish() = {got}, want = {want}"); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_builder() -> Result<()> { - let mut msg = large_test_msg()?; - let want = msg.pack()?; - - let mut b = Builder::new(&msg.header); - b.enable_compression(); - - b.start_questions()?; - for q in &msg.questions { - b.add_question(q)?; - } - - b.start_answers()?; - for r in &mut msg.answers { - b.add_resource(r)?; - } - - b.start_authorities()?; - for r in &mut msg.authorities { - b.add_resource(r)?; - } - - b.start_additionals()?; - for r in &mut msg.additionals { - b.add_resource(r)?; - } - - let got = b.finish()?; - assert_eq!( - got, - want, - "got.len()={}, want.len()={}", - got.len(), - want.len() - ); - - Ok(()) -} - -#[test] -fn test_resource_pack() -> Result<()> { - let tests = vec![ - ( - Message { - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - }], - answers: vec![Resource { - header: ResourceHeader::default(), - body: None, - }], - ..Default::default() - }, - Error::ErrNilResourceBody, - ), - ( - Message { - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - }], - authorities: vec![Resource { - header: ResourceHeader::default(), - body: Some(Box::::default()), - }], - ..Default::default() - }, - Error::ErrNonCanonicalName, - ), - ( - Message { - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::A, - class: DNSCLASS_INET, - }], - additionals: vec![Resource { - header: ResourceHeader::default(), - body: None, - }], - ..Default::default() - }, - Error::ErrNilResourceBody, - ), - ]; - - for (mut m, want_err) in tests { - if let Err(err) = m.pack() { - assert_eq!(err, want_err); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_resource_pack_length() -> Result<()> { - let mut r = Resource { - header: ResourceHeader { - name: Name::new(".")?, - typ: DnsType::A, - class: DNSCLASS_INET, - ..Default::default() - }, - body: Some(Box::new(AResource { a: [127, 0, 0, 2] })), - }; - - let (hb, _) = r.header.pack(vec![], &mut None, 0)?; - let buf = r.pack(vec![], &mut None, 0)?; - - let mut hdr = ResourceHeader::default(); - hdr.unpack(&buf, 0, 0)?; - - let (got, want) = (hdr.length as usize, buf.len() - hb.len()); - assert_eq!(got, want, "got hdr.Length = {got}, want = {want}"); - - Ok(()) -} - -#[test] -fn test_option_pack_unpack() -> Result<()> { - let tests = vec![ - ( - "without EDNS(0) options", - vec![ - 0x00, 0x00, 0x29, 0x10, 0x00, 0xfe, 0x00, 0x80, 0x00, 0x00, 0x00, - ], - Message { - header: Header { - rcode: RCode::FormatError, - ..Default::default() - }, - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::A, - class: DNSCLASS_INET, - }], - additionals: vec![Resource { - header: must_edns0_resource_header( - 4096, - 0xfe0 | RCode::FormatError as u32, - true, - )?, - body: Some(Box::::default()), - }], - ..Default::default() - }, - //true, - //0xfe0 | RCode::FormatError as u32, - ), - ( - "with EDNS(0) options", - vec![ - 0x00, 0x00, 0x29, 0x10, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x12, 0x34, - ], - Message { - header: Header { - rcode: RCode::ServerFailure, - ..Default::default() - }, - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - }], - additionals: vec![Resource { - header: must_edns0_resource_header( - 4096, - 0xff0 | RCode::ServerFailure as u32, - false, - )?, - body: Some(Box::new(OptResource { - options: vec![ - DnsOption { - code: 12, // see RFC 7828 - data: vec![0x00, 0x00], - }, - DnsOption { - code: 11, // see RFC 7830 - data: vec![0x12, 0x34], - }, - ], - })), - }], - ..Default::default() - }, - //dnssecOK: false, - //extRCode: 0xff0 | RCodeServerFailure, - ), - ( - // Containing multiple OPT resources in a - // message is invalid, but it's necessary for - // protocol conformance testing. - "with multiple OPT resources", - vec![ - 0x00, 0x00, 0x29, 0x10, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x0b, 0x00, - 0x02, 0x12, 0x34, 0x00, 0x00, 0x29, 0x10, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x06, - 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, - ], - Message { - header: Header { - rcode: RCode::NameError, - ..Default::default() - }, - questions: vec![Question { - name: Name::new(".")?, - typ: DnsType::Aaaa, - class: DNSCLASS_INET, - }], - additionals: vec![ - Resource { - header: must_edns0_resource_header( - 4096, - 0xff0 | RCode::NameError as u32, - false, - )?, - body: Some(Box::new(OptResource { - options: vec![DnsOption { - code: 11, // see RFC 7830 - data: vec![0x12, 0x34], - }], - })), - }, - Resource { - header: must_edns0_resource_header( - 4096, - 0xff0 | RCode::NameError as u32, - false, - )?, - body: Some(Box::new(OptResource { - options: vec![DnsOption { - code: 12, // see RFC 7828 - data: vec![0x00, 0x00], - }], - })), - }, - ], - ..Default::default() - }, - ), - ]; - - for (_tt_name, tt_w, mut tt_m) in tests { - let w = tt_m.pack()?; - - assert_eq!(&w[w.len() - tt_w.len()..], &tt_w[..]); - - let mut m = Message::default(); - m.unpack(&w)?; - - let ms: Vec = m.additionals.iter().map(|s| s.to_string()).collect(); - let tt_ms: Vec = tt_m.additionals.iter().map(|s| s.to_string()).collect(); - assert_eq!(ms, tt_ms); - } - - Ok(()) -} diff --git a/mdns/src/message/mod.rs b/mdns/src/message/mod.rs deleted file mode 100644 index ae1e35351..000000000 --- a/mdns/src/message/mod.rs +++ /dev/null @@ -1,349 +0,0 @@ -#[cfg(test)] -mod message_test; - -pub mod builder; -pub mod header; -pub mod name; -mod packer; -pub mod parser; -pub mod question; -pub mod resource; - -use std::collections::HashMap; -use std::fmt; - -use header::*; -use packer::*; -use parser::*; -use question::*; -use resource::*; - -use crate::error::*; - -// Message formats - -// A Type is a type of DNS request and response. -#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] -pub enum DnsType { - // ResourceHeader.Type and question.Type - A = 1, - Ns = 2, - Cname = 5, - Soa = 6, - Ptr = 12, - Mx = 15, - Txt = 16, - Aaaa = 28, - Srv = 33, - Opt = 41, - - // question.Type - Wks = 11, - Hinfo = 13, - Minfo = 14, - Axfr = 252, - All = 255, - - #[default] - Unsupported = 0, -} - -impl From for DnsType { - fn from(v: u16) -> Self { - match v { - 1 => DnsType::A, - 2 => DnsType::Ns, - 5 => DnsType::Cname, - 6 => DnsType::Soa, - 12 => DnsType::Ptr, - 15 => DnsType::Mx, - 16 => DnsType::Txt, - 28 => DnsType::Aaaa, - 33 => DnsType::Srv, - 41 => DnsType::Opt, - - // question.Type - 11 => DnsType::Wks, - 13 => DnsType::Hinfo, - 14 => DnsType::Minfo, - 252 => DnsType::Axfr, - 255 => DnsType::All, - - _ => DnsType::Unsupported, - } - } -} - -impl fmt::Display for DnsType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - DnsType::A => "A", - DnsType::Ns => "NS", - DnsType::Cname => "CNAME", - DnsType::Soa => "SOA", - DnsType::Ptr => "PTR", - DnsType::Mx => "MX", - DnsType::Txt => "TXT", - DnsType::Aaaa => "AAAA", - DnsType::Srv => "SRV", - DnsType::Opt => "OPT", - DnsType::Wks => "WKS", - DnsType::Hinfo => "HINFO", - DnsType::Minfo => "MINFO", - DnsType::Axfr => "AXFR", - DnsType::All => "ALL", - _ => "Unsupported", - }; - write!(f, "{s}") - } -} - -impl DnsType { - // pack_type appends the wire format of field to msg. - pub(crate) fn pack(&self, msg: Vec) -> Vec { - pack_uint16(msg, *self as u16) - } - - pub(crate) fn unpack(&mut self, msg: &[u8], off: usize) -> Result { - let (t, o) = unpack_uint16(msg, off)?; - *self = DnsType::from(t); - Ok(o) - } - - pub(crate) fn skip(msg: &[u8], off: usize) -> Result { - skip_uint16(msg, off) - } -} - -// A Class is a type of network. -#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] -pub struct DnsClass(pub u16); - -// ResourceHeader.Class and question.Class -pub const DNSCLASS_INET: DnsClass = DnsClass(1); -pub const DNSCLASS_CSNET: DnsClass = DnsClass(2); -pub const DNSCLASS_CHAOS: DnsClass = DnsClass(3); -pub const DNSCLASS_HESIOD: DnsClass = DnsClass(4); -// question.Class -pub const DNSCLASS_ANY: DnsClass = DnsClass(255); - -impl fmt::Display for DnsClass { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let other = format!("{}", self.0); - let s = match *self { - DNSCLASS_INET => "ClassINET", - DNSCLASS_CSNET => "ClassCSNET", - DNSCLASS_CHAOS => "ClassCHAOS", - DNSCLASS_HESIOD => "ClassHESIOD", - DNSCLASS_ANY => "ClassANY", - _ => other.as_str(), - }; - write!(f, "{s}") - } -} - -impl DnsClass { - // pack_class appends the wire format of field to msg. - pub(crate) fn pack(&self, msg: Vec) -> Vec { - pack_uint16(msg, self.0) - } - - pub(crate) fn unpack(&mut self, msg: &[u8], off: usize) -> Result { - let (c, o) = unpack_uint16(msg, off)?; - *self = DnsClass(c); - Ok(o) - } - - pub(crate) fn skip(msg: &[u8], off: usize) -> Result { - skip_uint16(msg, off) - } -} - -// An OpCode is a DNS operation code. -pub type OpCode = u16; - -// An RCode is a DNS response status code. -#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] -pub enum RCode { - // Message.Rcode - #[default] - Success = 0, - FormatError = 1, - ServerFailure = 2, - NameError = 3, - NotImplemented = 4, - Refused = 5, - Unsupported, -} - -impl From for RCode { - fn from(v: u8) -> Self { - match v { - 0 => RCode::Success, - 1 => RCode::FormatError, - 2 => RCode::ServerFailure, - 3 => RCode::NameError, - 4 => RCode::NotImplemented, - 5 => RCode::Refused, - _ => RCode::Unsupported, - } - } -} - -impl fmt::Display for RCode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RCode::Success => "RCodeSuccess", - RCode::FormatError => "RCodeFormatError", - RCode::ServerFailure => "RCodeServerFailure", - RCode::NameError => "RCodeNameError", - RCode::NotImplemented => "RCodeNotImplemented", - RCode::Refused => "RCodeRefused", - RCode::Unsupported => "RCodeUnsupported", - }; - write!(f, "{s}") - } -} - -// Internal constants. - -// PACK_STARTING_CAP is the default initial buffer size allocated during -// packing. -// -// The starting capacity doesn't matter too much, but most DNS responses -// Will be <= 512 bytes as it is the limit for DNS over UDP. -const PACK_STARTING_CAP: usize = 512; - -// UINT16LEN is the length (in bytes) of a uint16. -const UINT16LEN: usize = 2; - -// UINT32LEN is the length (in bytes) of a uint32. -const UINT32LEN: usize = 4; - -// HEADER_LEN is the length (in bytes) of a DNS header. -// -// A header is comprised of 6 uint16s and no padding. -const HEADER_LEN: usize = 6 * UINT16LEN; - -const HEADER_BIT_QR: u16 = 1 << 15; // query/response (response=1) -const HEADER_BIT_AA: u16 = 1 << 10; // authoritative -const HEADER_BIT_TC: u16 = 1 << 9; // truncated -const HEADER_BIT_RD: u16 = 1 << 8; // recursion desired -const HEADER_BIT_RA: u16 = 1 << 7; // recursion available - -// Message is a representation of a DNS message. -#[derive(Default, Debug)] -pub struct Message { - pub header: Header, - pub questions: Vec, - pub answers: Vec, - pub authorities: Vec, - pub additionals: Vec, -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut s = "dnsmessage.Message{Header: ".to_owned(); - s += self.header.to_string().as_str(); - - s += ", Questions: "; - let v: Vec = self.questions.iter().map(|q| q.to_string()).collect(); - s += &v.join(", "); - - s += ", Answers: "; - let v: Vec = self.answers.iter().map(|q| q.to_string()).collect(); - s += &v.join(", "); - - s += ", Authorities: "; - let v: Vec = self.authorities.iter().map(|q| q.to_string()).collect(); - s += &v.join(", "); - - s += ", Additionals: "; - let v: Vec = self.additionals.iter().map(|q| q.to_string()).collect(); - s += &v.join(", "); - - write!(f, "{s}") - } -} - -impl Message { - // Unpack parses a full Message. - pub fn unpack(&mut self, msg: &[u8]) -> Result<()> { - let mut p = Parser::default(); - self.header = p.start(msg)?; - self.questions = p.all_questions()?; - self.answers = p.all_answers()?; - self.authorities = p.all_authorities()?; - self.additionals = p.all_additionals()?; - Ok(()) - } - - // Pack packs a full Message. - pub fn pack(&mut self) -> Result> { - self.append_pack(vec![]) - } - - // append_pack is like Pack but appends the full Message to b and returns the - // extended buffer. - pub fn append_pack(&mut self, b: Vec) -> Result> { - // Validate the lengths. It is very unlikely that anyone will try to - // pack more than 65535 of any particular type, but it is possible and - // we should fail gracefully. - if self.questions.len() > u16::MAX as usize { - return Err(Error::ErrTooManyQuestions); - } - if self.answers.len() > u16::MAX as usize { - return Err(Error::ErrTooManyAnswers); - } - if self.authorities.len() > u16::MAX as usize { - return Err(Error::ErrTooManyAuthorities); - } - if self.additionals.len() > u16::MAX as usize { - return Err(Error::ErrTooManyAdditionals); - } - - let (id, bits) = self.header.pack(); - - let questions = self.questions.len() as u16; - let answers = self.answers.len() as u16; - let authorities = self.authorities.len() as u16; - let additionals = self.additionals.len() as u16; - - let h = HeaderInternal { - id, - bits, - questions, - answers, - authorities, - additionals, - }; - - let compression_off = b.len(); - let mut msg = h.pack(b); - - // RFC 1035 allows (but does not require) compression for packing. RFC - // 1035 requires unpacking implementations to support compression, so - // unconditionally enabling it is fine. - // - // DNS lookups are typically done over UDP, and RFC 1035 states that UDP - // DNS messages can be a maximum of 512 bytes long. Without compression, - // many DNS response messages are over this limit, so enabling - // compression will help ensure compliance. - let mut compression = Some(HashMap::new()); - - for question in &self.questions { - msg = question.pack(msg, &mut compression, compression_off)?; - } - for answer in &mut self.answers { - msg = answer.pack(msg, &mut compression, compression_off)?; - } - for authority in &mut self.authorities { - msg = authority.pack(msg, &mut compression, compression_off)?; - } - for additional in &mut self.additionals { - msg = additional.pack(msg, &mut compression, compression_off)?; - } - - Ok(msg) - } -} diff --git a/mdns/src/message/name.rs b/mdns/src/message/name.rs deleted file mode 100644 index f813fbd0f..000000000 --- a/mdns/src/message/name.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::collections::HashMap; -use std::fmt; - -use crate::error::*; - -const NAME_LEN: usize = 255; - -// A Name is a non-encoded domain name. It is used instead of strings to avoid -// allocations. -#[derive(Default, PartialEq, Eq, Debug, Clone)] -pub struct Name { - pub data: String, -} - -// String implements fmt.Stringer.String. -impl fmt::Display for Name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.data) - } -} - -impl Name { - pub fn new(data: &str) -> Result { - if data.len() > NAME_LEN { - Err(Error::ErrCalcLen) - } else { - Ok(Name { - data: data.to_owned(), - }) - } - } - - // pack appends the wire format of the Name to msg. - // - // Domain names are a sequence of counted strings split at the dots. They end - // with a zero-length string. Compression can be used to reuse domain suffixes. - // - // The compression map will be updated with new domain suffixes. If compression - // is nil, compression will not be used. - pub fn pack( - &self, - mut msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - let data = self.data.as_bytes(); - - // Add a trailing dot to canonicalize name. - if data.is_empty() || data[data.len() - 1] != b'.' { - return Err(Error::ErrNonCanonicalName); - } - - // Allow root domain. - if data.len() == 1 && data[0] == b'.' { - msg.push(0); - return Ok(msg); - } - - // Emit sequence of counted strings, chopping at dots. - let mut begin = 0; - for i in 0..data.len() { - // Check for the end of the segment. - if data[i] == b'.' { - // The two most significant bits have special meaning. - // It isn't allowed for segments to be long enough to - // need them. - if i - begin >= (1 << 6) { - return Err(Error::ErrSegTooLong); - } - - // Segments must have a non-zero length. - if i - begin == 0 { - return Err(Error::ErrZeroSegLen); - } - - msg.push((i - begin) as u8); - msg.extend_from_slice(&data[begin..i]); - - begin = i + 1; - continue; - } - - // We can only compress domain suffixes starting with a new - // segment. A pointer is two bytes with the two most significant - // bits set to 1 to indicate that it is a pointer. - if i == 0 || data[i - 1] == b'.' { - if let Some(compression) = compression { - let key: String = self.data.chars().skip(i).collect(); - if let Some(ptr) = compression.get(&key) { - // Hit. Emit a pointer instead of the rest of - // the domain. - msg.push(((ptr >> 8) | 0xC0) as u8); - msg.push((ptr & 0xFF) as u8); - return Ok(msg); - } - - // Miss. Add the suffix to the compression table if the - // offset can be stored in the available 14 bytes. - if msg.len() <= 0x3FFF { - compression.insert(key, msg.len() - compression_off); - } - } - } - } - - msg.push(0); - Ok(msg) - } - - // unpack unpacks a domain name. - pub fn unpack(&mut self, msg: &[u8], off: usize) -> Result { - self.unpack_compressed(msg, off, true /* allowCompression */) - } - - pub fn unpack_compressed( - &mut self, - msg: &[u8], - off: usize, - allow_compression: bool, - ) -> Result { - // curr_off is the current working offset. - let mut curr_off = off; - - // new_off is the offset where the next record will start. Pointers lead - // to data that belongs to other names and thus doesn't count towards to - // the usage of this name. - let mut new_off = off; - - // ptr is the number of pointers followed. - let mut ptr = 0; - - // Name is a slice representation of the name data. - let mut name = String::new(); //n.Data[:0] - - loop { - if curr_off >= msg.len() { - return Err(Error::ErrBaseLen); - } - let c = msg[curr_off]; - curr_off += 1; - match c & 0xC0 { - 0x00 => { - // String segment - if c == 0x00 { - // A zero length signals the end of the name. - break; - } - let end_off = curr_off + c as usize; - if end_off > msg.len() { - return Err(Error::ErrCalcLen); - } - name.push_str(String::from_utf8(msg[curr_off..end_off].to_vec())?.as_str()); - name.push('.'); - curr_off = end_off; - } - 0xC0 => { - // Pointer - if !allow_compression { - return Err(Error::ErrCompressedSrv); - } - if curr_off >= msg.len() { - return Err(Error::ErrInvalidPtr); - } - let c1 = msg[curr_off]; - curr_off += 1; - if ptr == 0 { - new_off = curr_off; - } - // Don't follow too many pointers, maybe there's a loop. - ptr += 1; - if ptr > 10 { - return Err(Error::ErrTooManyPtr); - } - curr_off = ((c ^ 0xC0) as usize) << 8 | (c1 as usize); - } - _ => { - // Prefixes 0x80 and 0x40 are reserved. - return Err(Error::ErrReserved); - } - } - } - if name.is_empty() { - name.push('.'); - } - if name.len() > NAME_LEN { - return Err(Error::ErrCalcLen); - } - self.data = name; - if ptr == 0 { - new_off = curr_off; - } - Ok(new_off) - } - - pub(crate) fn skip(msg: &[u8], off: usize) -> Result { - // new_off is the offset where the next record will start. Pointers lead - // to data that belongs to other names and thus doesn't count towards to - // the usage of this name. - let mut new_off = off; - - loop { - if new_off >= msg.len() { - return Err(Error::ErrBaseLen); - } - let c = msg[new_off]; - new_off += 1; - match c & 0xC0 { - 0x00 => { - if c == 0x00 { - // A zero length signals the end of the name. - break; - } - // literal string - new_off += c as usize; - if new_off > msg.len() { - return Err(Error::ErrCalcLen); - } - } - 0xC0 => { - // Pointer to somewhere else in msg. - - // Pointers are two bytes. - new_off += 1; - - // Don't follow the pointer as the data here has ended. - break; - } - _ => { - // Prefixes 0x80 and 0x40 are reserved. - return Err(Error::ErrReserved); - } - } - } - - Ok(new_off) - } -} diff --git a/mdns/src/message/packer.rs b/mdns/src/message/packer.rs deleted file mode 100644 index 76c617cc5..000000000 --- a/mdns/src/message/packer.rs +++ /dev/null @@ -1,92 +0,0 @@ -use super::*; -use crate::error::*; - -// pack_bytes appends the wire format of field to msg. -pub(crate) fn pack_bytes(mut msg: Vec, field: &[u8]) -> Vec { - msg.extend_from_slice(field); - msg -} - -pub(crate) fn unpack_bytes(msg: &[u8], off: usize, field: &mut [u8]) -> Result { - let new_off = off + field.len(); - if new_off > msg.len() { - return Err(Error::ErrBaseLen); - } - field.copy_from_slice(&msg[off..new_off]); - Ok(new_off) -} - -// pack_uint16 appends the wire format of field to msg. -pub(crate) fn pack_uint16(mut msg: Vec, field: u16) -> Vec { - msg.extend_from_slice(&field.to_be_bytes()); - msg -} - -pub(crate) fn unpack_uint16(msg: &[u8], off: usize) -> Result<(u16, usize)> { - if off + UINT16LEN > msg.len() { - return Err(Error::ErrBaseLen); - } - - Ok(( - (msg[off] as u16) << 8 | (msg[off + 1] as u16), - off + UINT16LEN, - )) -} - -pub(crate) fn skip_uint16(msg: &[u8], off: usize) -> Result { - if off + UINT16LEN > msg.len() { - return Err(Error::ErrBaseLen); - } - Ok(off + UINT16LEN) -} - -// pack_uint32 appends the wire format of field to msg. -pub(crate) fn pack_uint32(mut msg: Vec, field: u32) -> Vec { - msg.extend_from_slice(&field.to_be_bytes()); - msg -} - -pub(crate) fn unpack_uint32(msg: &[u8], off: usize) -> Result<(u32, usize)> { - if off + UINT32LEN > msg.len() { - return Err(Error::ErrBaseLen); - } - let v = (msg[off] as u32) << 24 - | (msg[off + 1] as u32) << 16 - | (msg[off + 2] as u32) << 8 - | (msg[off + 3] as u32); - Ok((v, off + UINT32LEN)) -} - -pub(crate) fn skip_uint32(msg: &[u8], off: usize) -> Result { - if off + UINT32LEN > msg.len() { - return Err(Error::ErrBaseLen); - } - Ok(off + UINT32LEN) -} - -// pack_text appends the wire format of field to msg. -pub(crate) fn pack_str(mut msg: Vec, field: &str) -> Result> { - let l = field.len(); - if l > 255 { - return Err(Error::ErrStringTooLong); - } - msg.push(l as u8); - msg.extend_from_slice(field.as_bytes()); - Ok(msg) -} - -pub(crate) fn unpack_str(msg: &[u8], off: usize) -> Result<(String, usize)> { - if off >= msg.len() { - return Err(Error::ErrBaseLen); - } - let begin_off = off + 1; - let end_off = begin_off + msg[off] as usize; - if end_off > msg.len() { - return Err(Error::ErrCalcLen); - } - - Ok(( - String::from_utf8(msg[begin_off..end_off].to_vec())?, - end_off, - )) -} diff --git a/mdns/src/message/parser.rs b/mdns/src/message/parser.rs deleted file mode 100644 index 8b120f8ee..000000000 --- a/mdns/src/message/parser.rs +++ /dev/null @@ -1,347 +0,0 @@ -use crate::error::*; -use crate::message::header::{Header, HeaderInternal, Section}; -use crate::message::name::Name; -use crate::message::question::Question; -use crate::message::resource::{unpack_resource_body, Resource, ResourceBody, ResourceHeader}; -use crate::message::{DnsClass, DnsType}; - -// A Parser allows incrementally parsing a DNS message. -// -// When parsing is started, the Header is parsed. Next, each question can be -// either parsed or skipped. Alternatively, all Questions can be skipped at -// once. When all Questions have been parsed, attempting to parse Questions -// will return (nil, nil) and attempting to skip Questions will return -// (true, nil). After all Questions have been either parsed or skipped, all -// Answers, Authorities and Additionals can be either parsed or skipped in the -// same way, and each type of Resource must be fully parsed or skipped before -// proceeding to the next type of Resource. -// -// Note that there is no requirement to fully skip or parse the message. -#[derive(Default)] -pub struct Parser<'a> { - pub msg: &'a [u8], - pub header: HeaderInternal, - - pub section: Section, - pub off: usize, - pub index: usize, - pub res_header_valid: bool, - pub res_header: ResourceHeader, -} - -impl<'a> Parser<'a> { - // start parses the header and enables the parsing of Questions. - pub fn start(&mut self, msg: &'a [u8]) -> Result
{ - *self = Parser { - msg, - ..Default::default() - }; - self.off = self.header.unpack(msg, 0)?; - self.section = Section::Questions; - Ok(self.header.header()) - } - - fn check_advance(&mut self, sec: Section) -> Result<()> { - if self.section < sec { - return Err(Error::ErrNotStarted); - } - if self.section > sec { - return Err(Error::ErrSectionDone); - } - self.res_header_valid = false; - if self.index == self.header.count(sec) as usize { - self.index = 0; - self.section = Section::from(1 + self.section as u8); - return Err(Error::ErrSectionDone); - } - Ok(()) - } - - fn resource(&mut self, sec: Section) -> Result { - let header = self.resource_header(sec)?; - self.res_header_valid = false; - let (body, off) = - unpack_resource_body(header.typ, self.msg, self.off, header.length as usize)?; - self.off = off; - self.index += 1; - Ok(Resource { - header, - body: Some(body), - }) - } - - fn resource_header(&mut self, sec: Section) -> Result { - if self.res_header_valid { - return Ok(self.res_header.clone()); - } - self.check_advance(sec)?; - let mut hdr = ResourceHeader::default(); - let off = hdr.unpack(self.msg, self.off, 0)?; - - self.res_header_valid = true; - self.res_header = hdr.clone(); - self.off = off; - Ok(hdr) - } - - fn skip_resource(&mut self, sec: Section) -> Result<()> { - if self.res_header_valid { - let new_off = self.off + self.res_header.length as usize; - if new_off > self.msg.len() { - return Err(Error::ErrResourceLen); - } - self.off = new_off; - self.res_header_valid = false; - self.index += 1; - return Ok(()); - } - self.check_advance(sec)?; - - self.off = Resource::skip(self.msg, self.off)?; - self.index += 1; - Ok(()) - } - - // question parses a single question. - pub fn question(&mut self) -> Result { - self.check_advance(Section::Questions)?; - let mut name = Name::new("")?; - let mut off = name.unpack(self.msg, self.off)?; - let mut typ = DnsType::Unsupported; - off = typ.unpack(self.msg, off)?; - let mut class = DnsClass::default(); - off = class.unpack(self.msg, off)?; - self.off = off; - self.index += 1; - Ok(Question { name, typ, class }) - } - - // all_questions parses all Questions. - pub fn all_questions(&mut self) -> Result> { - // Multiple questions are valid according to the spec, - // but servers don't actually support them. There will - // be at most one question here. - // - // Do not pre-allocate based on info in self.header, since - // the data is untrusted. - let mut qs = vec![]; - loop { - match self.question() { - Err(err) => { - if Error::ErrSectionDone == err { - return Ok(qs); - } else { - return Err(err); - } - } - Ok(q) => qs.push(q), - } - } - } - - // skip_question skips a single question. - pub fn skip_question(&mut self) -> Result<()> { - self.check_advance(Section::Questions)?; - let mut off = Name::skip(self.msg, self.off)?; - off = DnsType::skip(self.msg, off)?; - off = DnsClass::skip(self.msg, off)?; - self.off = off; - self.index += 1; - Ok(()) - } - - // skip_all_questions skips all Questions. - pub fn skip_all_questions(&mut self) -> Result<()> { - loop { - if let Err(err) = self.skip_question() { - if Error::ErrSectionDone == err { - return Ok(()); - } else { - return Err(err); - } - } - } - } - - // answer_header parses a single answer ResourceHeader. - pub fn answer_header(&mut self) -> Result { - self.resource_header(Section::Answers) - } - - // answer parses a single answer Resource. - pub fn answer(&mut self) -> Result { - self.resource(Section::Answers) - } - - // all_answers parses all answer Resources. - pub fn all_answers(&mut self) -> Result> { - // The most common query is for A/AAAA, which usually returns - // a handful of IPs. - // - // Pre-allocate up to a certain limit, since self.header is - // untrusted data. - let mut n = self.header.answers as usize; - if n > 20 { - n = 20 - } - let mut a = Vec::with_capacity(n); - loop { - match self.answer() { - Err(err) => { - if Error::ErrSectionDone == err { - return Ok(a); - } else { - return Err(err); - } - } - Ok(r) => a.push(r), - } - } - } - - // skip_answer skips a single answer Resource. - pub fn skip_answer(&mut self) -> Result<()> { - self.skip_resource(Section::Answers) - } - - // skip_all_answers skips all answer Resources. - pub fn skip_all_answers(&mut self) -> Result<()> { - loop { - if let Err(err) = self.skip_answer() { - if Error::ErrSectionDone == err { - return Ok(()); - } else { - return Err(err); - } - } - } - } - - // authority_header parses a single authority ResourceHeader. - pub fn authority_header(&mut self) -> Result { - self.resource_header(Section::Authorities) - } - - // authority parses a single authority Resource. - pub fn authority(&mut self) -> Result { - self.resource(Section::Authorities) - } - - // all_authorities parses all authority Resources. - pub fn all_authorities(&mut self) -> Result> { - // Authorities contains SOA in case of NXDOMAIN and friends, - // otherwise it is empty. - // - // Pre-allocate up to a certain limit, since self.header is - // untrusted data. - let mut n = self.header.authorities as usize; - if n > 10 { - n = 10; - } - let mut a = Vec::with_capacity(n); - loop { - match self.authority() { - Err(err) => { - if Error::ErrSectionDone == err { - return Ok(a); - } else { - return Err(err); - } - } - Ok(r) => a.push(r), - } - } - } - - // skip_authority skips a single authority Resource. - pub fn skip_authority(&mut self) -> Result<()> { - self.skip_resource(Section::Authorities) - } - - // skip_all_authorities skips all authority Resources. - pub fn skip_all_authorities(&mut self) -> Result<()> { - loop { - if let Err(err) = self.skip_authority() { - if Error::ErrSectionDone == err { - return Ok(()); - } else { - return Err(err); - } - } - } - } - - // additional_header parses a single additional ResourceHeader. - pub fn additional_header(&mut self) -> Result { - self.resource_header(Section::Additionals) - } - - // additional parses a single additional Resource. - pub fn additional(&mut self) -> Result { - self.resource(Section::Additionals) - } - - // all_additionals parses all additional Resources. - pub fn all_additionals(&mut self) -> Result> { - // Additionals usually contain OPT, and sometimes A/AAAA - // glue records. - // - // Pre-allocate up to a certain limit, since self.header is - // untrusted data. - let mut n = self.header.additionals as usize; - if n > 10 { - n = 10; - } - let mut a = Vec::with_capacity(n); - loop { - match self.additional() { - Err(err) => { - if Error::ErrSectionDone == err { - return Ok(a); - } else { - return Err(err); - } - } - Ok(r) => a.push(r), - } - } - } - - // skip_additional skips a single additional Resource. - pub fn skip_additional(&mut self) -> Result<()> { - self.skip_resource(Section::Additionals) - } - - // skip_all_additionals skips all additional Resources. - pub fn skip_all_additionals(&mut self) -> Result<()> { - loop { - if let Err(err) = self.skip_additional() { - if Error::ErrSectionDone == err { - return Ok(()); - } else { - return Err(err); - } - } - } - } - - // resource_body parses a single resource_boy. - // - // One of the XXXHeader methods must have been called before calling this - // method. - pub fn resource_body(&mut self) -> Result> { - if !self.res_header_valid { - return Err(Error::ErrNotStarted); - } - let (rb, _off) = unpack_resource_body( - self.res_header.typ, - self.msg, - self.off, - self.res_header.length as usize, - )?; - self.off += self.res_header.length as usize; - self.res_header_valid = false; - self.index += 1; - Ok(rb) - } -} diff --git a/mdns/src/message/question.rs b/mdns/src/message/question.rs deleted file mode 100644 index ef2023244..000000000 --- a/mdns/src/message/question.rs +++ /dev/null @@ -1,38 +0,0 @@ -use std::collections::HashMap; -use std::fmt; - -use super::name::*; -use super::*; -use crate::error::Result; - -// A question is a DNS query. -#[derive(Default, Debug, PartialEq, Eq, Clone)] -pub struct Question { - pub name: Name, - pub typ: DnsType, - pub class: DnsClass, -} - -impl fmt::Display for Question { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.question{{Name: {}, Type: {}, Class: {}}}", - self.name, self.typ, self.class - ) - } -} - -impl Question { - // pack appends the wire format of the question to msg. - pub fn pack( - &self, - mut msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - msg = self.name.pack(msg, compression, compression_off)?; - msg = self.typ.pack(msg); - Ok(self.class.pack(msg)) - } -} diff --git a/mdns/src/message/resource/a.rs b/mdns/src/message/resource/a.rs deleted file mode 100644 index dedb1d942..000000000 --- a/mdns/src/message/resource/a.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::*; -use crate::message::packer::*; - -// An AResource is an A Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct AResource { - pub a: [u8; 4], -} - -impl fmt::Display for AResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "dnsmessage.AResource{{A: {:?}}}", self.a) - } -} - -impl ResourceBody for AResource { - fn real_type(&self) -> DnsType { - DnsType::A - } - - // pack appends the wire format of the AResource to msg. - fn pack( - &self, - msg: Vec, - _compression: &mut Option>, - _compression_off: usize, - ) -> Result> { - Ok(pack_bytes(msg, &self.a)) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - unpack_bytes(msg, off, &mut self.a) - } -} diff --git a/mdns/src/message/resource/aaaa.rs b/mdns/src/message/resource/aaaa.rs deleted file mode 100644 index 6a23da84a..000000000 --- a/mdns/src/message/resource/aaaa.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::*; -use crate::message::packer::*; - -// An AAAAResource is an aaaa Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct AaaaResource { - pub aaaa: [u8; 16], -} - -impl fmt::Display for AaaaResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "dnsmessage.AAAAResource{{aaaa: {:?}}}", self.aaaa) - } -} - -impl ResourceBody for AaaaResource { - fn real_type(&self) -> DnsType { - DnsType::Aaaa - } - - // pack appends the wire format of the AAAAResource to msg. - fn pack( - &self, - msg: Vec, - _compression: &mut Option>, - _compression_off: usize, - ) -> Result> { - Ok(pack_bytes(msg, &self.aaaa)) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - unpack_bytes(msg, off, &mut self.aaaa) - } -} diff --git a/mdns/src/message/resource/cname.rs b/mdns/src/message/resource/cname.rs deleted file mode 100644 index b2e2281dd..000000000 --- a/mdns/src/message/resource/cname.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::*; -use crate::message::name::*; - -// A cnameresource is a cname Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct CnameResource { - pub cname: Name, -} - -impl fmt::Display for CnameResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "dnsmessage.cnameresource{{cname: {}}}", self.cname) - } -} - -impl ResourceBody for CnameResource { - fn real_type(&self) -> DnsType { - DnsType::Cname - } - - // pack appends the wire format of the cnameresource to msg. - fn pack( - &self, - msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - self.cname.pack(msg, compression, compression_off) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - self.cname.unpack(msg, off) - } -} diff --git a/mdns/src/message/resource/mod.rs b/mdns/src/message/resource/mod.rs deleted file mode 100644 index 6ca5e57bd..000000000 --- a/mdns/src/message/resource/mod.rs +++ /dev/null @@ -1,273 +0,0 @@ -pub mod a; -pub mod aaaa; -pub mod cname; -pub mod mx; -pub mod ns; -pub mod opt; -pub mod ptr; -pub mod soa; -pub mod srv; -pub mod txt; - -use std::collections::HashMap; -use std::fmt; - -use a::*; -use aaaa::*; -use cname::*; -use mx::*; -use ns::*; -use opt::*; -use ptr::*; -use soa::*; -use srv::*; -use txt::*; - -use super::name::*; -use super::packer::*; -use super::*; -use crate::error::*; - -// EDNS(0) wire constants. - -const EDNS0_VERSION: u32 = 0; -const EDNS0_DNSSEC_OK: u32 = 0x00008000; -const EDNS_VERSION_MASK: u32 = 0x00ff0000; -const EDNS0_DNSSEC_OK_MASK: u32 = 0x00ff8000; - -// A Resource is a DNS resource record. -#[derive(Default, Debug)] -pub struct Resource { - pub header: ResourceHeader, - pub body: Option>, -} - -impl fmt::Display for Resource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.Resource{{Header: {}, Body: {}}}", - self.header, - if let Some(body) = &self.body { - body.to_string() - } else { - "None".to_owned() - } - ) - } -} - -impl Resource { - // pack appends the wire format of the Resource to msg. - pub fn pack( - &mut self, - msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - if let Some(body) = &self.body { - self.header.typ = body.real_type(); - } else { - return Err(Error::ErrNilResourceBody); - } - let (mut msg, len_off) = self.header.pack(msg, compression, compression_off)?; - let pre_len = msg.len(); - if let Some(body) = &self.body { - msg = body.pack(msg, compression, compression_off)?; - self.header.fix_len(&mut msg, len_off, pre_len)?; - } - Ok(msg) - } - - pub fn unpack(&mut self, msg: &[u8], mut off: usize) -> Result { - off = self.header.unpack(msg, off, 0)?; - let (rb, off) = - unpack_resource_body(self.header.typ, msg, off, self.header.length as usize)?; - self.body = Some(rb); - Ok(off) - } - - pub(crate) fn skip(msg: &[u8], off: usize) -> Result { - let mut new_off = Name::skip(msg, off)?; - new_off = DnsType::skip(msg, new_off)?; - new_off = DnsClass::skip(msg, new_off)?; - new_off = skip_uint32(msg, new_off)?; - let (length, mut new_off) = unpack_uint16(msg, new_off)?; - new_off += length as usize; - if new_off > msg.len() { - return Err(Error::ErrResourceLen); - } - Ok(new_off) - } -} - -// A ResourceHeader is the header of a DNS resource record. There are -// many types of DNS resource records, but they all share the same header. -#[derive(Clone, Default, PartialEq, Eq, Debug)] -pub struct ResourceHeader { - // Name is the domain name for which this resource record pertains. - pub name: Name, - - // Type is the type of DNS resource record. - // - // This field will be set automatically during packing. - pub typ: DnsType, - - // Class is the class of network to which this DNS resource record - // pertains. - pub class: DnsClass, - - // TTL is the length of time (measured in seconds) which this resource - // record is valid for (time to live). All Resources in a set should - // have the same TTL (RFC 2181 Section 5.2). - pub ttl: u32, - - // Length is the length of data in the resource record after the header. - // - // This field will be set automatically during packing. - pub length: u16, -} - -impl fmt::Display for ResourceHeader { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.ResourceHeader{{Name: {}, Type: {}, Class: {}, TTL: {}, Length: {}}}", - self.name, self.typ, self.class, self.ttl, self.length, - ) - } -} - -impl ResourceHeader { - // pack appends the wire format of the ResourceHeader to oldMsg. - // - // lenOff is the offset in msg where the Length field was packed. - pub fn pack( - &self, - mut msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result<(Vec, usize)> { - msg = self.name.pack(msg, compression, compression_off)?; - msg = self.typ.pack(msg); - msg = self.class.pack(msg); - msg = pack_uint32(msg, self.ttl); - let len_off = msg.len(); - msg = pack_uint16(msg, self.length); - Ok((msg, len_off)) - } - - pub fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - let mut new_off = off; - new_off = self.name.unpack(msg, new_off)?; - new_off = self.typ.unpack(msg, new_off)?; - new_off = self.class.unpack(msg, new_off)?; - let (ttl, new_off) = unpack_uint32(msg, new_off)?; - self.ttl = ttl; - let (l, new_off) = unpack_uint16(msg, new_off)?; - self.length = l; - - Ok(new_off) - } - - // fixLen updates a packed ResourceHeader to include the length of the - // ResourceBody. - // - // lenOff is the offset of the ResourceHeader.Length field in msg. - // - // preLen is the length that msg was before the ResourceBody was packed. - pub fn fix_len(&mut self, msg: &mut [u8], len_off: usize, pre_len: usize) -> Result<()> { - if msg.len() < pre_len || msg.len() > pre_len + u16::MAX as usize { - return Err(Error::ErrResTooLong); - } - - let con_len = msg.len() - pre_len; - - // Fill in the length now that we know how long the content is. - msg[len_off] = ((con_len >> 8) & 0xFF) as u8; - msg[len_off + 1] = (con_len & 0xFF) as u8; - self.length = con_len as u16; - - Ok(()) - } - - // set_edns0 configures h for EDNS(0). - // - // The provided ext_rcode must be an extended RCode. - pub fn set_edns0( - &mut self, - udp_payload_len: u16, - ext_rcode: u32, - dnssec_ok: bool, - ) -> Result<()> { - self.name = Name { - data: ".".to_owned(), - }; // RFC 6891 section 6.1.2 - self.typ = DnsType::Opt; - self.class = DnsClass(udp_payload_len); - self.ttl = (ext_rcode >> 4) << 24; - if dnssec_ok { - self.ttl |= EDNS0_DNSSEC_OK; - } - Ok(()) - } - - // dnssec_allowed reports whether the DNSSEC OK bit is set. - pub fn dnssec_allowed(&self) -> bool { - self.ttl & EDNS0_DNSSEC_OK_MASK == EDNS0_DNSSEC_OK // RFC 6891 section 6.1.3 - } - - // extended_rcode returns an extended RCode. - // - // The provided rcode must be the RCode in DNS message header. - pub fn extended_rcode(&self, rcode: RCode) -> RCode { - if self.ttl & EDNS_VERSION_MASK == EDNS0_VERSION { - // RFC 6891 section 6.1.3 - let ttl = ((self.ttl >> 24) << 4) as u8 | rcode as u8; - return RCode::from(ttl); - } - rcode - } -} - -// A ResourceBody is a DNS resource record minus the header. -pub trait ResourceBody: fmt::Display + fmt::Debug { - // real_type returns the actual type of the Resource. This is used to - // fill in the header Type field. - fn real_type(&self) -> DnsType; - - // pack packs a Resource except for its header. - fn pack( - &self, - msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result>; - - fn unpack(&mut self, msg: &[u8], off: usize, length: usize) -> Result; -} - -pub fn unpack_resource_body( - typ: DnsType, - msg: &[u8], - mut off: usize, - length: usize, -) -> Result<(Box, usize)> { - let mut rb: Box = match typ { - DnsType::A => Box::::default(), - DnsType::Ns => Box::::default(), - DnsType::Cname => Box::::default(), - DnsType::Soa => Box::::default(), - DnsType::Ptr => Box::::default(), - DnsType::Mx => Box::::default(), - DnsType::Txt => Box::::default(), - DnsType::Aaaa => Box::::default(), - DnsType::Srv => Box::::default(), - DnsType::Opt => Box::::default(), - _ => return Err(Error::ErrNilResourceBody), - }; - - off = rb.unpack(msg, off, length)?; - - Ok((rb, off)) -} diff --git a/mdns/src/message/resource/mx.rs b/mdns/src/message/resource/mx.rs deleted file mode 100644 index b9bcfdfff..000000000 --- a/mdns/src/message/resource/mx.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::*; -use crate::error::Result; -use crate::message::name::*; -use crate::message::packer::*; - -// An MXResource is an mx Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct MxResource { - pub pref: u16, - pub mx: Name, -} - -impl fmt::Display for MxResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.MXResource{{pref: {}, mx: {}}}", - self.pref, self.mx - ) - } -} - -impl ResourceBody for MxResource { - fn real_type(&self) -> DnsType { - DnsType::Mx - } - - // pack appends the wire format of the MXResource to msg. - fn pack( - &self, - mut msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - msg = pack_uint16(msg, self.pref); - msg = self.mx.pack(msg, compression, compression_off)?; - Ok(msg) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - let (pref, off) = unpack_uint16(msg, off)?; - self.pref = pref; - self.mx.unpack(msg, off) - } -} diff --git a/mdns/src/message/resource/ns.rs b/mdns/src/message/resource/ns.rs deleted file mode 100644 index bf2819429..000000000 --- a/mdns/src/message/resource/ns.rs +++ /dev/null @@ -1,35 +0,0 @@ -use super::*; -use crate::error::Result; -use crate::message::name::*; - -// An NSResource is an NS Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct NsResource { - pub ns: Name, -} - -impl fmt::Display for NsResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "dnsmessage.NSResource{{NS: {}}}", self.ns) - } -} - -impl ResourceBody for NsResource { - fn real_type(&self) -> DnsType { - DnsType::Ns - } - - // pack appends the wire format of the NSResource to msg. - fn pack( - &self, - msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - self.ns.pack(msg, compression, compression_off) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _txt_length: usize) -> Result { - self.ns.unpack(msg, off) - } -} diff --git a/mdns/src/message/resource/opt.rs b/mdns/src/message/resource/opt.rs deleted file mode 100644 index 70d0f9e1f..000000000 --- a/mdns/src/message/resource/opt.rs +++ /dev/null @@ -1,84 +0,0 @@ -use super::*; -use crate::error::{Result, *}; -use crate::message::packer::*; - -// An OPTResource is an OPT pseudo Resource record. -// -// The pseudo resource record is part of the extension mechanisms for DNS -// as defined in RFC 6891. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct OptResource { - pub options: Vec, -} - -// An Option represents a DNS message option within OPTResource. -// -// The message option is part of the extension mechanisms for DNS as -// defined in RFC 6891. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct DnsOption { - pub code: u16, // option code - pub data: Vec, -} - -impl fmt::Display for DnsOption { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.Option{{Code: {}, Data: {:?}}}", - self.code, self.data - ) - } -} - -impl fmt::Display for OptResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s: Vec = self.options.iter().map(|o| o.to_string()).collect(); - write!(f, "dnsmessage.OPTResource{{options: {}}}", s.join(",")) - } -} - -impl ResourceBody for OptResource { - fn real_type(&self) -> DnsType { - DnsType::Opt - } - - fn pack( - &self, - mut msg: Vec, - _compression: &mut Option>, - _compression_off: usize, - ) -> Result> { - for opt in &self.options { - msg = pack_uint16(msg, opt.code); - msg = pack_uint16(msg, opt.data.len() as u16); - msg = pack_bytes(msg, &opt.data); - } - Ok(msg) - } - - fn unpack(&mut self, msg: &[u8], mut off: usize, length: usize) -> Result { - let mut opts = vec![]; - let old_off = off; - while off < old_off + length { - let (code, new_off) = unpack_uint16(msg, off)?; - off = new_off; - - let (l, new_off) = unpack_uint16(msg, off)?; - off = new_off; - - let mut opt = DnsOption { - code, - data: vec![0; l as usize], - }; - if off + l as usize > msg.len() { - return Err(Error::ErrCalcLen); - } - opt.data.copy_from_slice(&msg[off..off + l as usize]); - off += l as usize; - opts.push(opt); - } - self.options = opts; - Ok(off) - } -} diff --git a/mdns/src/message/resource/ptr.rs b/mdns/src/message/resource/ptr.rs deleted file mode 100644 index 24d427500..000000000 --- a/mdns/src/message/resource/ptr.rs +++ /dev/null @@ -1,35 +0,0 @@ -use super::*; -use crate::error::Result; -use crate::message::name::*; - -// A PTRResource is a PTR Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct PtrResource { - pub ptr: Name, -} - -impl fmt::Display for PtrResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "dnsmessage.PTRResource{{PTR: {}}}", self.ptr) - } -} - -impl ResourceBody for PtrResource { - fn real_type(&self) -> DnsType { - DnsType::Ptr - } - - // pack appends the wire format of the PTRResource to msg. - fn pack( - &self, - msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - self.ptr.pack(msg, compression, compression_off) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - self.ptr.unpack(msg, off) - } -} diff --git a/mdns/src/message/resource/soa.rs b/mdns/src/message/resource/soa.rs deleted file mode 100644 index ea7d92873..000000000 --- a/mdns/src/message/resource/soa.rs +++ /dev/null @@ -1,80 +0,0 @@ -use super::*; -use crate::error::Result; -use crate::message::name::*; -use crate::message::packer::*; - -// An SOAResource is an SOA Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct SoaResource { - pub ns: Name, - pub mbox: Name, - pub serial: u32, - pub refresh: u32, - pub retry: u32, - pub expire: u32, - - // min_ttl the is the default TTL of Resources records which did not - // contain a TTL value and the TTL of negative responses. (RFC 2308 - // Section 4) - pub min_ttl: u32, -} - -impl fmt::Display for SoaResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.SOAResource{{ns: {}, mbox: {}, serial: {}, refresh: {}, retry: {}, expire: {}, min_ttl: {}}}", - self.ns, - self.mbox, - self.serial, - self.refresh, - self.retry, - self.expire, - self.min_ttl, - ) - } -} - -impl ResourceBody for SoaResource { - fn real_type(&self) -> DnsType { - DnsType::Soa - } - - // pack appends the wire format of the SOAResource to msg. - fn pack( - &self, - mut msg: Vec, - compression: &mut Option>, - compression_off: usize, - ) -> Result> { - msg = self.ns.pack(msg, compression, compression_off)?; - msg = self.mbox.pack(msg, compression, compression_off)?; - msg = pack_uint32(msg, self.serial); - msg = pack_uint32(msg, self.refresh); - msg = pack_uint32(msg, self.retry); - msg = pack_uint32(msg, self.expire); - Ok(pack_uint32(msg, self.min_ttl)) - } - - fn unpack(&mut self, msg: &[u8], mut off: usize, _length: usize) -> Result { - off = self.ns.unpack(msg, off)?; - off = self.mbox.unpack(msg, off)?; - - let (serial, off) = unpack_uint32(msg, off)?; - self.serial = serial; - - let (refresh, off) = unpack_uint32(msg, off)?; - self.refresh = refresh; - - let (retry, off) = unpack_uint32(msg, off)?; - self.retry = retry; - - let (expire, off) = unpack_uint32(msg, off)?; - self.expire = expire; - - let (min_ttl, off) = unpack_uint32(msg, off)?; - self.min_ttl = min_ttl; - - Ok(off) - } -} diff --git a/mdns/src/message/resource/srv.rs b/mdns/src/message/resource/srv.rs deleted file mode 100644 index 5299bb4fe..000000000 --- a/mdns/src/message/resource/srv.rs +++ /dev/null @@ -1,60 +0,0 @@ -use super::*; -use crate::error::Result; -use crate::message::name::*; -use crate::message::packer::*; - -// An SRVResource is an SRV Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct SrvResource { - pub priority: u16, - pub weight: u16, - pub port: u16, - pub target: Name, // Not compressed as per RFC 2782. -} - -impl fmt::Display for SrvResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "dnsmessage.SRVResource{{priority: {}, weight: {}, port: {}, target: {}}}", - self.priority, self.weight, self.port, self.target - ) - } -} - -impl ResourceBody for SrvResource { - fn real_type(&self) -> DnsType { - DnsType::Srv - } - - // pack appends the wire format of the SRVResource to msg. - fn pack( - &self, - mut msg: Vec, - _compression: &mut Option>, - compression_off: usize, - ) -> Result> { - msg = pack_uint16(msg, self.priority); - msg = pack_uint16(msg, self.weight); - msg = pack_uint16(msg, self.port); - msg = self.target.pack(msg, &mut None, compression_off)?; - Ok(msg) - } - - fn unpack(&mut self, msg: &[u8], off: usize, _length: usize) -> Result { - let (priority, off) = unpack_uint16(msg, off)?; - self.priority = priority; - - let (weight, off) = unpack_uint16(msg, off)?; - self.weight = weight; - - let (port, off) = unpack_uint16(msg, off)?; - self.port = port; - - let off = self - .target - .unpack_compressed(msg, off, false /* allowCompression */)?; - - Ok(off) - } -} diff --git a/mdns/src/message/resource/txt.rs b/mdns/src/message/resource/txt.rs deleted file mode 100644 index 64888ff84..000000000 --- a/mdns/src/message/resource/txt.rs +++ /dev/null @@ -1,56 +0,0 @@ -use super::*; -use crate::error::*; -use crate::message::packer::*; - -// A TXTResource is a txt Resource record. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct TxtResource { - pub txt: Vec, -} - -impl fmt::Display for TxtResource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.txt.is_empty() { - write!(f, "dnsmessage.TXTResource{{txt: {{}}}}",) - } else { - write!(f, "dnsmessage.TXTResource{{txt: {{{}}}", self.txt.join(",")) - } - } -} - -impl ResourceBody for TxtResource { - fn real_type(&self) -> DnsType { - DnsType::Txt - } - - // pack appends the wire format of the TXTResource to msg. - fn pack( - &self, - mut msg: Vec, - _compression: &mut Option>, - _compression_off: usize, - ) -> Result> { - for s in &self.txt { - msg = pack_str(msg, s)? - } - Ok(msg) - } - - fn unpack(&mut self, msg: &[u8], mut off: usize, length: usize) -> Result { - let mut txts = vec![]; - let mut n = 0; - while n < length { - let (t, new_off) = unpack_str(msg, off)?; - off = new_off; - // Check if we got too many bytes. - if length < n + t.as_bytes().len() + 1 { - return Err(Error::ErrCalcLen); - } - n += t.len() + 1; - txts.push(t); - } - self.txt = txts; - - Ok(off) - } -} diff --git a/media/.gitignore b/media/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/media/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/media/CHANGELOG.md b/media/CHANGELOG.md deleted file mode 100644 index 0f56fba7f..000000000 --- a/media/CHANGELOG.md +++ /dev/null @@ -1,23 +0,0 @@ -# webrtc-media changelog - -## Unreleased - -## v0.5.0 - -* Improve handling of padding packets in `SampleBuiler`. Prior to this `SampleBuilder` would sometimes, incorrectly, drop packets that carry media when they appeared adjacent to runs of padding packets. Contributed by [@k0nserv](https://github.com/k0nserv) in [#309](https://github.com/webrtc-rs/webrtc/pull/309) -* Increased minimum support rust version to `1.60.0`. -* Increased required `webrtc-util` version to `0.7.0`. - -### Breaking - -* Introduced a new field in `Sample`, `prev_padding_packets`, that reflects the number of observed padding only packets while building the Sample. This can be use to differentiate inconsequential padding packets being dropped from those carrying media. Contributed by [@k0nserv](https://github.com/k0nserv) in [#303](https://github.com/webrtc-rs/webrtc/pull/303). - -## v0.4.7 - -* Bumped util dependency to `0.6.0`. -* Bumped rtp dependency to `0.6.0`. - - -## Prior to 0.4.7 - -Before 0.4.7 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/media/releases). diff --git a/media/Cargo.toml b/media/Cargo.toml deleted file mode 100644 index b89f3d5b9..000000000 --- a/media/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ -[package] -name = "webrtc-media" -version = "0.8.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of WebRTC Media API" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-media" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/media" - -[dependencies] -rtp = { version = "0.11.0", path = "../rtp" } - -byteorder = "1" -bytes = "1" -thiserror = "1" -rand = "0.8" - -[dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } -nearly_eq = "0.2" - -[[bench]] -name = "audio_buffer" -harness = false diff --git a/media/LICENSE-APACHE b/media/LICENSE-APACHE deleted file mode 100644 index b2e847a43..000000000 --- a/media/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/media/LICENSE-MIT b/media/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/media/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/media/README.md b/media/README.md deleted file mode 100644 index ca8d135ac..000000000 --- a/media/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of WebRTC Media. Rewrite Pion MediaDevices in Rust -

diff --git a/media/benches/audio_buffer.rs b/media/benches/audio_buffer.rs deleted file mode 100644 index a7164658d..000000000 --- a/media/benches/audio_buffer.rs +++ /dev/null @@ -1,36 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use webrtc_media::audio::buffer::layout::{Deinterleaved, Interleaved}; -use webrtc_media::audio::buffer::Buffer; - -fn benchmark_from(c: &mut Criterion) { - type Sample = i32; - let channels = 4; - let frames = 100_000; - let deinterleaved_buffer: Buffer = { - let samples = (0..(channels * frames)).map(|i| i as i32).collect(); - Buffer::new(samples, channels) - }; - let interleaved_buffer: Buffer = { - let samples = (0..(channels * frames)).map(|i| i as i32).collect(); - Buffer::new(samples, channels) - }; - - c.bench_function("Buffer => Buffer", |b| { - b.iter(|| { - black_box(Buffer::::from( - deinterleaved_buffer.as_ref(), - )); - }) - }); - - c.bench_function("Buffer => Buffer", |b| { - b.iter(|| { - black_box(Buffer::::from( - interleaved_buffer.as_ref(), - )); - }) - }); -} - -criterion_group!(benches, benchmark_from); -criterion_main!(benches); diff --git a/media/codecov.yml b/media/codecov.yml deleted file mode 100644 index 788310caa..000000000 --- a/media/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 87b3c77b-91fc-48bd-8560-0c9dfdb774e8 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/media/doc/webrtc.rs.png b/media/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/media/doc/webrtc.rs.png and /dev/null differ diff --git a/media/src/audio/buffer.rs b/media/src/audio/buffer.rs deleted file mode 100644 index 07ddb804d..000000000 --- a/media/src/audio/buffer.rs +++ /dev/null @@ -1,443 +0,0 @@ -pub mod info; -pub mod layout; - -use std::mem::{ManuallyDrop, MaybeUninit}; -use std::ops::Range; - -use byteorder::ByteOrder; -pub use info::BufferInfo; -pub use layout::BufferLayout; -use layout::{Deinterleaved, Interleaved}; -use thiserror::Error; - -pub trait FromBytes: Sized { - type Error; - - fn from_bytes(bytes: &[u8], channels: usize) -> Result; -} - -pub trait ToByteBufferRef: Sized { - type Error; - - fn bytes_len(&self); - fn to_bytes( - &self, - bytes: &mut [u8], - channels: usize, - ) -> Result; -} - -#[derive(Debug, Error, PartialEq, Eq)] -pub enum Error { - #[error("Unexpected end of buffer: (expected: {expected}, actual: {actual})")] - UnexpectedEndOfBuffer { expected: usize, actual: usize }, -} - -#[derive(Eq, PartialEq, Clone, Debug)] -pub struct BufferRef<'a, T, L> { - samples: &'a [T], - info: BufferInfo, -} - -impl<'a, T, L> BufferRef<'a, T, L> { - pub fn new(samples: &'a [T], channels: usize) -> Self { - debug_assert_eq!(samples.len() % channels, 0); - let info = { - let frames = samples.len() / channels; - BufferInfo::new(channels, frames) - }; - Self { samples, info } - } -} - -/// Buffer multi-channel interlaced Audio. -#[derive(Eq, PartialEq, Clone, Debug)] -pub struct Buffer { - samples: Vec, - info: BufferInfo, -} - -impl Buffer { - pub fn new(samples: Vec, channels: usize) -> Self { - debug_assert_eq!(samples.len() % channels, 0); - let info = { - let frames = samples.len() / channels; - BufferInfo::new(channels, frames) - }; - Self { samples, info } - } - - pub fn as_ref(&'_ self) -> BufferRef<'_, T, L> { - BufferRef { - samples: &self.samples[..], - info: self.info, - } - } - - pub fn sub_range(&'_ self, range: Range) -> BufferRef<'_, T, L> { - let samples_len = range.len(); - let samples = &self.samples[range]; - let info = { - let channels = self.info.channels(); - assert_eq!(samples_len % channels, 0); - let frames = samples_len / channels; - BufferInfo::new(channels, frames) - }; - BufferRef { samples, info } - } -} - -impl From> for Buffer -where - T: Default + Copy, -{ - fn from(buffer: Buffer) -> Self { - Self::from(buffer.as_ref()) - } -} - -impl<'a, T> From> for Buffer -where - T: Default + Copy, -{ - fn from(buffer: BufferRef<'a, T, Deinterleaved>) -> Self { - // Writing into a vec of uninitialized `samples` is about 10% faster than - // cloning it or creating a default-initialized one and over-writing it. - // - // # Safety - // - // The performance boost comes with a cost though: - // At the end of the block each and every single item in - // `samples` needs to have been initialized, or else you get UB! - let samples = { - // Create a vec of uninitialized samples. - let mut samples: Vec> = - vec![MaybeUninit::uninit(); buffer.samples.len()]; - - // Initialize all of its values: - layout::interleaved_by( - buffer.samples, - &mut samples[..], - buffer.info.channels(), - |sample| MaybeUninit::new(*sample), - ); - - // Transmute the vec to the initialized type. - unsafe { std::mem::transmute::>, Vec>(samples) } - }; - - let info = buffer.info.into(); - Self { samples, info } - } -} - -impl From> for Buffer -where - T: Default + Copy, -{ - fn from(buffer: Buffer) -> Self { - Self::from(buffer.as_ref()) - } -} - -impl<'a, T> From> for Buffer -where - T: Default + Copy, -{ - fn from(buffer: BufferRef<'a, T, Interleaved>) -> Self { - // Writing into a vec of uninitialized `samples` is about 10% faster than - // cloning it or creating a default-initialized one and over-writing it. - // - // # Safety - // - // The performance boost comes with a cost though: - // At the end of the block each and every single item in - // `samples` needs to have been initialized, or else you get UB! - let samples = { - // Create a vec of uninitialized samples. - let mut samples: Vec> = - vec![MaybeUninit::uninit(); buffer.samples.len()]; - - // Initialize the vec's values: - layout::deinterleaved_by( - buffer.samples, - &mut samples[..], - buffer.info.channels(), - |sample| MaybeUninit::new(*sample), - ); - - // Everything is initialized. Transmute the vec to the initialized type. - unsafe { std::mem::transmute::>, Vec>(samples) } - }; - - let info = buffer.info.into(); - Self { samples, info } - } -} - -impl FromBytes for Buffer { - type Error = (); - - fn from_bytes(bytes: &[u8], channels: usize) -> Result { - const STRIDE: usize = std::mem::size_of::(); - assert_eq!(bytes.len() % STRIDE, 0); - - let chunks = { - let chunks_ptr = bytes.as_ptr() as *const [u8; STRIDE]; - let chunks_len = bytes.len() / STRIDE; - unsafe { std::slice::from_raw_parts(chunks_ptr, chunks_len) } - }; - - let samples: Vec<_> = chunks.iter().map(|chunk| B::read_i16(&chunk[..])).collect(); - - let info = { - let frames = samples.len() / channels; - BufferInfo::new(channels, frames) - }; - Ok(Self { samples, info }) - } -} - -impl FromBytes for Buffer { - type Error = (); - - fn from_bytes(bytes: &[u8], channels: usize) -> Result { - const STRIDE: usize = std::mem::size_of::(); - assert_eq!(bytes.len() % STRIDE, 0); - - let chunks = { - let chunks_ptr = bytes.as_ptr() as *const [u8; STRIDE]; - let chunks_len = bytes.len() / STRIDE; - unsafe { std::slice::from_raw_parts(chunks_ptr, chunks_len) } - }; - - // Writing into a vec of uninitialized `samples` is about 10% faster than - // cloning it or creating a default-initialized one and over-writing it. - // - // # Safety - // - // The performance boost comes with a cost though: - // At the end of the block each and every single item in - // `samples` needs to have been initialized, or else you get UB! - let samples = unsafe { - init_vec(chunks.len(), |samples| { - layout::interleaved_by(chunks, samples, channels, |chunk| { - MaybeUninit::new(B::read_i16(&chunk[..])) - }); - }) - }; - - let info = { - let frames = samples.len() / channels; - BufferInfo::new(channels, frames) - }; - Ok(Self { samples, info }) - } -} - -impl FromBytes for Buffer { - type Error = (); - - fn from_bytes(bytes: &[u8], channels: usize) -> Result { - const STRIDE: usize = std::mem::size_of::(); - assert_eq!(bytes.len() % STRIDE, 0); - - let chunks = { - let chunks_ptr = bytes.as_ptr() as *const [u8; STRIDE]; - let chunks_len = bytes.len() / STRIDE; - unsafe { std::slice::from_raw_parts(chunks_ptr, chunks_len) } - }; - - let samples: Vec<_> = chunks.iter().map(|chunk| B::read_i16(&chunk[..])).collect(); - - let info = { - let frames = samples.len() / channels; - BufferInfo::new(channels, frames) - }; - Ok(Self { samples, info }) - } -} - -impl FromBytes for Buffer { - type Error = (); - - fn from_bytes(bytes: &[u8], channels: usize) -> Result { - const STRIDE: usize = std::mem::size_of::(); - assert_eq!(bytes.len() % STRIDE, 0); - - let chunks = { - let chunks_ptr = bytes.as_ptr() as *const [u8; STRIDE]; - let chunks_len = bytes.len() / STRIDE; - unsafe { std::slice::from_raw_parts(chunks_ptr, chunks_len) } - }; - - // Writing into a vec of uninitialized `samples` is about 10% faster than - // cloning it or creating a default-initialized one and over-writing it. - // - // # Safety - // - // The performance boost comes with a cost though: - // At the end of the block each and every single item in - // `samples` needs to have been initialized, or else you get UB! - let samples = unsafe { - init_vec(chunks.len(), |samples| { - layout::deinterleaved_by(chunks, samples, channels, |chunk| { - MaybeUninit::new(B::read_i16(&chunk[..])) - }); - }) - }; - - let info = { - let frames = samples.len() / channels; - BufferInfo::new(channels, frames) - }; - Ok(Self { samples, info }) - } -} - -/// Creates a vec with deferred initialization. -/// -/// # Safety -/// -/// The closure `f` MUST initialize every single item in the provided slice. -unsafe fn init_vec(len: usize, f: F) -> Vec -where - MaybeUninit: Clone, - F: FnOnce(&mut [MaybeUninit]), -{ - // Create a vec of uninitialized values. - let mut vec: Vec> = vec![MaybeUninit::uninit(); len]; - - // Initialize values: - f(&mut vec[..]); - - // Take owner-ship away from `vec`: - let mut manually_drop: ManuallyDrop<_> = ManuallyDrop::new(vec); - - // Create vec of proper type from `vec`'s raw parts. - let ptr = manually_drop.as_mut_ptr() as *mut T; - let len = manually_drop.len(); - let cap = manually_drop.capacity(); - Vec::from_raw_parts(ptr, len, cap) -} - -#[cfg(test)] -mod tests { - use byteorder::NativeEndian; - - use super::*; - - #[test] - fn deinterleaved_from_interleaved() { - let channels = 3; - - let input_samples: Vec = vec![0, 5, 10, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14]; - let input: Buffer = Buffer::new(input_samples, channels); - - let output = Buffer::::from(input); - - let actual = output.samples; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } - - #[test] - fn interleaved_from_deinterleaved() { - let channels = 3; - - let input_samples: Vec = vec![0, 3, 6, 9, 12, 1, 4, 7, 10, 13, 2, 5, 8, 11, 14]; - let input: Buffer = Buffer::new(input_samples, channels); - - let output = Buffer::::from(input); - - let actual = output.samples; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } - - #[test] - fn deinterleaved_from_deinterleaved_bytes() { - let channels = 3; - let stride = 2; - - let input_samples: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - let input_bytes: &[u8] = { - let bytes_ptr = input_samples.as_ptr() as *const u8; - let bytes_len = input_samples.len() * stride; - unsafe { std::slice::from_raw_parts(bytes_ptr, bytes_len) } - }; - - let output: Buffer = - FromBytes::::from_bytes::(input_bytes, channels).unwrap(); - - let actual = output.samples; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } - - #[test] - fn deinterleaved_from_interleaved_bytes() { - let channels = 3; - let stride = 2; - - let input_samples: Vec = vec![0, 5, 10, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14]; - let input_bytes: &[u8] = { - let bytes_ptr = input_samples.as_ptr() as *const u8; - let bytes_len = input_samples.len() * stride; - unsafe { std::slice::from_raw_parts(bytes_ptr, bytes_len) } - }; - - let output: Buffer = - FromBytes::::from_bytes::(input_bytes, channels).unwrap(); - - let actual = output.samples; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } - - #[test] - fn interleaved_from_interleaved_bytes() { - let channels = 3; - let stride = 2; - - let input_samples: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - let input_bytes: &[u8] = { - let bytes_ptr = input_samples.as_ptr() as *const u8; - let bytes_len = input_samples.len() * stride; - unsafe { std::slice::from_raw_parts(bytes_ptr, bytes_len) } - }; - - let output: Buffer = - FromBytes::::from_bytes::(input_bytes, channels).unwrap(); - - let actual = output.samples; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } - - #[test] - fn interleaved_from_deinterleaved_bytes() { - let channels = 3; - let stride = 2; - - let input_samples: Vec = vec![0, 3, 6, 9, 12, 1, 4, 7, 10, 13, 2, 5, 8, 11, 14]; - let input_bytes: &[u8] = { - let bytes_ptr = input_samples.as_ptr() as *const u8; - let bytes_len = input_samples.len() * stride; - unsafe { std::slice::from_raw_parts(bytes_ptr, bytes_len) } - }; - - let output: Buffer = - FromBytes::::from_bytes::(input_bytes, channels).unwrap(); - - let actual = output.samples; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } -} diff --git a/media/src/audio/buffer/info.rs b/media/src/audio/buffer/info.rs deleted file mode 100644 index bd70e12de..000000000 --- a/media/src/audio/buffer/info.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::marker::PhantomData; - -use crate::audio::buffer::layout::{Deinterleaved, Interleaved}; - -#[derive(Eq, PartialEq, Debug)] -pub struct BufferInfo { - channels: usize, - frames: usize, - _phantom: PhantomData, -} - -impl BufferInfo { - pub fn new(channels: usize, frames: usize) -> Self { - Self { - channels, - frames, - _phantom: PhantomData, - } - } - - /// Get a reference to the buffer info's channels. - pub fn channels(&self) -> usize { - self.channels - } - - /// Set the buffer info's channels. - pub fn set_channels(&mut self, channels: usize) { - self.channels = channels; - } - - /// Get a reference to the buffer info's frames. - pub fn frames(&self) -> usize { - self.frames - } - - /// Set the buffer info's frames. - pub fn set_frames(&mut self, frames: usize) { - self.frames = frames; - } - - pub fn samples(&self) -> usize { - self.channels * self.frames - } -} - -impl Copy for BufferInfo {} - -impl Clone for BufferInfo { - fn clone(&self) -> Self { - *self - } -} - -macro_rules! impl_from_buffer_info { - ($in_layout:ty => $out_layout:ty) => { - impl From> for BufferInfo<$out_layout> { - fn from(info: BufferInfo<$in_layout>) -> Self { - Self { - channels: info.channels, - frames: info.frames, - _phantom: PhantomData, - } - } - } - }; -} - -impl_from_buffer_info!(Interleaved => Deinterleaved); -impl_from_buffer_info!(Deinterleaved => Interleaved); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn new() { - let channels = 3; - let frames = 100; - - let interleaved = BufferInfo::::new(channels, frames); - - assert_eq!(interleaved.channels, channels); - assert_eq!(interleaved.frames, frames); - - let deinterleaved = BufferInfo::::new(channels, frames); - - assert_eq!(deinterleaved.channels, channels); - assert_eq!(deinterleaved.frames, frames); - } - - #[test] - fn clone() { - let channels = 3; - let frames = 100; - - let interleaved = BufferInfo::::new(channels, frames); - - assert_eq!(interleaved.clone(), interleaved); - - let deinterleaved = BufferInfo::::new(channels, frames); - - assert_eq!(deinterleaved.clone(), deinterleaved); - } - - #[test] - fn samples() { - let channels = 3; - let frames = 100; - - let interleaved = BufferInfo::::new(channels, frames); - - assert_eq!(interleaved.samples(), channels * frames); - - let deinterleaved = BufferInfo::::new(channels, frames); - - assert_eq!(deinterleaved.samples(), channels * frames); - } -} diff --git a/media/src/audio/buffer/layout.rs b/media/src/audio/buffer/layout.rs deleted file mode 100644 index d26f1fe78..000000000 --- a/media/src/audio/buffer/layout.rs +++ /dev/null @@ -1,179 +0,0 @@ -use crate::audio::buffer::BufferInfo; -use crate::audio::sealed::Sealed; - -pub trait BufferLayout: Sized + Sealed { - fn index_of(info: &BufferInfo, channel: usize, frame: usize) -> usize; -} - -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -pub enum Deinterleaved {} - -impl Sealed for Deinterleaved {} - -impl BufferLayout for Deinterleaved { - #[inline] - fn index_of(info: &BufferInfo, channel: usize, frame: usize) -> usize { - (channel * info.frames()) + frame - } -} - -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -pub enum Interleaved {} - -impl Sealed for Interleaved {} - -impl BufferLayout for Interleaved { - #[inline] - fn index_of(info: &BufferInfo, channel: usize, frame: usize) -> usize { - (frame * info.channels()) + channel - } -} - -#[cfg(test)] -#[inline(always)] -pub(crate) fn deinterleaved(input: &[T], output: &mut [T], channels: usize) -where - T: Copy, -{ - deinterleaved_by(input, output, channels, |sample| *sample) -} - -/// De-interleaves an interleaved slice using a memory access pattern -/// that's optimized for efficient cached (i.e. sequential) reads. -pub(crate) fn deinterleaved_by(input: &[T], output: &mut [U], channels: usize, f: F) -where - F: Fn(&T) -> U, -{ - assert_eq!(input.len(), output.len()); - assert_eq!(input.len() % channels, 0); - - let frames = input.len() / channels; - let mut interleaved_index = 0; - for frame in 0..frames { - let mut deinterleaved_index = frame; - for _channel in 0..channels { - output[deinterleaved_index] = f(&input[interleaved_index]); - interleaved_index += 1; - deinterleaved_index += frames; - } - } -} - -#[cfg(test)] -#[inline(always)] -pub(crate) fn interleaved(input: &[T], output: &mut [T], channels: usize) -where - T: Copy, -{ - interleaved_by(input, output, channels, |sample| *sample) -} - -/// Interleaves an de-interleaved slice using a memory access pattern -/// that's optimized for efficient cached (i.e. sequential) reads. -pub(crate) fn interleaved_by(input: &[T], output: &mut [U], channels: usize, f: F) -where - F: Fn(&T) -> U, -{ - assert_eq!(input.len(), output.len()); - assert_eq!(input.len() % channels, 0); - - let frames = input.len() / channels; - let mut deinterleaved_index = 0; - for channel in 0..channels { - let mut interleaved_index = channel; - for _frame in 0..frames { - output[interleaved_index] = f(&input[deinterleaved_index]); - deinterleaved_index += 1; - interleaved_index += channels; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn interleaved_1_channel() { - let input: Vec<_> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - let mut output = vec![0; input.len()]; - let channels = 1; - - interleaved(&input[..], &mut output[..], channels); - - let actual = output; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - - assert_eq!(actual, expected); - } - - #[test] - fn deinterleaved_1_channel() { - let input: Vec<_> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - let mut output = vec![0; input.len()]; - let channels = 1; - - deinterleaved(&input[..], &mut output[..], channels); - - let actual = output; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - - assert_eq!(actual, expected); - } - - #[test] - fn interleaved_2_channel() { - let input: Vec<_> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - let mut output = vec![0; input.len()]; - let channels = 2; - - interleaved(&input[..], &mut output[..], channels); - - let actual = output; - let expected = vec![0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15]; - - assert_eq!(actual, expected); - } - - #[test] - fn deinterleaved_2_channel() { - let input: Vec<_> = vec![0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15]; - let mut output = vec![0; input.len()]; - let channels = 2; - - deinterleaved(&input[..], &mut output[..], channels); - - let actual = output; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - - assert_eq!(actual, expected); - } - - #[test] - fn interleaved_3_channel() { - let input: Vec<_> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - let mut output = vec![0; input.len()]; - let channels = 3; - - interleaved(&input[..], &mut output[..], channels); - - let actual = output; - let expected = vec![0, 5, 10, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14]; - - assert_eq!(actual, expected); - } - - #[test] - fn deinterleaved_3_channel() { - let input: Vec<_> = vec![0, 5, 10, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14]; - let mut output = vec![0; input.len()]; - let channels = 3; - - deinterleaved(&input[..], &mut output[..], channels); - - let actual = output; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - - assert_eq!(actual, expected); - } -} diff --git a/media/src/audio/mod.rs b/media/src/audio/mod.rs deleted file mode 100644 index e259ae9c4..000000000 --- a/media/src/audio/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod buffer; -mod sample; - -pub use sample::Sample; - -mod sealed { - pub trait Sealed {} -} diff --git a/media/src/audio/sample.rs b/media/src/audio/sample.rs deleted file mode 100644 index a98910a1e..000000000 --- a/media/src/audio/sample.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::io::{Cursor, Read}; - -use byteorder::{ByteOrder, ReadBytesExt}; -#[cfg(test)] -use nearly_eq::NearlyEq; - -#[derive(Eq, PartialEq, Copy, Clone, Default, Debug)] -#[repr(transparent)] -pub struct Sample(Raw); - -impl From for Sample { - #[inline] - fn from(raw: i16) -> Self { - Self(raw) - } -} - -impl From for Sample { - #[inline] - fn from(raw: f32) -> Self { - Self(raw.clamp(-1.0, 1.0)) - } -} - -macro_rules! impl_from_sample_for_raw { - ($raw:ty) => { - impl From> for $raw { - #[inline] - fn from(sample: Sample<$raw>) -> $raw { - sample.0 - } - } - }; -} - -impl_from_sample_for_raw!(i16); -impl_from_sample_for_raw!(f32); - -// impl From> for Sample { -// #[inline] -// fn from(sample: Sample) -> Self { -// // Fast but imprecise approach: -// // Perform crude but fast upsample by bit-shifting the raw value: -// Self::from((sample.0 as i64) << 16) - -// // Slow but precise approach: -// // Perform a proper but expensive lerp from -// // i16::MIN..i16::MAX to i32::MIN..i32::MAX: - -// // let value = sample.0 as i64; - -// // let from = if value <= 0 { i16::MIN } else { i16::MAX } as i64; -// // let to = if value <= 0 { i32::MIN } else { i32::MAX } as i64; - -// // Self::from((value * to + from / 2) / from) -// } -// } - -impl From> for Sample { - #[inline] - fn from(sample: Sample) -> Self { - let divisor = if sample.0 < 0 { - i16::MIN as f32 - } else { - i16::MAX as f32 - } - .abs(); - Self::from((sample.0 as f32) / divisor) - } -} - -impl From> for Sample { - #[inline] - fn from(sample: Sample) -> Self { - let multiplier = if sample.0 < 0.0 { - i16::MIN as f32 - } else { - i16::MAX as f32 - } - .abs(); - Self::from((sample.0 * multiplier) as i16) - } -} - -trait FromBytes: Sized { - fn from_reader(reader: &mut R) -> Result; - - fn from_bytes(bytes: &[u8]) -> Result { - let mut cursor = Cursor::new(bytes); - Self::from_reader::(&mut cursor) - } -} - -impl FromBytes for Sample { - fn from_reader(reader: &mut R) -> Result { - reader.read_i16::().map(Self::from) - } -} - -impl FromBytes for Sample { - fn from_reader(reader: &mut R) -> Result { - reader.read_f32::().map(Self::from) - } -} - -#[cfg(test)] -impl NearlyEq for Sample -where - Raw: NearlyEq, -{ - fn eps() -> Raw { - Raw::eps() - } - - fn eq(&self, other: &Self, eps: &Raw) -> bool { - NearlyEq::eq(&self.0, &other.0, eps) - } -} - -#[cfg(test)] -mod tests { - use nearly_eq::assert_nearly_eq; - - use super::*; - - #[test] - fn sample_i16_from_i16() { - // i16: - assert_eq!(Sample::::from(i16::MIN).0, i16::MIN); - assert_eq!(Sample::::from(i16::MIN / 2).0, i16::MIN / 2); - assert_eq!(Sample::::from(0).0, 0); - assert_eq!(Sample::::from(i16::MAX / 2).0, i16::MAX / 2); - assert_eq!(Sample::::from(i16::MAX).0, i16::MAX); - } - - #[test] - fn sample_f32_from_f32() { - assert_eq!(Sample::::from(-1.0).0, -1.0); - assert_eq!(Sample::::from(-0.5).0, -0.5); - assert_eq!(Sample::::from(0.0).0, 0.0); - assert_eq!(Sample::::from(0.5).0, 0.5); - assert_eq!(Sample::::from(1.0).0, 1.0); - - // For any values outside of -1.0..=1.0 we expect clamping: - assert_eq!(Sample::::from(f32::MIN).0, -1.0); - assert_eq!(Sample::::from(f32::MAX).0, 1.0); - } - - #[test] - fn sample_i16_from_sample_f32() { - assert_nearly_eq!( - Sample::::from(Sample::::from(-1.0)), - Sample::from(i16::MIN) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(-0.5)), - Sample::from(i16::MIN / 2) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(0.0)), - Sample::from(0) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(0.5)), - Sample::from(i16::MAX / 2) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(1.0)), - Sample::from(i16::MAX) - ); - } - - #[test] - fn sample_f32_from_sample_i16() { - assert_nearly_eq!( - Sample::::from(Sample::::from(i16::MIN)), - Sample::from(-1.0) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(i16::MIN / 2)), - Sample::from(-0.5) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(0)), - Sample::from(0.0) - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(i16::MAX / 2)), - Sample::from(0.5), - 0.0001 // rounding error due to i16::MAX being odd - ); - assert_nearly_eq!( - Sample::::from(Sample::::from(i16::MAX)), - Sample::from(1.0) - ); - } -} diff --git a/media/src/error.rs b/media/src/error.rs deleted file mode 100644 index 3c5272a4a..000000000 --- a/media/src/error.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::io; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("stream is nil")] - ErrNilStream, - #[error("incomplete frame header")] - ErrIncompleteFrameHeader, - #[error("incomplete frame data")] - ErrIncompleteFrameData, - #[error("incomplete file header")] - ErrIncompleteFileHeader, - #[error("IVF signature mismatch")] - ErrSignatureMismatch, - #[error("IVF version unknown, parser may not parse correctly")] - ErrUnknownIVFVersion, - - #[error("file not opened")] - ErrFileNotOpened, - #[error("invalid nil packet")] - ErrInvalidNilPacket, - - #[error("bad header signature")] - ErrBadIDPageSignature, - #[error("wrong header, expected beginning of stream")] - ErrBadIDPageType, - #[error("payload for id page must be 19 bytes")] - ErrBadIDPageLength, - #[error("bad payload signature")] - ErrBadIDPagePayloadSignature, - #[error("not enough data for payload header")] - ErrShortPageHeader, - #[error("expected and actual checksum do not match")] - ErrChecksumMismatch, - - #[error("data is not a H264 bitstream")] - ErrDataIsNotH264Stream, - #[error("Io EOF")] - ErrIoEOF, - - #[allow(non_camel_case_types)] - #[error("{0}")] - Io(#[source] IoError), - #[error("{0}")] - Rtp(#[from] rtp::Error), - - #[error("{0}")] - Other(String), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} diff --git a/media/src/io/h264_reader/h264_reader_test.rs b/media/src/io/h264_reader/h264_reader_test.rs deleted file mode 100644 index 5bbb3ff7a..000000000 --- a/media/src/io/h264_reader/h264_reader_test.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::io::Cursor; - -use super::*; - -#[test] -fn test_data_does_not_start_with_h264header() -> Result<()> { - let test_function = |input: &[u8]| { - let mut reader = H264Reader::new(Cursor::new(input), 1_048_576); - if let Err(err) = reader.next_nal() { - assert_eq!(err, Error::ErrDataIsNotH264Stream); - } else { - panic!(); - } - }; - - test_function(&[2]); - test_function(&[0, 2]); - test_function(&[0, 0, 2]); - test_function(&[0, 0, 2, 0]); - test_function(&[0, 0, 0, 2]); - - Ok(()) -} - -#[test] -fn test_parse_header() -> Result<()> { - let h264bytes = &[0x0, 0x0, 0x1, 0xAB]; - let mut reader = H264Reader::new(Cursor::new(h264bytes), 1_048_576); - - let nal = reader.next_nal()?; - - assert_eq!(nal.data.len(), 1); - assert!(nal.forbidden_zero_bit); - assert_eq!(nal.picture_order_count, 0); - assert_eq!(nal.ref_idc, 1); - assert_eq!(NalUnitType::EndOfStream, nal.unit_type); - - Ok(()) -} - -#[test] -fn test_eof() -> Result<()> { - let test_function = |input: &[u8]| { - let mut reader = H264Reader::new(Cursor::new(input), 1_048_576); - if let Err(err) = reader.next_nal() { - assert_eq!(Error::ErrIoEOF, err); - } else { - panic!(); - } - }; - - test_function(&[0, 0, 0, 1]); - test_function(&[0, 0, 1]); - test_function(&[]); - - Ok(()) -} - -#[test] -fn test_skip_sei() -> Result<()> { - let h264bytes = &[ - 0x0, 0x0, 0x0, 0x1, 0xAA, 0x0, 0x0, 0x0, 0x1, 0x6, // SEI - 0x0, 0x0, 0x0, 0x1, 0xAB, - ]; - - let mut reader = H264Reader::new(Cursor::new(h264bytes), 1_048_576); - - let nal = reader.next_nal()?; - assert_eq!(nal.data[0], 0xAA); - - let nal = reader.next_nal()?; - assert_eq!(nal.data[0], 0xAB); - - Ok(()) -} - -#[test] -fn test_issue1734_next_nal() -> Result<()> { - let tests: Vec<&[u8]> = vec![ - &[0x00, 0x00, 0x010, 0x00, 0x00, 0x01, 0x00, 0x00, 0x01], - &[0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x01], - ]; - - for test in tests { - let mut reader = H264Reader::new(Cursor::new(test), 1_048_576); - - // Just make sure it doesn't crash - while reader.next_nal().is_ok() { - //do nothing - } - } - - Ok(()) -} - -#[test] -fn test_trailing01after_start_code() -> Result<()> { - let test = vec![0x0, 0x0, 0x0, 0x1, 0x01, 0x0, 0x0, 0x0, 0x1, 0x01]; - let mut r = H264Reader::new(Cursor::new(test), 1_048_576); - - for _ in 0..=1 { - let _nal = r.next_nal()?; - } - - Ok(()) -} diff --git a/media/src/io/h264_reader/mod.rs b/media/src/io/h264_reader/mod.rs deleted file mode 100644 index 7b2b974b8..000000000 --- a/media/src/io/h264_reader/mod.rs +++ /dev/null @@ -1,334 +0,0 @@ -#[cfg(test)] -mod h264_reader_test; - -use std::fmt; -use std::io::Read; - -use bytes::{BufMut, BytesMut}; - -use crate::error::{Error, Result}; - -/// NalUnitType is the type of a NAL -/// Enums for NalUnitTypes -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum NalUnitType { - /// Unspecified - #[default] - Unspecified = 0, - /// Coded slice of a non-IDR picture - CodedSliceNonIdr = 1, - /// Coded slice data partition A - CodedSliceDataPartitionA = 2, - /// Coded slice data partition B - CodedSliceDataPartitionB = 3, - /// Coded slice data partition C - CodedSliceDataPartitionC = 4, - /// Coded slice of an IDR picture - CodedSliceIdr = 5, - /// Supplemental enhancement information (SEI) - SEI = 6, - /// Sequence parameter set - SPS = 7, - /// Picture parameter set - PPS = 8, - /// Access unit delimiter - AUD = 9, - /// End of sequence - EndOfSequence = 10, - /// End of stream - EndOfStream = 11, - /// Filler data - Filler = 12, - /// Sequence parameter set extension - SpsExt = 13, - /// Coded slice of an auxiliary coded picture without partitioning - CodedSliceAux = 19, - ///Reserved - Reserved, - // 14..18 // Reserved - // 20..23 // Reserved - // 24..31 // Unspecified -} - -impl fmt::Display for NalUnitType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - NalUnitType::Unspecified => "Unspecified", - NalUnitType::CodedSliceNonIdr => "CodedSliceNonIdr", - NalUnitType::CodedSliceDataPartitionA => "CodedSliceDataPartitionA", - NalUnitType::CodedSliceDataPartitionB => "CodedSliceDataPartitionB", - NalUnitType::CodedSliceDataPartitionC => "CodedSliceDataPartitionC", - NalUnitType::CodedSliceIdr => "CodedSliceIdr", - NalUnitType::SEI => "SEI", - NalUnitType::SPS => "SPS", - NalUnitType::PPS => "PPS", - NalUnitType::AUD => "AUD", - NalUnitType::EndOfSequence => "EndOfSequence", - NalUnitType::EndOfStream => "EndOfStream", - NalUnitType::Filler => "Filler", - NalUnitType::SpsExt => "SpsExt", - NalUnitType::CodedSliceAux => "NalUnitTypeCodedSliceAux", - _ => "Reserved", - }; - write!(f, "{}({})", s, *self as u8) - } -} - -impl From for NalUnitType { - fn from(v: u8) -> Self { - match v { - 0 => NalUnitType::Unspecified, - 1 => NalUnitType::CodedSliceNonIdr, - 2 => NalUnitType::CodedSliceDataPartitionA, - 3 => NalUnitType::CodedSliceDataPartitionB, - 4 => NalUnitType::CodedSliceDataPartitionC, - 5 => NalUnitType::CodedSliceIdr, - 6 => NalUnitType::SEI, - 7 => NalUnitType::SPS, - 8 => NalUnitType::PPS, - 9 => NalUnitType::AUD, - 10 => NalUnitType::EndOfSequence, - 11 => NalUnitType::EndOfStream, - 12 => NalUnitType::Filler, - 13 => NalUnitType::SpsExt, - 19 => NalUnitType::CodedSliceAux, - _ => NalUnitType::Reserved, - } - } -} - -/// NAL H.264 Network Abstraction Layer -pub struct NAL { - pub picture_order_count: u32, - - /// NAL header - pub forbidden_zero_bit: bool, - pub ref_idc: u8, - pub unit_type: NalUnitType, - - /// header byte + rbsp - pub data: BytesMut, -} - -impl NAL { - fn new(data: BytesMut) -> Self { - NAL { - picture_order_count: 0, - forbidden_zero_bit: false, - ref_idc: 0, - unit_type: NalUnitType::Unspecified, - data, - } - } - - fn parse_header(&mut self) { - let first_byte = self.data[0]; - self.forbidden_zero_bit = ((first_byte & 0x80) >> 7) == 1; // 0x80 = 0b10000000 - self.ref_idc = (first_byte & 0x60) >> 5; // 0x60 = 0b01100000 - self.unit_type = NalUnitType::from(first_byte & 0x1F); // 0x1F = 0b00011111 - } -} - -const NAL_PREFIX_3BYTES: [u8; 3] = [0, 0, 1]; -const NAL_PREFIX_4BYTES: [u8; 4] = [0, 0, 0, 1]; - -/// Wrapper class around reading buffer -struct ReadBuffer { - buffer: Box<[u8]>, - read_end: usize, - filled_end: usize, -} - -impl ReadBuffer { - fn new(capacity: usize) -> ReadBuffer { - Self { - buffer: vec![0u8; capacity].into_boxed_slice(), - read_end: 0, - filled_end: 0, - } - } - - #[inline] - fn in_buffer(&self) -> usize { - self.filled_end - self.read_end - } - - fn consume(&mut self, consume: usize) -> &[u8] { - debug_assert!(self.read_end + consume <= self.filled_end); - let result = &self.buffer[self.read_end..][..consume]; - self.read_end += consume; - result - } - - pub(crate) fn fill_buffer(&mut self, reader: &mut impl Read) -> Result<()> { - debug_assert_eq!(self.read_end, self.filled_end); - - self.read_end = 0; - self.filled_end = reader.read(&mut self.buffer)?; - - Ok(()) - } -} - -/// H264Reader reads data from stream and constructs h264 nal units -pub struct H264Reader { - reader: R, - // reading buffers - buffer: ReadBuffer, - // for reading - nal_prefix_parsed: bool, - count_of_consecutive_zero_bytes: usize, - nal_buffer: BytesMut, -} - -impl H264Reader { - /// new creates new `H264Reader` with `capacity` sized read buffer. - pub fn new(reader: R, capacity: usize) -> H264Reader { - H264Reader { - reader, - nal_prefix_parsed: false, - buffer: ReadBuffer::new(capacity), - count_of_consecutive_zero_bytes: 0, - nal_buffer: BytesMut::new(), - } - } - - fn read4(&mut self) -> Result<([u8; 4], usize)> { - let mut result = [0u8; 4]; - let mut result_filled = 0; - loop { - let in_buffer = self.buffer.in_buffer(); - - if in_buffer + result_filled >= 4 { - let consume = 4 - result_filled; - result[result_filled..].copy_from_slice(self.buffer.consume(consume)); - return Ok((result, 4)); - } - - result[result_filled..][..in_buffer].copy_from_slice(self.buffer.consume(in_buffer)); - result_filled += in_buffer; - - self.buffer.fill_buffer(&mut self.reader)?; - - if self.buffer.in_buffer() == 0 { - return Ok((result, result_filled)); - } - } - } - - fn read1(&mut self) -> Result> { - if self.buffer.in_buffer() == 0 { - self.buffer.fill_buffer(&mut self.reader)?; - - if self.buffer.in_buffer() == 0 { - return Ok(None); - } - } - - Ok(Some(self.buffer.consume(1)[0])) - } - - fn bit_stream_starts_with_h264prefix(&mut self) -> Result { - let (prefix_buffer, n) = self.read4()?; - - if n == 0 { - return Err(Error::ErrIoEOF); - } - - if n < 3 { - return Err(Error::ErrDataIsNotH264Stream); - } - - let nal_prefix3bytes_found = NAL_PREFIX_3BYTES[..] == prefix_buffer[..3]; - if n == 3 { - if nal_prefix3bytes_found { - return Err(Error::ErrIoEOF); - } - return Err(Error::ErrDataIsNotH264Stream); - } - - // n == 4 - if nal_prefix3bytes_found { - self.nal_buffer.put_u8(prefix_buffer[3]); - return Ok(3); - } - - let nal_prefix4bytes_found = NAL_PREFIX_4BYTES[..] == prefix_buffer; - if nal_prefix4bytes_found { - Ok(4) - } else { - Err(Error::ErrDataIsNotH264Stream) - } - } - - /// next_nal reads from stream and returns then next NAL, - /// and an error if there is incomplete frame data. - /// Returns all nil values when no more NALs are available. - pub fn next_nal(&mut self) -> Result { - if !self.nal_prefix_parsed { - self.bit_stream_starts_with_h264prefix()?; - - self.nal_prefix_parsed = true; - } - - loop { - let Some(read_byte) = self.read1()? else { - break; - }; - - let nal_found = self.process_byte(read_byte); - if nal_found { - let nal_unit_type = NalUnitType::from(self.nal_buffer[0] & 0x1F); - if nal_unit_type == NalUnitType::SEI { - self.nal_buffer.clear(); - continue; - } else { - break; - } - } - - self.nal_buffer.put_u8(read_byte); - } - - if self.nal_buffer.is_empty() { - return Err(Error::ErrIoEOF); - } - - let mut nal = NAL::new(self.nal_buffer.split()); - nal.parse_header(); - - Ok(nal) - } - - fn process_byte(&mut self, read_byte: u8) -> bool { - let mut nal_found = false; - - match read_byte { - 0 => { - self.count_of_consecutive_zero_bytes += 1; - } - 1 => { - if self.count_of_consecutive_zero_bytes >= 2 { - let count_of_consecutive_zero_bytes_in_prefix = - if self.count_of_consecutive_zero_bytes > 2 { - 3 - } else { - 2 - }; - let nal_unit_length = - self.nal_buffer.len() - count_of_consecutive_zero_bytes_in_prefix; - if nal_unit_length > 0 { - let _ = self.nal_buffer.split_off(nal_unit_length); - nal_found = true; - } - } - self.count_of_consecutive_zero_bytes = 0; - } - _ => { - self.count_of_consecutive_zero_bytes = 0; - } - } - - nal_found - } -} diff --git a/media/src/io/h264_writer/h264_writer_test.rs b/media/src/io/h264_writer/h264_writer_test.rs deleted file mode 100644 index 45e4bbc0f..000000000 --- a/media/src/io/h264_writer/h264_writer_test.rs +++ /dev/null @@ -1,127 +0,0 @@ -use std::io::Cursor; - -use bytes::Bytes; - -use super::*; - -#[test] -fn test_is_key_frame() -> Result<()> { - let tests = vec![ - ( - "When given a non-keyframe; it should return false", - vec![0x27, 0x90, 0x90], - false, - ), - ( - "When given a SPS packetized with STAP-A;; it should return true", - vec![ - 0x38, 0x00, 0x03, 0x27, 0x90, 0x90, 0x00, 0x05, 0x28, 0x90, 0x90, 0x90, 0x90, - ], - true, - ), - ( - "When given a SPS with no packetization; it should return true", - vec![0x27, 0x90, 0x90, 0x00], - true, - ), - ]; - - for (name, payload, want) in tests { - let got = is_key_frame(&payload); - assert_eq!(got, want, "{name} failed"); - } - - Ok(()) -} - -#[test] -fn test_write_rtp() -> Result<()> { - let tests = vec![ - ( - "When given an empty payload; it should return nil", - vec![], - false, - vec![], - false, - ), - ( - "When no keyframe is defined; it should discard the packet", - vec![0x25, 0x90, 0x90], - false, - vec![], - false, - ), - ( - "When a valid Single NAL Unit packet is given; it should unpack it without error", - vec![0x27, 0x90, 0x90], - true, - vec![0x00, 0x00, 0x00, 0x01, 0x27, 0x90, 0x90], - false, - ), - ( - "When a valid STAP-A packet is given; it should unpack it without error", - vec![ - 0x38, 0x00, 0x03, 0x27, 0x90, 0x90, 0x00, 0x05, 0x28, 0x90, 0x90, 0x90, 0x90, - ], - true, - vec![ - 0x00, 0x00, 0x00, 0x01, 0x27, 0x90, 0x90, 0x00, 0x00, 0x00, 0x01, 0x28, 0x90, 0x90, - 0x90, 0x90, - ], - false, - ), - ]; - - for (_name, payload, has_key_frame, want_bytes, _reuse) in tests { - let mut writer = vec![]; - { - let w = Cursor::new(&mut writer); - let mut h264writer = H264Writer::new(w); - h264writer.has_key_frame = has_key_frame; - - let packet = rtp::packet::Packet { - payload: Bytes::from(payload), - ..Default::default() - }; - - h264writer.write_rtp(&packet)?; - h264writer.close()?; - } - - assert_eq!(writer, want_bytes); - } - - Ok(()) -} - -#[test] -fn test_write_rtp_fu() -> Result<()> { - let tests = vec![ - vec![0x3C, 0x85, 0x90, 0x90, 0x90], - vec![0x3C, 0x45, 0x90, 0x90, 0x90], - ]; - - let want_bytes = vec![ - 0x00, 0x00, 0x00, 0x01, 0x25, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, - ]; - - let mut writer = vec![]; - { - let w = Cursor::new(&mut writer); - let mut h264writer = H264Writer::new(w); - h264writer.has_key_frame = true; - - for payload in tests { - let packet = rtp::packet::Packet { - payload: Bytes::from(payload), - ..Default::default() - }; - - h264writer.write_rtp(&packet)?; - } - h264writer.close()?; - } - assert_eq!(writer, want_bytes); - - Ok(()) -} diff --git a/media/src/io/h264_writer/mod.rs b/media/src/io/h264_writer/mod.rs deleted file mode 100644 index a79db29a6..000000000 --- a/media/src/io/h264_writer/mod.rs +++ /dev/null @@ -1,83 +0,0 @@ -#[cfg(test)] -mod h264_writer_test; - -use std::io::{Seek, Write}; - -use rtp::codecs::h264::H264Packet; -use rtp::packetizer::Depacketizer; - -use crate::error::Result; -use crate::io::Writer; - -const NALU_TTYPE_STAP_A: u32 = 24; -const NALU_TTYPE_SPS: u32 = 7; -const NALU_TYPE_BITMASK: u32 = 0x1F; - -fn is_key_frame(data: &[u8]) -> bool { - if data.len() < 4 { - false - } else { - let word = u32::from_be_bytes([data[0], data[1], data[2], data[3]]); - let nalu_type = (word >> 24) & NALU_TYPE_BITMASK; - (nalu_type == NALU_TTYPE_STAP_A && (word & NALU_TYPE_BITMASK) == NALU_TTYPE_SPS) - || (nalu_type == NALU_TTYPE_SPS) - } -} - -/// H264Writer is used to take RTP packets, parse them and -/// write the data to an io.Writer. -/// Currently it only supports non-interleaved mode -/// Therefore, only 1-23, 24 (STAP-A), 28 (FU-A) NAL types are allowed. -/// -pub struct H264Writer { - writer: W, - has_key_frame: bool, - cached_packet: Option, -} - -impl H264Writer { - // new initializes a new H264 writer with an io.Writer output - pub fn new(writer: W) -> Self { - H264Writer { - writer, - has_key_frame: false, - cached_packet: None, - } - } -} - -impl Writer for H264Writer { - /// write_rtp adds a new packet and writes the appropriate headers for it - fn write_rtp(&mut self, packet: &rtp::packet::Packet) -> Result<()> { - if packet.payload.is_empty() { - return Ok(()); - } - - if !self.has_key_frame { - self.has_key_frame = is_key_frame(&packet.payload); - if !self.has_key_frame { - // key frame not defined yet. discarding packet - return Ok(()); - } - } - - if self.cached_packet.is_none() { - self.cached_packet = Some(H264Packet::default()); - } - - if let Some(cached_packet) = &mut self.cached_packet { - let payload = cached_packet.depacketize(&packet.payload)?; - - self.writer.write_all(&payload)?; - } - - Ok(()) - } - - /// close closes the underlying writer - fn close(&mut self) -> Result<()> { - self.cached_packet = None; - self.writer.flush()?; - Ok(()) - } -} diff --git a/media/src/io/ivf_reader/ivf_reader_test.rs b/media/src/io/ivf_reader/ivf_reader_test.rs deleted file mode 100644 index 4cc2d6836..000000000 --- a/media/src/io/ivf_reader/ivf_reader_test.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::io::BufReader; - -use bytes::Bytes; - -use super::*; - -/// build_ivf_container takes frames and prepends valid IVF file header -fn build_ivf_container(frames: &[Bytes]) -> Bytes { - // Valid IVF file header taken from: https://github.com/webmproject/... - // vp8-test-vectors/blob/master/vp80-00-comprehensive-001.ivf - // Video Image Width - 176 - // Video Image Height - 144 - // Frame Rate Rate - 30000 - // Frame Rate Scale - 1000 - // Video Length in Frames - 29 - // BitRate: 64.01 kb/s - let header = Bytes::from_static(&[ - 0x44, 0x4b, 0x49, 0x46, 0x00, 0x00, 0x20, 0x00, 0x56, 0x50, 0x38, 0x30, 0xb0, 0x00, 0x90, - 0x00, 0x30, 0x75, 0x00, 0x00, 0xe8, 0x03, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, - ]); - - let mut ivf = BytesMut::new(); - ivf.extend(header); - - for frame in frames { - ivf.extend(frame); - } - - ivf.freeze() -} - -#[test] -fn test_ivf_reader_parse_valid_file_header() -> Result<()> { - let ivf = build_ivf_container(&[]); - - let r = BufReader::new(&ivf[..]); - let (_, header) = IVFReader::new(r)?; - - assert_eq!(&header.signature, b"DKIF", "signature is 'DKIF'"); - assert_eq!(header.version, 0, "version should be 0"); - assert_eq!(&header.four_cc, b"VP80", "FourCC should be 'VP80'"); - assert_eq!(header.width, 176, "width should be 176"); - assert_eq!(header.height, 144, "height should be 144"); - assert_eq!( - header.timebase_denominator, 30000, - "timebase denominator should be 30000" - ); - assert_eq!( - header.timebase_numerator, 1000, - "timebase numerator should be 1000" - ); - assert_eq!(header.num_frames, 29, "number of frames should be 29"); - assert_eq!(header.unused, 0, "bytes should be unused"); - - Ok(()) -} - -#[test] -fn test_ivf_reader_parse_valid_frames() -> Result<()> { - // Frame Length - 4 - // Timestamp - None - // Frame Payload - 0xDEADBEEF - let valid_frame1 = Bytes::from_static(&[ - 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD, 0xBE, - 0xEF, - ]); - - // Frame Length - 12 - // Timestamp - None - // Frame Payload - 0xDEADBEEFDEADBEEF - let valid_frame2 = Bytes::from_static(&[ - 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD, 0xBE, - 0xEF, 0xDE, 0xAD, 0xBE, 0xEF, 0xDE, 0xAD, 0xBE, 0xEF, - ]); - - let ivf = build_ivf_container(&[valid_frame1, valid_frame2]); - let r = BufReader::new(&ivf[..]); - let (mut reader, _) = IVFReader::new(r)?; - - // Parse Frame #1 - let (payload, header) = reader.parse_next_frame()?; - - assert_eq!(header.frame_size, 4, "Frame header frameSize should be 4"); - assert_eq!(payload.len(), 4, "Payload should be length 4"); - assert_eq!( - payload, - Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF,]), - "Payload value should be 0xDEADBEEF" - ); - assert_eq!( - reader.bytes_read, - IVF_FILE_HEADER_SIZE + IVF_FRAME_HEADER_SIZE + header.frame_size as usize - ); - let previous_bytes_read = reader.bytes_read; - - // Parse Frame #2 - let (payload, header) = reader.parse_next_frame()?; - - assert_eq!(header.frame_size, 12, "Frame header frameSize should be 4"); - assert_eq!(payload.len(), 12, "Payload should be length 12"); - assert_eq!( - payload, - Bytes::from_static(&[ - 0xDE, 0xAD, 0xBE, 0xEF, 0xDE, 0xAD, 0xBE, 0xEF, 0xDE, 0xAD, 0xBE, 0xEF, - ]), - "Payload value should be 0xDEADBEEFDEADBEEF" - ); - assert_eq!( - reader.bytes_read, - previous_bytes_read + IVF_FRAME_HEADER_SIZE + header.frame_size as usize, - ); - - Ok(()) -} - -#[test] -fn test_ivf_reader_parse_incomplete_frame_header() -> Result<()> { - // frame with 11-byte header (missing 1 byte) - let incomplete_frame = Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]); - - let ivf = build_ivf_container(&[incomplete_frame]); - let r = BufReader::new(&ivf[..]); - let (mut reader, _) = IVFReader::new(r)?; - - // Parse Frame #1 - let result = reader.parse_next_frame(); - assert!(result.is_err(), "Expected Error but got Ok"); - - Ok(()) -} - -#[test] -fn test_ivf_reader_parse_incomplete_frame_payload() -> Result<()> { - // frame with header defining frameSize of 4 - // but only 2 bytes available (missing 2 bytes) - let incomplete_frame = Bytes::from_static(&[ - 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD, - ]); - - let ivf = build_ivf_container(&[incomplete_frame]); - let r = BufReader::new(&ivf[..]); - let (mut reader, _) = IVFReader::new(r)?; - - // Parse Frame #1 - let result = reader.parse_next_frame(); - assert!(result.is_err(), "Expected Error but got Ok"); - - Ok(()) -} - -#[test] -fn test_ivf_reader_eof_when_no_frames_left() -> Result<()> { - let ivf = build_ivf_container(&[]); - let r = BufReader::new(&ivf[..]); - let (mut reader, _) = IVFReader::new(r)?; - - let result = reader.parse_next_frame(); - assert!(result.is_err(), "Expected Error but got Ok"); - - Ok(()) -} diff --git a/media/src/io/ivf_reader/mod.rs b/media/src/io/ivf_reader/mod.rs deleted file mode 100644 index 62f314917..000000000 --- a/media/src/io/ivf_reader/mod.rs +++ /dev/null @@ -1,127 +0,0 @@ -#[cfg(test)] -mod ivf_reader_test; - -use std::io::Read; - -use byteorder::{LittleEndian, ReadBytesExt}; -use bytes::BytesMut; - -use crate::error::{Error, Result}; -use crate::io::ResetFn; - -pub const IVF_FILE_HEADER_SIGNATURE: &[u8] = b"DKIF"; -pub const IVF_FILE_HEADER_SIZE: usize = 32; -pub const IVF_FRAME_HEADER_SIZE: usize = 12; - -/// IVFFileHeader 32-byte header for IVF files -/// -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub struct IVFFileHeader { - pub signature: [u8; 4], // 0-3 - pub version: u16, // 4-5 - pub header_size: u16, // 6-7 - pub four_cc: [u8; 4], // 8-11 - pub width: u16, // 12-13 - pub height: u16, // 14-15 - pub timebase_denominator: u32, // 16-19 - pub timebase_numerator: u32, // 20-23 - pub num_frames: u32, // 24-27 - pub unused: u32, // 28-31 -} - -/// IVFFrameHeader 12-byte header for IVF frames -/// -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub struct IVFFrameHeader { - pub frame_size: u32, // 0-3 - pub timestamp: u64, // 4-11 -} - -/// IVFReader is used to read IVF files and return frame payloads -pub struct IVFReader { - reader: R, - bytes_read: usize, -} - -impl IVFReader { - /// new returns a new IVF reader and IVF file header - /// with an io.Reader input - pub fn new(reader: R) -> Result<(IVFReader, IVFFileHeader)> { - let mut r = IVFReader { - reader, - bytes_read: 0, - }; - - let header = r.parse_file_header()?; - - Ok((r, header)) - } - - /// reset_reader resets the internal stream of IVFReader. This is useful - /// for live streams, where the end of the file might be read without the - /// data being finished. - pub fn reset_reader(&mut self, mut reset: ResetFn) { - self.reader = reset(self.bytes_read); - } - - /// parse_next_frame reads from stream and returns IVF frame payload, header, - /// and an error if there is incomplete frame data. - /// Returns all nil values when no more frames are available. - pub fn parse_next_frame(&mut self) -> Result<(BytesMut, IVFFrameHeader)> { - let frame_size = self.reader.read_u32::()?; - let timestamp = self.reader.read_u64::()?; - let header = IVFFrameHeader { - frame_size, - timestamp, - }; - - let mut payload = BytesMut::with_capacity(header.frame_size as usize); - payload.resize(header.frame_size as usize, 0); - self.reader.read_exact(&mut payload)?; - - self.bytes_read += IVF_FRAME_HEADER_SIZE + header.frame_size as usize; - - Ok((payload, header)) - } - - /// parse_file_header reads 32 bytes from stream and returns - /// IVF file header. This is always called before parse_next_frame() - fn parse_file_header(&mut self) -> Result { - let mut signature = [0u8; 4]; - let mut four_cc = [0u8; 4]; - - self.reader.read_exact(&mut signature)?; - let version = self.reader.read_u16::()?; - let header_size = self.reader.read_u16::()?; - self.reader.read_exact(&mut four_cc)?; - let width = self.reader.read_u16::()?; - let height = self.reader.read_u16::()?; - let timebase_denominator = self.reader.read_u32::()?; - let timebase_numerator = self.reader.read_u32::()?; - let num_frames = self.reader.read_u32::()?; - let unused = self.reader.read_u32::()?; - - let header = IVFFileHeader { - signature, - version, - header_size, - four_cc, - width, - height, - timebase_denominator, - timebase_numerator, - num_frames, - unused, - }; - - if header.signature != IVF_FILE_HEADER_SIGNATURE { - return Err(Error::ErrSignatureMismatch); - } else if header.version != 0 { - return Err(Error::ErrUnknownIVFVersion); - } - - self.bytes_read += IVF_FILE_HEADER_SIZE; - - Ok(header) - } -} diff --git a/media/src/io/ivf_writer/ivf_writer_test.rs b/media/src/io/ivf_writer/ivf_writer_test.rs deleted file mode 100644 index 9a67199e6..000000000 --- a/media/src/io/ivf_writer/ivf_writer_test.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::io::Cursor; - -use super::*; -use crate::error::Error; - -#[test] -fn test_ivf_writer_add_packet_and_close() -> Result<()> { - // Construct valid packet - let raw_valid_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x00, 0x01, 0x00, - 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x98, 0x36, 0xbe, 0x89, 0x9e, - ]); - - let mut valid_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - //payloadOffset: 20, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: false, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_valid_pkt.slice(20..), - }; - valid_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - // Construct mid partition packet - let raw_mid_part_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x00, 0x01, 0x00, - 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x88, 0x36, 0xbe, 0x89, 0x9e, - ]); - - let mut mid_part_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - //PayloadOffset: 20, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: raw_mid_part_pkt.len() % 4 != 0, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_mid_part_pkt.slice(20..), - }; - mid_part_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - // Construct keyframe packet - let raw_keyframe_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x00, 0x01, 0x00, - 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - - let mut keyframe_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - //PayloadOffset: 20, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: raw_keyframe_pkt.len() % 4 != 0, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_keyframe_pkt.slice(20..), - }; - keyframe_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - // Check valid packet parameters - let mut vp8packet = rtp::codecs::vp8::Vp8Packet::default(); - let payload = vp8packet.depacketize(&valid_packet.payload)?; - assert_eq!(1, vp8packet.s, "Start packet S value should be 1"); - assert_eq!( - payload[0] & 0x01, - 1, - "Non Keyframe packet P value should be 1" - ); - - // Check mid partition packet parameters - let mut vp8packet = rtp::codecs::vp8::Vp8Packet::default(); - let payload = vp8packet.depacketize(&mid_part_packet.payload)?; - assert_eq!(vp8packet.s, 0, "Mid Partition packet S value should be 0"); - assert_eq!( - payload[0] & 0x01, - 1, - "Non Keyframe packet P value should be 1" - ); - - // Check keyframe packet parameters - let mut vp8packet = rtp::codecs::vp8::Vp8Packet::default(); - let payload = vp8packet.depacketize(&keyframe_packet.payload)?; - assert_eq!(vp8packet.s, 1, "Start packet S value should be 1"); - assert_eq!(payload[0] & 0x01, 0, "Keyframe packet P value should be 0"); - - let add_packet_test_case = vec![ - ( - "IVFWriter shouldn't be able to write something an empty packet", - "IVFWriter should be able to close the file", - rtp::packet::Packet::default(), - Some(Error::ErrInvalidNilPacket), - false, - 0, - ), - ( - "IVFWriter should be able to write an IVF packet", - "IVFWriter should be able to close the file", - valid_packet.clone(), - None, - false, - 1, - ), - ( - "IVFWriter should be able to write a Keframe IVF packet", - "IVFWriter should be able to close the file", - keyframe_packet, - None, - true, - 2, - ), - ]; - - let header = IVFFileHeader { - signature: *b"DKIF", // DKIF - version: 0, // version - header_size: 32, // Header size - four_cc: *b"VP80", // FOURCC - width: 640, // Width in pixels - height: 480, // Height in pixels - timebase_denominator: 30, // Framerate denominator - timebase_numerator: 1, // Framerate numerator - num_frames: 900, // Frame count, will be updated on first Close() call - unused: 0, // Unused - }; - - for (msg1, _msg2, packet, err, seen_key_frame, count) in add_packet_test_case { - let mut writer = IVFWriter::new(Cursor::new(Vec::::new()), &header)?; - assert!( - !writer.seen_key_frame, - "Writer's seenKeyFrame should initialize false" - ); - assert_eq!(writer.count, 0, "Writer's packet count should initialize 0"); - let result = writer.write_rtp(&packet); - if err.is_some() { - assert!(result.is_err(), "{}", msg1); - continue; - } else { - assert!(result.is_ok(), "{}", msg1); - } - - assert_eq!(seen_key_frame, writer.seen_key_frame, "{msg1} failed"); - if count == 1 { - assert_eq!(writer.count, 0); - } else if count == 2 { - assert_eq!(writer.count, 1); - } - - writer.write_rtp(&mid_part_packet)?; - if count == 1 { - assert_eq!(writer.count, 0); - } else if count == 2 { - assert_eq!(writer.count, 1); - - writer.write_rtp(&valid_packet)?; - assert_eq!(writer.count, 2); - } - - writer.close()?; - } - - Ok(()) -} diff --git a/media/src/io/ivf_writer/mod.rs b/media/src/io/ivf_writer/mod.rs deleted file mode 100644 index 3e2476e73..000000000 --- a/media/src/io/ivf_writer/mod.rs +++ /dev/null @@ -1,122 +0,0 @@ -#[cfg(test)] -mod ivf_writer_test; - -use std::io::{Seek, SeekFrom, Write}; - -use byteorder::{LittleEndian, WriteBytesExt}; -use bytes::{Bytes, BytesMut}; -use rtp::packetizer::Depacketizer; - -use crate::error::Result; -use crate::io::ivf_reader::IVFFileHeader; -use crate::io::Writer; - -/// IVFWriter is used to take RTP packets and write them to an IVF on disk -pub struct IVFWriter { - writer: W, - count: u64, - seen_key_frame: bool, - current_frame: Option, - is_vp9: bool, -} - -impl IVFWriter { - /// new initialize a new IVF writer with an io.Writer output - pub fn new(writer: W, header: &IVFFileHeader) -> Result { - let mut w = IVFWriter { - writer, - count: 0, - seen_key_frame: false, - current_frame: None, - is_vp9: &header.four_cc != b"VP80", - }; - - w.write_header(header)?; - - Ok(w) - } - - fn write_header(&mut self, header: &IVFFileHeader) -> Result<()> { - self.writer.write_all(&header.signature)?; // DKIF - self.writer.write_u16::(header.version)?; // version - self.writer.write_u16::(header.header_size)?; // Header size - self.writer.write_all(&header.four_cc)?; // FOURCC - self.writer.write_u16::(header.width)?; // Width in pixels - self.writer.write_u16::(header.height)?; // Height in pixels - self.writer - .write_u32::(header.timebase_denominator)?; // Framerate denominator - self.writer - .write_u32::(header.timebase_numerator)?; // Framerate numerator - self.writer.write_u32::(header.num_frames)?; // Frame count, will be updated on first Close() call - self.writer.write_u32::(header.unused)?; // Unused - - Ok(()) - } -} - -impl Writer for IVFWriter { - /// write_rtp adds a new packet and writes the appropriate headers for it - fn write_rtp(&mut self, packet: &rtp::packet::Packet) -> Result<()> { - let mut depacketizer: Box = if self.is_vp9 { - Box::::default() - } else { - Box::::default() - }; - - let payload = depacketizer.depacketize(&packet.payload)?; - - let is_key_frame = payload[0] & 0x01; - - if (!self.seen_key_frame && is_key_frame == 1) - || (self.current_frame.is_none() && !depacketizer.is_partition_head(&packet.payload)) - { - return Ok(()); - } - - self.seen_key_frame = true; - let frame_length = if let Some(current_frame) = &mut self.current_frame { - current_frame.extend(payload); - current_frame.len() - } else { - let mut current_frame = BytesMut::new(); - current_frame.extend(payload); - let frame_length = current_frame.len(); - self.current_frame = Some(current_frame); - frame_length - }; - - if !packet.header.marker { - return Ok(()); - } else if let Some(current_frame) = &self.current_frame { - if current_frame.is_empty() { - return Ok(()); - } - } else { - return Ok(()); - } - - self.writer.write_u32::(frame_length as u32)?; // Frame length - self.writer.write_u64::(self.count)?; // PTS - self.count += 1; - - let frame_content = if let Some(current_frame) = self.current_frame.take() { - current_frame.freeze() - } else { - Bytes::new() - }; - - self.writer.write_all(&frame_content)?; - - Ok(()) - } - - /// close stops the recording - fn close(&mut self) -> Result<()> { - // Update the frame count - self.writer.seek(SeekFrom::Start(24))?; - self.writer.write_u32::(self.count as u32)?; - - self.writer.flush()?; - Ok(()) - } -} diff --git a/media/src/io/mod.rs b/media/src/io/mod.rs deleted file mode 100644 index ade982656..000000000 --- a/media/src/io/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -pub mod h264_reader; -pub mod h264_writer; -use crate::error::Result; - -pub mod ivf_reader; -pub mod ivf_writer; -pub mod ogg_reader; -pub mod ogg_writer; -pub mod sample_builder; - -pub type ResetFn = Box R>; - -// Writer defines an interface to handle -// the creation of media files -pub trait Writer { - // Add the content of an RTP packet to the media - fn write_rtp(&mut self, pkt: &rtp::packet::Packet) -> Result<()>; - // close the media - // Note: close implementation must be idempotent - fn close(&mut self) -> Result<()>; -} diff --git a/media/src/io/ogg_reader/mod.rs b/media/src/io/ogg_reader/mod.rs deleted file mode 100644 index 4a1051622..000000000 --- a/media/src/io/ogg_reader/mod.rs +++ /dev/null @@ -1,204 +0,0 @@ -#[cfg(test)] -mod ogg_reader_test; - -use std::io::{Cursor, Read}; - -use byteorder::{LittleEndian, ReadBytesExt}; -use bytes::BytesMut; - -use crate::error::{Error, Result}; -use crate::io::ResetFn; - -pub const PAGE_HEADER_TYPE_CONTINUATION_OF_STREAM: u8 = 0x00; -pub const PAGE_HEADER_TYPE_BEGINNING_OF_STREAM: u8 = 0x02; -pub const PAGE_HEADER_TYPE_END_OF_STREAM: u8 = 0x04; -pub const DEFAULT_PRE_SKIP: u16 = 3840; // 3840 recommended in the RFC -pub const PAGE_HEADER_SIGNATURE: &[u8] = b"OggS"; -pub const ID_PAGE_SIGNATURE: &[u8] = b"OpusHead"; -pub const COMMENT_PAGE_SIGNATURE: &[u8] = b"OpusTags"; -pub const PAGE_HEADER_SIZE: usize = 27; -pub const ID_PAGE_PAYLOAD_SIZE: usize = 19; - -/// OggReader is used to read Ogg files and return page payloads -pub struct OggReader { - reader: R, - bytes_read: usize, - checksum_table: [u32; 256], - do_checksum: bool, -} - -/// OggHeader is the metadata from the first two pages -/// in the file (ID and Comment) -/// -pub struct OggHeader { - pub channel_map: u8, - pub channels: u8, - pub output_gain: u16, - pub pre_skip: u16, - pub sample_rate: u32, - pub version: u8, -} - -/// OggPageHeader is the metadata for a Page -/// Pages are the fundamental unit of multiplexing in an Ogg stream -/// -pub struct OggPageHeader { - pub granule_position: u64, - - sig: [u8; 4], - version: u8, - header_type: u8, - serial: u32, - index: u32, - segments_count: u8, -} - -impl OggReader { - /// new returns a new Ogg reader and Ogg header - /// with an io.Reader input - pub fn new(reader: R, do_checksum: bool) -> Result<(OggReader, OggHeader)> { - let mut r = OggReader { - reader, - bytes_read: 0, - checksum_table: generate_checksum_table(), - do_checksum, - }; - - let header = r.read_headers()?; - - Ok((r, header)) - } - - fn read_headers(&mut self) -> Result { - let (payload, page_header) = self.parse_next_page()?; - - if page_header.sig != PAGE_HEADER_SIGNATURE { - return Err(Error::ErrBadIDPageSignature); - } - - if page_header.header_type != PAGE_HEADER_TYPE_BEGINNING_OF_STREAM { - return Err(Error::ErrBadIDPageType); - } - - if payload.len() != ID_PAGE_PAYLOAD_SIZE { - return Err(Error::ErrBadIDPageLength); - } - - let s = &payload[..8]; - if s != ID_PAGE_SIGNATURE { - return Err(Error::ErrBadIDPagePayloadSignature); - } - - let mut reader = Cursor::new(&payload[8..]); - let version = reader.read_u8()?; //8 - let channels = reader.read_u8()?; //9 - let pre_skip = reader.read_u16::()?; //10-11 - let sample_rate = reader.read_u32::()?; //12-15 - let output_gain = reader.read_u16::()?; //16-17 - let channel_map = reader.read_u8()?; //18 - - Ok(OggHeader { - channel_map, - channels, - output_gain, - pre_skip, - sample_rate, - version, - }) - } - - // parse_next_page reads from stream and returns Ogg page payload, header, - // and an error if there is incomplete page data. - pub fn parse_next_page(&mut self) -> Result<(BytesMut, OggPageHeader)> { - let mut h = [0u8; PAGE_HEADER_SIZE]; - self.reader.read_exact(&mut h)?; - - let mut head_reader = Cursor::new(h); - let mut sig = [0u8; 4]; //0-3 - head_reader.read_exact(&mut sig)?; - let version = head_reader.read_u8()?; //4 - let header_type = head_reader.read_u8()?; //5 - let granule_position = head_reader.read_u64::()?; //6-13 - let serial = head_reader.read_u32::()?; //14-17 - let index = head_reader.read_u32::()?; //18-21 - let checksum = head_reader.read_u32::()?; //22-25 - let segments_count = head_reader.read_u8()?; //26 - - let mut size_buffer = vec![0u8; segments_count as usize]; - self.reader.read_exact(&mut size_buffer)?; - - let mut payload_size = 0usize; - for s in &size_buffer { - payload_size += *s as usize; - } - - let mut payload = BytesMut::with_capacity(payload_size); - payload.resize(payload_size, 0); - self.reader.read_exact(&mut payload)?; - - if self.do_checksum { - let mut sum = 0; - - for (index, v) in h.iter().enumerate() { - // Don't include expected checksum in our generation - if index > 21 && index < 26 { - sum = self.update_checksum(0, sum); - continue; - } - sum = self.update_checksum(*v, sum); - } - - for v in &size_buffer { - sum = self.update_checksum(*v, sum); - } - for v in &payload[..] { - sum = self.update_checksum(*v, sum); - } - - if sum != checksum { - return Err(Error::ErrChecksumMismatch); - } - } - - let page_header = OggPageHeader { - granule_position, - sig, - version, - header_type, - serial, - index, - segments_count, - }; - - Ok((payload, page_header)) - } - - /// reset_reader resets the internal stream of OggReader. This is useful - /// for live streams, where the end of the file might be read without the - /// data being finished. - pub fn reset_reader(&mut self, mut reset: ResetFn) { - self.reader = reset(self.bytes_read); - } - - fn update_checksum(&self, v: u8, sum: u32) -> u32 { - (sum << 8) ^ self.checksum_table[(((sum >> 24) as u8) ^ v) as usize] - } -} - -pub(crate) fn generate_checksum_table() -> [u32; 256] { - let mut table = [0u32; 256]; - const POLY: u32 = 0x04c11db7; - - for (i, t) in table.iter_mut().enumerate() { - let mut r = (i as u32) << 24; - for _ in 0..8 { - if (r & 0x80000000) != 0 { - r = (r << 1) ^ POLY; - } else { - r <<= 1; - } - } - *t = r; - } - table -} diff --git a/media/src/io/ogg_reader/ogg_reader_test.rs b/media/src/io/ogg_reader/ogg_reader_test.rs deleted file mode 100644 index 311f35e8e..000000000 --- a/media/src/io/ogg_reader/ogg_reader_test.rs +++ /dev/null @@ -1,111 +0,0 @@ -use bytes::Bytes; - -use super::*; - -// generates a valid ogg file that can be used for tests -fn build_ogg_container() -> Vec { - vec![ - 0x4f, 0x67, 0x67, 0x53, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8e, - 0x9b, 0x20, 0xaa, 0x00, 0x00, 0x00, 0x00, 0x61, 0xee, 0x61, 0x17, 0x01, 0x13, 0x4f, 0x70, - 0x75, 0x73, 0x48, 0x65, 0x61, 0x64, 0x01, 0x02, 0x00, 0x0f, 0x80, 0xbb, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x4f, 0x67, 0x67, 0x53, 0x00, 0x00, 0xda, 0x93, 0xc2, 0xd9, 0x00, 0x00, 0x00, - 0x00, 0x8e, 0x9b, 0x20, 0xaa, 0x02, 0x00, 0x00, 0x00, 0x49, 0x97, 0x03, 0x37, 0x01, 0x05, - 0x98, 0x36, 0xbe, 0x88, 0x9e, - ] -} - -#[test] -fn test_ogg_reader_parse_valid_header() -> Result<()> { - let ogg = build_ogg_container(); - let r = Cursor::new(&ogg); - let (_reader, header) = OggReader::new(r, true)?; - - assert_eq!(header.channel_map, 0); - assert_eq!(header.channels, 2); - assert_eq!(header.output_gain, 0); - assert_eq!(header.pre_skip, 0xf00); - assert_eq!(header.sample_rate, 48000); - assert_eq!(header.version, 1); - - Ok(()) -} - -#[test] -fn test_ogg_reader_parse_next_page() -> Result<()> { - let ogg = build_ogg_container(); - let r = Cursor::new(&ogg); - let (mut reader, _header) = OggReader::new(r, true)?; - - let (payload, _) = reader.parse_next_page()?; - assert_eq!(payload, Bytes::from_static(&[0x98, 0x36, 0xbe, 0x88, 0x9e])); - - let result = reader.parse_next_page(); - assert!(result.is_err()); - - Ok(()) -} - -#[test] -fn test_ogg_reader_parse_errors() -> Result<()> { - //"Invalid ID Page Header Signature" - { - let mut ogg = build_ogg_container(); - ogg[0] = 0; - - let result = OggReader::new(Cursor::new(ogg), false); - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrBadIDPageSignature); - } - } - - //"Invalid ID Page Header Type" - { - let mut ogg = build_ogg_container(); - ogg[5] = 0; - - let result = OggReader::new(Cursor::new(ogg), false); - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrBadIDPageType); - } - } - - //"Invalid ID Page Payload Length" - { - let mut ogg = build_ogg_container(); - ogg[27] = 0; - - let result = OggReader::new(Cursor::new(ogg), false); - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrBadIDPageLength); - } - } - - //"Invalid ID Page Payload Length" - { - let mut ogg = build_ogg_container(); - ogg[35] = 0; - - let result = OggReader::new(Cursor::new(ogg), false); - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrBadIDPagePayloadSignature); - } - } - - //"Invalid Page Checksum" - { - let mut ogg = build_ogg_container(); - ogg[22] = 0; - - let result = OggReader::new(Cursor::new(ogg), true); - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrChecksumMismatch); - } - } - - Ok(()) -} diff --git a/media/src/io/ogg_writer/mod.rs b/media/src/io/ogg_writer/mod.rs deleted file mode 100644 index 18530528d..000000000 --- a/media/src/io/ogg_writer/mod.rs +++ /dev/null @@ -1,206 +0,0 @@ -#[cfg(test)] -mod ogg_writer_test; - -use std::io::{BufWriter, Seek, Write}; - -use byteorder::{LittleEndian, WriteBytesExt}; -use bytes::Bytes; -use rtp::packetizer::Depacketizer; - -use crate::error::Result; -use crate::io::ogg_reader::*; -use crate::io::Writer; - -/// OggWriter is used to take RTP packets and write them to an OGG on disk -pub struct OggWriter { - writer: W, - sample_rate: u32, - channel_count: u8, - serial: u32, - page_index: u32, - checksum_table: [u32; 256], - previous_granule_position: u64, - previous_timestamp: u32, - last_payload_size: usize, - last_payload: Bytes, -} - -impl OggWriter { - /// new initialize a new OGG Opus writer with an io.Writer output - pub fn new(writer: W, sample_rate: u32, channel_count: u8) -> Result { - let mut w = OggWriter { - writer, - sample_rate, - channel_count, - serial: rand::random::(), - page_index: 0, - checksum_table: generate_checksum_table(), - - // Timestamp and Granule MUST start from 1 - // Only headers can have 0 values - previous_timestamp: 1, - previous_granule_position: 1, - last_payload_size: 0, - last_payload: Bytes::new(), - }; - - w.write_headers()?; - - Ok(w) - } - - /* - ref: https://tools.ietf.org/html/rfc7845.html - https://git.xiph.org/?p=opus-tools.git;a=blob;f=src/opus_header.c#l219 - - Page 0 Pages 1 ... n Pages (n+1) ... - +------------+ +---+ +---+ ... +---+ +-----------+ +---------+ +-- - | | | | | | | | | | | | | - |+----------+| |+-----------------+| |+-------------------+ +----- - |||ID Header|| || Comment Header || ||Audio Data Packet 1| | ... - |+----------+| |+-----------------+| |+-------------------+ +----- - | | | | | | | | | | | | | - +------------+ +---+ +---+ ... +---+ +-----------+ +---------+ +-- - ^ ^ ^ - | | | - | | Mandatory Page Break - | | - | ID header is contained on a single page - | - 'Beginning Of Stream' - - Figure 1: Example Packet Organization for a Logical Ogg Opus Stream - */ - - fn write_headers(&mut self) -> Result<()> { - // ID Header - let mut ogg_id_header = Vec::with_capacity(19); - { - let mut header_writer = BufWriter::new(&mut ogg_id_header); - header_writer.write_all(ID_PAGE_SIGNATURE)?; // Magic Signature 'OpusHead' - header_writer.write_u8(1)?; // Version //8 - header_writer.write_u8(self.channel_count)?; // Channel count //9 - header_writer.write_u16::(DEFAULT_PRE_SKIP)?; // pre-skip //10-11 - header_writer.write_u32::(self.sample_rate)?; // original sample rate, any valid sample e.g 48000, //12-15 - header_writer.write_u16::(0)?; // output gain // 16-17 - header_writer.write_u8(0)?; // channel map 0 = one stream: mono or stereo, //18 - } - - // Reference: https://tools.ietf.org/html/rfc7845.html#page-6 - // RFC specifies that the ID Header page should have a granule position of 0 and a Header Type set to 2 (StartOfStream) - self.write_page( - &Bytes::from(ogg_id_header), - PAGE_HEADER_TYPE_BEGINNING_OF_STREAM, - 0, - self.page_index, - )?; - self.page_index += 1; - - // Comment Header - let mut ogg_comment_header = Vec::with_capacity(25); - { - let mut header_writer = BufWriter::new(&mut ogg_comment_header); - header_writer.write_all(COMMENT_PAGE_SIGNATURE)?; // Magic Signature 'OpusTags' //0-7 - header_writer.write_u32::(10)?; // Vendor Length //8-11 - header_writer.write_all(b"WebRTC.rs")?; // Vendor name 'WebRTC.rs' //12-20 - header_writer.write_u32::(0)?; // User Comment List Length //21-24 - } - - // RFC specifies that the page where the CommentHeader completes should have a granule position of 0 - self.write_page( - &Bytes::from(ogg_comment_header), - PAGE_HEADER_TYPE_CONTINUATION_OF_STREAM, - 0, - self.page_index, - )?; - self.page_index += 1; - - Ok(()) - } - - fn write_page( - &mut self, - payload: &Bytes, - header_type: u8, - granule_pos: u64, - page_index: u32, - ) -> Result<()> { - self.last_payload_size = payload.len(); - self.last_payload = payload.clone(); - let n_segments = (self.last_payload_size + 255 - 1) / 255; - - let mut page = - Vec::with_capacity(PAGE_HEADER_SIZE + 1 + self.last_payload_size + n_segments); - { - let mut header_writer = BufWriter::new(&mut page); - header_writer.write_all(PAGE_HEADER_SIGNATURE)?; // page headers starts with 'OggS'//0-3 - header_writer.write_u8(0)?; // Version//4 - header_writer.write_u8(header_type)?; // 1 = continuation, 2 = beginning of stream, 4 = end of stream//5 - header_writer.write_u64::(granule_pos)?; // granule position //6-13 - header_writer.write_u32::(self.serial)?; // Bitstream serial number//14-17 - header_writer.write_u32::(page_index)?; // Page sequence number//18-21 - header_writer.write_u32::(0)?; //Checksum reserve //22-25 - header_writer.write_u8(n_segments as u8)?; // Number of segments in page //26 - - // Filling the segment table with the lacing values. - // First (n_segments - 1) values will always be 255. - for _ in 0..n_segments - 1 { - header_writer.write_u8(255)?; - } - // The last value will be the remainder. - header_writer.write_u8((self.last_payload_size - (n_segments * 255 - 255)) as u8)?; - - header_writer.write_all(payload)?; // inserting at 28th since Segment Table(1) + header length(27) - } - - let mut checksum = 0u32; - for v in &page { - checksum = - (checksum << 8) ^ self.checksum_table[(((checksum >> 24) as u8) ^ (*v)) as usize]; - } - page[22..26].copy_from_slice(&checksum.to_le_bytes()); // Checksum - generating for page data and inserting at 22th position into 32 bits - - self.writer.write_all(&page)?; - - Ok(()) - } -} - -impl Writer for OggWriter { - /// write_rtp adds a new packet and writes the appropriate headers for it - fn write_rtp(&mut self, packet: &rtp::packet::Packet) -> Result<()> { - let mut opus_packet = rtp::codecs::opus::OpusPacket; - let payload = opus_packet.depacketize(&packet.payload)?; - - // Should be equivalent to sample_rate * duration - if self.previous_timestamp != 1 { - let increment = packet.header.timestamp - self.previous_timestamp; - self.previous_granule_position += increment as u64; - } - self.previous_timestamp = packet.header.timestamp; - - self.write_page( - &payload, - PAGE_HEADER_TYPE_CONTINUATION_OF_STREAM, - self.previous_granule_position, - self.page_index, - )?; - self.page_index += 1; - - Ok(()) - } - - /// close stops the recording - fn close(&mut self) -> Result<()> { - let payload = self.last_payload.clone(); - self.write_page( - &payload, - PAGE_HEADER_TYPE_END_OF_STREAM, - self.previous_granule_position, - self.page_index - 1, - )?; - - self.writer.flush()?; - Ok(()) - } -} diff --git a/media/src/io/ogg_writer/ogg_writer_test.rs b/media/src/io/ogg_writer/ogg_writer_test.rs deleted file mode 100644 index 1c0e43868..000000000 --- a/media/src/io/ogg_writer/ogg_writer_test.rs +++ /dev/null @@ -1,229 +0,0 @@ -use std::io::Cursor; - -use super::*; -use crate::error::Error; - -#[test] -fn test_ogg_writer_add_packet_and_close() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x00, 0x01, 0x00, - 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - - let mut valid_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - //PayloadOffset: 20, - payload_type: 111, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: false, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_pkt.slice(20..), - }; - valid_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - // The linter misbehave and thinks this code is the same as the tests in ivf-writer_test - // nolint:dupl - let add_packet_test_case = vec![ - ( - "OggWriter shouldn't be able to write an empty packet", - "OggWriter should be able to close the file", - rtp::packet::Packet::default(), - Some(Error::ErrInvalidNilPacket), - ), - ( - "OggWriter should be able to write an Opus packet", - "OggWriter should be able to close the file", - valid_packet, - None, - ), - ]; - - for (msg1, _msg2, packet, err) in add_packet_test_case { - let mut writer = OggWriter::new(Cursor::new(Vec::::new()), 4800, 2)?; - let result = writer.write_rtp(&packet); - if err.is_some() { - assert!(result.is_err(), "{}", msg1); - continue; - } else { - assert!(result.is_ok(), "{}", msg1); - } - writer.close()?; - } - - Ok(()) -} - -#[test] -fn test_ogg_writer_add_packet() -> Result<()> { - let raw_pkt = Bytes::from_iter(std::iter::repeat(0x45).take(235)); - - let mut valid_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - payload_type: 111, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: false, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_pkt, - }; - valid_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - let buffer = Cursor::new(Vec::::new()); - let mut writer = OggWriter::new(buffer, 48000, 2)?; - let result = writer.write_rtp(&valid_packet); - - assert!( - result.is_ok(), - "OggWriter should be able to write an Opus packet smaller than 255 bytes" - ); - assert!( - writer.writer.into_inner()[126..128] == [1, 235], - "OggWriter should be able to write an Opus packet smaller than 255 bytes" - ); - - Ok(()) -} - -#[test] -fn test_ogg_writer_add_packet_of_255() -> Result<()> { - let raw_pkt = Bytes::from_iter(std::iter::repeat(0x45).take(255)); - - let mut valid_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - payload_type: 111, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: false, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_pkt, - }; - valid_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - let buffer = Cursor::new(Vec::::new()); - let mut writer = OggWriter::new(buffer, 48000, 2)?; - let result = writer.write_rtp(&valid_packet); - - assert!( - result.is_ok(), - "OggWriter should be able to write an Opus packet of exactly 255" - ); - assert!( - writer.writer.into_inner()[126..128] == [1, 255], - "OggWriter should be able to write an Opus packet of exactly 255" - ); - - Ok(()) -} - -#[test] -fn test_ogg_writer_add_large_packet() -> Result<()> { - let raw_pkt = Bytes::from_iter(std::iter::repeat(0x45).take(1000)); - - let mut valid_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - payload_type: 111, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: false, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_pkt, - }; - valid_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - let buffer = Cursor::new(Vec::::new()); - let mut writer = OggWriter::new(buffer, 48000, 2)?; - let result = writer.write_rtp(&valid_packet); - - assert!( - result.is_ok(), - "OggWriter should be able to write a large (> 255 bytes) Opus packet" - ); - assert!( - writer.writer.into_inner()[126..131] == [4, 255, 255, 255, 235], - "OggWriter should be able to write multiple segments per page, for 1000 bytes, 4 segments of 255, 255, 255 and 235 long" - ); - - Ok(()) -} - -#[test] -fn test_ogg_writer_add_large_packet_with_multiple_of_255() -> Result<()> { - let raw_pkt = Bytes::from_iter(std::iter::repeat(0x45).take(255 * 4)); - - let mut valid_packet = rtp::packet::Packet { - header: rtp::header::Header { - marker: true, - extension: true, - extension_profile: 1, - version: 2, - payload_type: 111, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - padding: false, - extensions: vec![], - extensions_padding: 0, - }, - payload: raw_pkt, - }; - valid_packet - .header - .set_extension(0, Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]))?; - - let buffer = Cursor::new(Vec::::new()); - let mut writer = OggWriter::new(buffer, 48000, 2)?; - let result = writer.write_rtp(&valid_packet); - - assert!( - result.is_ok(), - "OggWriter should be able to write a large (> 255 bytes) Opus packet" - ); - assert!( - writer.writer.into_inner()[126..131] == [4, 255, 255, 255, 255], - "OggWriter should be able to write multiple segments per page, for 1020 bytes, 4 segments of 255 each" - ); - - Ok(()) -} diff --git a/media/src/io/sample_builder/mod.rs b/media/src/io/sample_builder/mod.rs deleted file mode 100644 index 9bf5e9328..000000000 --- a/media/src/io/sample_builder/mod.rs +++ /dev/null @@ -1,433 +0,0 @@ -#[cfg(test)] -mod sample_builder_test; -#[cfg(test)] -mod sample_sequence_location_test; - -pub mod sample_sequence_location; - -use std::time::{Duration, SystemTime}; - -use bytes::Bytes; -use rtp::packet::Packet; -use rtp::packetizer::Depacketizer; - -use self::sample_sequence_location::{Comparison, SampleSequenceLocation}; -use crate::Sample; - -/// SampleBuilder buffers packets until media frames are complete. -pub struct SampleBuilder { - /// how many packets to wait until we get a valid Sample - max_late: u16, - /// max timestamp between old and new timestamps before dropping packets - max_late_timestamp: u32, - buffer: Vec>, - prepared_samples: Vec>, - last_sample_timestamp: Option, - - /// Interface that allows us to take RTP packets to samples - depacketizer: T, - - /// sample_rate allows us to compute duration of media.SamplecA - sample_rate: u32, - - /// filled contains the head/tail of the packets inserted into the buffer - filled: SampleSequenceLocation, - - /// active contains the active head/tail of the timestamp being actively processed - active: SampleSequenceLocation, - - /// prepared contains the samples that have been processed to date - prepared: SampleSequenceLocation, - - /// number of packets forced to be dropped - dropped_packets: u16, - - /// number of padding packets detected and dropped. This number will be a subset of - /// `dropped_packets` - padding_packets: u16, -} - -impl SampleBuilder { - /// Constructs a new SampleBuilder. - /// `max_late` is how long to wait until we can construct a completed [`Sample`]. - /// `max_late` is measured in RTP packet sequence numbers. - /// A large max_late will result in less packet loss but higher latency. - /// The depacketizer extracts media samples from RTP packets. - /// Several depacketizers are available in package [github.com/pion/rtp/codecs](https://github.com/webrtc-rs/rtp/tree/main/src/codecs). - pub fn new(max_late: u16, depacketizer: T, sample_rate: u32) -> Self { - Self { - max_late, - max_late_timestamp: 0, - buffer: vec![None; u16::MAX as usize + 1], - prepared_samples: (0..=u16::MAX as usize).map(|_| None).collect(), - last_sample_timestamp: None, - depacketizer, - sample_rate, - filled: SampleSequenceLocation::new(), - active: SampleSequenceLocation::new(), - prepared: SampleSequenceLocation::new(), - dropped_packets: 0, - padding_packets: 0, - } - } - - pub fn with_max_time_delay(mut self, max_late_duration: Duration) -> Self { - self.max_late_timestamp = - (self.sample_rate as u128 * max_late_duration.as_millis() / 1000) as u32; - self - } - - fn too_old(&self, location: &SampleSequenceLocation) -> bool { - if self.max_late_timestamp == 0 { - return false; - } - - let mut found_head: Option = None; - let mut found_tail: Option = None; - - let mut i = location.head; - while i != location.tail { - if let Some(ref packet) = self.buffer[i as usize] { - found_head = Some(packet.header.timestamp); - break; - } - i = i.wrapping_add(1); - } - - if found_head.is_none() { - return false; - } - - let mut i = location.tail.wrapping_sub(1); - while i != location.head { - if let Some(ref packet) = self.buffer[i as usize] { - found_tail = Some(packet.header.timestamp); - break; - } - i = i.wrapping_sub(1); - } - - if found_tail.is_none() { - return false; - } - - found_tail.unwrap() - found_head.unwrap() > self.max_late_timestamp - } - - /// Returns the timestamp associated with a given sample location - fn fetch_timestamp(&self, location: &SampleSequenceLocation) -> Option { - if location.empty() { - None - } else { - Some( - (self.buffer[location.head as usize]) - .as_ref()? - .header - .timestamp, - ) - } - } - - fn release_packet(&mut self, i: u16) { - self.buffer[i as usize] = None; - } - - /// Clears all buffers that have already been consumed by - /// popping. - fn purge_consumed_buffers(&mut self) { - let active = self.active; - self.purge_consumed_location(&active, false); - } - - /// Clears all buffers that have already been consumed - /// during a sample building method. - fn purge_consumed_location(&mut self, consume: &SampleSequenceLocation, force_consume: bool) { - if !self.filled.has_data() { - return; - } - match consume.compare(self.filled.head) { - Comparison::Inside if force_consume => { - self.release_packet(self.filled.head); - self.filled.head = self.filled.head.wrapping_add(1); - } - Comparison::Before => { - self.release_packet(self.filled.head); - self.filled.head = self.filled.head.wrapping_add(1); - } - _ => {} - } - } - - /// Flushes all buffers that are already consumed or those buffers - /// that are too late to consume. - fn purge_buffers(&mut self) { - self.purge_consumed_buffers(); - - while (self.too_old(&self.filled) || (self.filled.count() > self.max_late)) - && self.filled.has_data() - { - if self.active.empty() { - // refill the active based on the filled packets - self.active = self.filled; - } - - if self.active.has_data() && (self.active.head == self.filled.head) { - // attempt to force the active packet to be consumed even though - // outstanding data may be pending arrival - let err = match self.build_sample(true) { - Ok(_) => continue, - Err(e) => e, - }; - - if !matches!(err, BuildError::InvalidPartition(_)) { - // In the InvalidPartition case `build_sample` will have already adjusted `dropped_packets`. - self.dropped_packets += 1; - } - - // could not build the sample so drop it - self.active.head = self.active.head.wrapping_add(1); - } - - self.release_packet(self.filled.head); - self.filled.head = self.filled.head.wrapping_add(1); - } - } - - /// Adds an RTP Packet to self's buffer. - /// - /// Push does not copy the input. If you wish to reuse - /// this memory make sure to copy before calling push - pub fn push(&mut self, p: Packet) { - let sequence_number = p.header.sequence_number; - self.buffer[sequence_number as usize] = Some(p); - match self.filled.compare(sequence_number) { - Comparison::Void => { - self.filled.head = sequence_number; - self.filled.tail = sequence_number.wrapping_add(1); - } - Comparison::Before => { - self.filled.head = sequence_number; - } - Comparison::After => { - self.filled.tail = sequence_number.wrapping_add(1); - } - _ => {} - } - self.purge_buffers(); - } - - /// Creates a sample from a valid collection of RTP Packets by - /// walking forwards building a sample if everything looks good clear and - /// update buffer+values - fn build_sample( - &mut self, - purging_buffers: bool, - ) -> Result { - if self.active.empty() { - self.active = self.filled; - } - - if self.active.empty() { - return Err(BuildError::NoActiveSegment); - } - - if self.filled.compare(self.active.tail) == Comparison::Inside { - self.active.tail = self.filled.tail; - } - - let mut consume = SampleSequenceLocation::new(); - - let mut i = self.active.head; - // `self.active` isn't modified in the loop, fetch the timestamp once and cache it. - let head_timestamp = self.fetch_timestamp(&self.active); - while let Some(ref packet) = self.buffer[i as usize] { - if self.active.compare(i) == Comparison::After { - break; - } - let is_same_timestamp = head_timestamp.map(|t| packet.header.timestamp == t); - let is_different_timestamp = is_same_timestamp.map(std::ops::Not::not); - let is_partition_tail = self - .depacketizer - .is_partition_tail(packet.header.marker, &packet.payload); - - // If the timestamp is not the same it might be because the next packet is both a start - // and end of the next partition in which case a sample should be generated now. This - // can happen when padding packets are used .e.g: - // - // p1(t=1), p2(t=1), p3(t=1), p4(t=2, marker=true, start=true) - // - // In thic case the generated sample should be p1 through p3, but excluding p4 which is - // its own sample. - if is_partition_tail && is_same_timestamp.unwrap_or(true) { - consume.head = self.active.head; - consume.tail = i.wrapping_add(1); - break; - } - - if is_different_timestamp.unwrap_or(false) { - consume.head = self.active.head; - consume.tail = i; - break; - } - i = i.wrapping_add(1); - } - - if consume.empty() { - return Err(BuildError::NothingToConsume); - } - - if !purging_buffers && self.buffer[consume.tail as usize].is_none() { - // wait for the next packet after this set of packets to arrive - // to ensure at least one post sample timestamp is known - // (unless we have to release right now) - return Err(BuildError::PendingTimestampPacket); - } - - let sample_timestamp = self.fetch_timestamp(&self.active).unwrap_or(0); - let mut after_timestamp = sample_timestamp; - - // scan for any packet after the current and use that time stamp as the diff point - for i in consume.tail..self.active.tail { - if let Some(ref packet) = self.buffer[i as usize] { - after_timestamp = packet.header.timestamp; - break; - } - } - - // prior to decoding all the packets, check if this packet - // would end being disposed anyway - let head_payload = self.buffer[consume.head as usize] - .as_ref() - .map(|p| &p.payload) - .ok_or(BuildError::GapInSegment)?; - if !self.depacketizer.is_partition_head(head_payload) { - // libWebRTC will sometimes send several empty padding packets to smooth out send - // rate. These packets don't carry any media payloads. - let is_padding = consume.range(&self.buffer).all(|p| { - p.map(|p| { - self.last_sample_timestamp == Some(p.header.timestamp) && p.payload.is_empty() - }) - .unwrap_or(false) - }); - - self.dropped_packets += consume.count(); - if is_padding { - self.padding_packets += consume.count(); - } - self.purge_consumed_location(&consume, true); - self.purge_consumed_buffers(); - - self.active.head = consume.tail; - return Err(BuildError::InvalidPartition(consume)); - } - - // the head set of packets is now fully consumed - self.active.head = consume.tail; - - // merge all the buffers into a sample - let mut data: Vec = Vec::new(); - let mut i = consume.head; - while i != consume.tail { - let payload = self.buffer[i as usize] - .as_ref() - .map(|p| &p.payload) - .ok_or(BuildError::GapInSegment)?; - - let p = self - .depacketizer - .depacketize(payload) - .map_err(|_| BuildError::DepacketizerFailed)?; - - data.extend_from_slice(&p); - i = i.wrapping_add(1); - } - let samples = after_timestamp - sample_timestamp; - - let sample = Sample { - data: Bytes::copy_from_slice(&data), - timestamp: SystemTime::now(), - duration: Duration::from_secs_f64((samples as f64) / (self.sample_rate as f64)), - packet_timestamp: sample_timestamp, - prev_dropped_packets: self.dropped_packets, - prev_padding_packets: self.padding_packets, - }; - - self.dropped_packets = 0; - self.padding_packets = 0; - self.last_sample_timestamp = Some(sample_timestamp); - - self.prepared_samples[self.prepared.tail as usize] = Some(sample); - self.prepared.tail = self.prepared.tail.wrapping_add(1); - - self.purge_consumed_location(&consume, true); - self.purge_consumed_buffers(); - - Ok(consume) - } - - /// Compiles pushed RTP packets into media samples and then - /// returns the next valid sample (or None if no sample is compiled). - pub fn pop(&mut self) -> Option { - let _ = self.build_sample(false); - - if self.prepared.empty() { - return None; - } - let result = self.prepared_samples[self.prepared.head as usize].take(); - self.prepared.head = self.prepared.head.wrapping_add(1); - result - } - - /// Compiles pushed RTP packets into media samples and then - /// returns the next valid sample with its associated RTP timestamp (or `None` if - /// no sample is compiled). - pub fn pop_with_timestamp(&mut self) -> Option<(Sample, u32)> { - if let Some(sample) = self.pop() { - let timestamp = sample.packet_timestamp; - Some((sample, timestamp)) - } else { - None - } - } -} - -/// Computes the distance between two sequence numbers -/*pub(crate) fn seqnum_distance(head: u16, tail: u16) -> u16 { - if head > tail { - head.wrapping_add(tail) - } else { - tail - head - } -}*/ - -pub(crate) fn seqnum_distance(x: u16, y: u16) -> u16 { - let diff = x.wrapping_sub(y); - if diff > 0xFFFF / 2 { - 0xFFFF - diff + 1 - } else { - diff - } -} - -#[derive(Debug)] -enum BuildError { - /// There's no active segment of RTP packets to consider yet. - NoActiveSegment, - - /// No sample partition could be found in the active segment. - NothingToConsume, - - /// A segment to consume was identified, but a subsequent packet is needed to determine the - /// duration of the sample. - PendingTimestampPacket, - - /// The active segment's head was not aligned with a sample partition head. Some packets were - /// dropped. - InvalidPartition(SampleSequenceLocation), - - /// There was a gap in the active segment because of one or more missing RTP packets. - GapInSegment, - - /// We failed to depacketize an RTP packet. - DepacketizerFailed, -} diff --git a/media/src/io/sample_builder/sample_builder_test.rs b/media/src/io/sample_builder/sample_builder_test.rs deleted file mode 100644 index 097a82d6d..000000000 --- a/media/src/io/sample_builder/sample_builder_test.rs +++ /dev/null @@ -1,1499 +0,0 @@ -use rtp::header::Header; -use rtp::packet::Packet; -use rtp::packetizer::Depacketizer; - -use super::*; - -// Turns u8 integers into Bytes Array -macro_rules! bytes { - ($($item:expr),*) => ({ - static STATIC_SLICE: &'static [u8] = &[$($item), *]; - Bytes::from_static(STATIC_SLICE) - }); -} -#[derive(Default)] -pub struct SampleBuilderTest { - message: String, - packets: Vec, - with_head_checker: bool, - head_bytes: Vec, - samples: Vec, - max_late: u16, - max_late_timestamp: Duration, - extra_pop_attempts: usize, -} - -pub struct FakeDepacketizer { - head_checker: bool, - head_bytes: Vec, -} - -impl FakeDepacketizer { - fn new() -> Self { - Self { - head_checker: false, - head_bytes: vec![], - } - } -} - -impl Depacketizer for FakeDepacketizer { - fn depacketize(&mut self, b: &Bytes) -> std::result::Result { - Ok(b.clone()) - } - - /// Checks if the packet is at the beginning of a partition. This - /// should return false if the result could not be determined, in - /// which case the caller will detect timestamp discontinuities. - fn is_partition_head(&self, payload: &Bytes) -> bool { - if !self.head_checker { - // from .go: simulates a bug in 3.0 version, the tests should not assume the bug - return true; - } - - for b in &self.head_bytes { - if *payload == b { - return true; - } - } - false - } - - /// Checks if the packet is at the end of a partition. This should - /// return false if the result could not be determined. - fn is_partition_tail(&self, marker: bool, _payload: &Bytes) -> bool { - marker - } -} - -#[test] -pub fn test_sample_builder() { - #![allow(clippy::needless_update)] - let test_data: Vec = vec![ - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder shouldn't emit anything if only one RTP packet has been pushed".into(), - packets: vec![Packet { - header: Header { - sequence_number: 5000, - timestamp: 5, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }], - samples: vec![], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder shouldn't emit anything if only one RTP packet has been pushed even if the marker bit is set".into(), - packets: vec![Packet { - header: Header { - sequence_number: 5000, - timestamp: 5, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }], - samples: vec![], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should emit two packets, we had three packets with unique timestamps".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5001, - timestamp: 6, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5002, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First sample - data: bytes!(1), - duration: Duration::from_secs(1), // technically this is the default value, but since it was in .go source.... - packet_timestamp: 5, - ..Default::default() - }, - Sample { - // Second sample - data: bytes!(2), - duration: Duration::from_secs(1), - packet_timestamp: 6, - ..Default::default() - }, - ], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should emit one packet, we had a packet end of sequence marker and run out of space".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5002, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5004, - timestamp: 9, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Fourth packet - header: Header { - sequence_number: 5006, - timestamp: 11, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - Packet { - // Fifth packet - header: Header { - sequence_number: 5008, - timestamp: 13, - ..Default::default() - }, - payload: bytes!(5), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5010, - timestamp: 15, - ..Default::default() - }, - payload: bytes!(6), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5012, - timestamp: 17, - ..Default::default() - }, - payload: bytes!(7), - ..Default::default() - }, - ], - samples: vec![Sample { - // First sample - data: bytes!(1), - duration: Duration::from_secs(2), - packet_timestamp: 5, - ..Default::default() - }], - max_late: 5, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder shouldn't emit any packet, we do not have a valid end of sequence and run out of space".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5002, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5004, - timestamp: 9, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Fourth packet - header: Header { - sequence_number: 5006, - timestamp: 11, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - Packet { - // Fifth packet - header: Header { - sequence_number: 5008, - timestamp: 13, - ..Default::default() - }, - payload: bytes!(5), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5010, - timestamp: 15, - ..Default::default() - }, - payload: bytes!(6), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5012, - timestamp: 17, - ..Default::default() - }, - payload: bytes!(7), - ..Default::default() - }, - ], - samples: vec![], - max_late: 5, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should emit one packet, we had a packet end of sequence marker and run out of space".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5002, - timestamp: 7, - marker: true, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5004, - timestamp: 9, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Fourth packet - header: Header { - sequence_number: 5006, - timestamp: 11, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - Packet { - // Fifth packet - header: Header { - sequence_number: 5008, - timestamp: 13, - ..Default::default() - }, - payload: bytes!(5), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5010, - timestamp: 15, - ..Default::default() - }, - payload: bytes!(6), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5012, - timestamp: 17, - ..Default::default() - }, - payload: bytes!(7), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First (dropped) sample - data: bytes!(1), - duration: Duration::from_secs(2), - packet_timestamp: 5, - ..Default::default() - }, - Sample { - // First correct sample - data: bytes!(2), - duration: Duration::from_secs(2), - packet_timestamp: 7, - prev_dropped_packets: 1, - ..Default::default() - }, - ], - max_late: 5, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should emit one packet, we had two packets but with duplicate timestamps".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5001, - timestamp: 6, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5002, - timestamp: 6, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Fourth packet - header: Header { - sequence_number: 5003, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First sample - data: bytes!(1), - duration: Duration::from_secs(1), - packet_timestamp: 5, - ..Default::default() - }, - Sample { - // Second (duplicate) correct sample - data: bytes!(2, 3), - duration: Duration::from_secs(1), - packet_timestamp: 6, - ..Default::default() - }, - ], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder shouldn't emit a packet because we have a gap before a valid one".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5007, - timestamp: 6, - marker: true, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5008, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - ], - samples: vec![], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder shouldn't emit a packet after a gap as there are gaps and have not reached maxLate yet".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5007, - timestamp: 6, - marker: true, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5008, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - ], - with_head_checker: true, - head_bytes: vec![bytes!(2)], - samples: vec![], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder shouldn't emit a packet after a gap if PartitionHeadChecker doesn't assume it head".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 5, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5007, - timestamp: 6, - marker: true, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5008, - timestamp: 7, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - ], - with_head_checker: true, - head_bytes: vec![], - samples: vec![], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should emit multiple valid packets".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5001, - timestamp: 2, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5002, - timestamp: 3, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Fourth packet - header: Header { - sequence_number: 5003, - timestamp: 4, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - Packet { - // Fifth packet - header: Header { - sequence_number: 5004, - timestamp: 5, - ..Default::default() - }, - payload: bytes!(5), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5005, - timestamp: 6, - ..Default::default() - }, - payload: bytes!(6), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First sample - data: bytes!(1), - duration: Duration::from_secs(1), - packet_timestamp: 1, - ..Default::default() - }, - Sample { - // Second sample - data: bytes!(2), - duration: Duration::from_secs(1), - packet_timestamp: 2, - ..Default::default() - }, - Sample { - // Third sample - data: bytes!(3), - duration: Duration::from_secs(1), - packet_timestamp: 3, - ..Default::default() - }, - Sample { - // Fourth sample - data: bytes!(4), - duration: Duration::from_secs(1), - packet_timestamp: 4, - ..Default::default() - }, - Sample { - // Fifth sample - data: bytes!(5), - duration: Duration::from_secs(1), - packet_timestamp: 5, - ..Default::default() - }, - ], - max_late: 50, - max_late_timestamp: Duration::from_secs(0), - ..Default::default() - }, - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should skip timestamps too old".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5001, - timestamp: 2, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5002, - timestamp: 3, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Fourth packet - header: Header { - sequence_number: 5013, - timestamp: 4000, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - Packet { - // Fifth packet - header: Header { - sequence_number: 5014, - timestamp: 4000, - ..Default::default() - }, - payload: bytes!(5), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5015, - timestamp: 4002, - ..Default::default() - }, - payload: bytes!(6), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5016, - timestamp: 7000, - ..Default::default() - }, - payload: bytes!(4), - ..Default::default() - }, - Packet { - // Eighth packet - header: Header { - sequence_number: 5017, - timestamp: 7001, - ..Default::default() - }, - payload: bytes!(5), - ..Default::default() - }, - ], - samples: vec![Sample { - // First sample - data: bytes!(4, 5), - duration: Duration::from_secs(2), - packet_timestamp: 4000, - prev_dropped_packets: 12, - ..Default::default() - }], - with_head_checker: true, - head_bytes: vec![bytes!(4)], - max_late: 50, - max_late_timestamp: Duration::from_secs(2000), - ..Default::default() - }, - // This test is based on observed RTP packet streams from Chrome. libWebRTC inserts padding - // packets to keep send rates steady, these are not important for sample building but we - // should identify them as padding packets to differentiate them from lost packets. - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should recognise padding packets".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5001, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5002, - timestamp: 1, - marker: true, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Padding packet 1 - header: Header { - sequence_number: 5003, - timestamp: 1, - ..Default::default() - }, - payload: Bytes::from_static(&[]), - ..Default::default() - }, - Packet { - // Padding packet 2 - header: Header { - sequence_number: 5004, - timestamp: 1, - ..Default::default() - }, - payload: Bytes::from_static(&[]), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5005, - timestamp: 2, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5006, - timestamp: 2, - marker: true, - ..Default::default() - }, - payload: bytes!(7), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5007, - timestamp: 3, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First sample - data: bytes!(1, 2, 3), - duration: Duration::from_secs(0), - packet_timestamp: 1, - prev_dropped_packets: 0, - ..Default::default() - }, - Sample { - // Second sample - data: bytes!(1, 7), - duration: Duration::from_secs(1), - packet_timestamp: 2, - prev_dropped_packets: 2, - prev_padding_packets: 2, - ..Default::default() - }, - ], - with_head_checker: true, - head_bytes: vec![bytes!(1)], - max_late: 50, - max_late_timestamp: Duration::from_secs(2000), - extra_pop_attempts: 1, - ..Default::default() - }, - // This test is based on observed RTP packet streams when screen sharing in Chrome. - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should recognise padding packets when combined with max_late_timestamp".into(), - packets: vec![ - Packet { - // First packet - header: Header { - sequence_number: 5000, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Second packet - header: Header { - sequence_number: 5001, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - Packet { - // Third packet - header: Header { - sequence_number: 5002, - timestamp: 1, - marker: true, - ..Default::default() - }, - payload: bytes!(3), - ..Default::default() - }, - Packet { - // Padding packet 1 - header: Header { - sequence_number: 5003, - timestamp: 1, - ..Default::default() - }, - payload: Bytes::from_static(&[]), - ..Default::default() - }, - Packet { - // Padding packet 2 - header: Header { - sequence_number: 5004, - timestamp: 1, - ..Default::default() - }, - payload: Bytes::from_static(&[]), - ..Default::default() - }, - Packet { - // Sixth packet - header: Header { - sequence_number: 5005, - timestamp: 3, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5006, - timestamp: 3, - marker: true, - ..Default::default() - }, - payload: bytes!(7), - ..Default::default() - }, - Packet { - // Seventh packet - header: Header { - sequence_number: 5007, - timestamp: 4, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First sample - data: bytes!(1, 2, 3), - duration: Duration::from_secs(0), - packet_timestamp: 1, - prev_dropped_packets: 0, - ..Default::default() - }, - Sample { - // Second sample - data: bytes!(1, 7), - duration: Duration::from_secs(1), - packet_timestamp: 3, - prev_dropped_packets: 2, - prev_padding_packets: 2, - ..Default::default() - }, - ], - with_head_checker: true, - head_bytes: vec![bytes!(1)], - max_late: 50, - max_late_timestamp: Duration::from_millis(1050), - extra_pop_attempts: 1, - ..Default::default() - }, - // This test is based on observed RTP packet streams when screen sharing in Chrome. - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should build a sample out of a packet that's both start and end".into(), - packets: vec![ - Packet { - header: Header { - sequence_number: 5000, - timestamp: 1, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - Packet { - header: Header { - sequence_number: 5001, - timestamp: 2, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - ], - samples: vec![Sample { - // First sample - data: bytes!(1), - duration: Duration::from_secs(1), - packet_timestamp: 1, - prev_dropped_packets: 0, - ..Default::default() - }], - with_head_checker: true, - head_bytes: vec![bytes!(1)], - max_late: 50, - max_late_timestamp: Duration::from_millis(1050), - ..Default::default() - }, - // This test is based on observed RTP packet streams when screen sharing in Chrome. In - // particular the scenario used involved no movement on screen which causes Chrome to - // generate padding packets. - SampleBuilderTest { - #[rustfmt::skip] - message: "Sample builder should build a sample out of a packet that's both start and end following a run of padding packets".into(), - packets: vec![ - // First valid packet - Packet { - header: Header { - sequence_number: 5000, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - // Second valid packet - Packet { - header: Header { - sequence_number: 5001, - timestamp: 1, - marker: true, - ..Default::default() - }, - payload: bytes!(2), - ..Default::default() - }, - // Padding packet 1 - Packet { - header: Header { - sequence_number: 5002, - timestamp: 1, - ..Default::default() - }, - payload: Bytes::default(), - ..Default::default() - }, - // Padding packet 2 - Packet { - header: Header { - sequence_number: 5003, - timestamp: 1, - ..Default::default() - }, - payload: Bytes::default(), - ..Default::default() - }, - // Third valid packet - Packet { - header: Header { - sequence_number: 5004, - timestamp: 2, - marker: true, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - // Fourth valid packet, start of next sample - Packet { - header: Header { - sequence_number: 5005, - timestamp: 3, - ..Default::default() - }, - payload: bytes!(1), - ..Default::default() - }, - ], - samples: vec![ - Sample { - // First sample - data: bytes!(1, 2), - duration: Duration::from_secs(0), - packet_timestamp: 1, - prev_dropped_packets: 0, - ..Default::default() - }, - Sample { - // Second sample - data: bytes!(1), - duration: Duration::from_secs(1), - packet_timestamp: 2, - prev_dropped_packets: 2, - prev_padding_packets: 2, - ..Default::default() - }, - ], - with_head_checker: true, - head_bytes: vec![bytes!(1)], - extra_pop_attempts: 1, - max_late: 50, - ..Default::default() - }, - ]; - - for t in test_data { - let d = FakeDepacketizer { - head_checker: t.with_head_checker, - head_bytes: t.head_bytes, - }; - - let mut s = { - let sample_builder = SampleBuilder::new(t.max_late, d, 1); - if t.max_late_timestamp != Duration::from_secs(0) { - sample_builder.with_max_time_delay(t.max_late_timestamp) - } else { - sample_builder - } - }; - - let mut samples = Vec::::new(); - for p in t.packets { - s.push(p) - } - - while let Some(sample) = s.pop() { - samples.push(sample) - } - - for _ in 0..t.extra_pop_attempts { - // Pop some more - while let Some(sample) = s.pop() { - samples.push(sample) - } - } - - // Current problem: Sample does not implement Eq. Either implement myself or find another way of comparison. (Derive does not work) - assert_eq!(t.samples, samples, "{}", t.message); - } -} - -// SampleBuilder should respect maxLate if we popped successfully but then have a gap larger then maxLate -#[test] -fn test_sample_builder_max_late() { - let mut s = SampleBuilder::new(50, FakeDepacketizer::new(), 1); - - s.push(Packet { - header: Header { - sequence_number: 0, - timestamp: 1, - ..Default::default() - }, - payload: bytes!(0x01), - }); - s.push(Packet { - header: Header { - sequence_number: 1, - timestamp: 2, - ..Default::default() - }, - payload: bytes!(0x01), - }); - s.push(Packet { - header: Header { - sequence_number: 2, - timestamp: 3, - ..Default::default() - }, - payload: bytes!(0x01), - }); - assert_eq!( - s.pop(), - Some(Sample { - data: bytes!(0x01), - duration: Duration::from_secs(1), - packet_timestamp: 1, - ..Default::default() - }), - "Failed to build samples before gap" - ); - - s.push(Packet { - header: Header { - sequence_number: 5000, - timestamp: 500, - ..Default::default() - }, - payload: bytes!(0x02), - }); - s.push(Packet { - header: Header { - sequence_number: 5001, - timestamp: 501, - ..Default::default() - }, - payload: bytes!(0x02), - }); - s.push(Packet { - header: Header { - sequence_number: 5002, - timestamp: 502, - ..Default::default() - }, - payload: bytes!(0x02), - }); - - assert_eq!( - s.pop(), - Some(Sample { - data: bytes!(0x01), - duration: Duration::from_secs(1), - packet_timestamp: 2, - ..Default::default() - }), - "Failed to build samples after large gap" - ); - assert_eq!(None, s.pop(), "Failed to build samples after large gap"); - - s.push(Packet { - header: Header { - sequence_number: 6000, - timestamp: 600, - ..Default::default() - }, - payload: bytes!(0x03), - }); - assert_eq!( - s.pop(), - Some(Sample { - data: bytes!(0x02), - duration: Duration::from_secs(1), - packet_timestamp: 500, - prev_dropped_packets: 4998, - ..Default::default() - }), - "Failed to build samples after large gap" - ); - assert_eq!( - s.pop(), - Some(Sample { - data: bytes!(0x02), - duration: Duration::from_secs(1), - packet_timestamp: 501, - ..Default::default() - }), - "Failed to build samples after large gap" - ); -} - -#[test] -fn test_seqnum_distance() { - struct TestData { - x: u16, - y: u16, - d: u16, - } - let test_data = vec![ - TestData { - x: 0x0001, - y: 0x0003, - d: 0x0002, - }, - TestData { - x: 0x0003, - y: 0x0001, - d: 0x0002, - }, - TestData { - x: 0xFFF3, - y: 0xFFF1, - d: 0x0002, - }, - TestData { - x: 0xFFF1, - y: 0xFFF3, - d: 0x0002, - }, - TestData { - x: 0xFFFF, - y: 0x0001, - d: 0x0002, - }, - TestData { - x: 0x0001, - y: 0xFFFF, - d: 0x0002, - }, - ]; - - for data in test_data { - assert_eq!( - seqnum_distance(data.x, data.y), - data.d, - "seqnum_distance({}, {}) returned {} which must be {}", - data.x, - data.y, - seqnum_distance(data.x, data.y), - data.d - ); - } -} - -#[test] -fn test_sample_builder_clean_reference() { - for seq_start in [0_u16, 0xfff8, 0xfffe] { - let mut s = SampleBuilder::new(10, FakeDepacketizer::new(), 1); - s.push(Packet { - header: Header { - sequence_number: seq_start, - timestamp: 0, - ..Default::default() - }, - payload: bytes!(0x01), - }); - s.push(Packet { - header: Header { - sequence_number: seq_start.wrapping_add(1), - timestamp: 0, - ..Default::default() - }, - payload: bytes!(0x02), - }); - s.push(Packet { - header: Header { - sequence_number: seq_start.wrapping_add(2), - timestamp: 0, - ..Default::default() - }, - payload: bytes!(0x03), - }); - let pkt4 = Packet { - header: Header { - sequence_number: seq_start.wrapping_add(14), - timestamp: 120, - ..Default::default() - }, - payload: bytes!(0x04), - }; - s.push(pkt4.clone()); - let pkt5 = Packet { - header: Header { - sequence_number: seq_start.wrapping_add(12), - timestamp: 120, - ..Default::default() - }, - payload: bytes!(0x05), - }; - s.push(pkt5.clone()); - - for i in 0..3 { - assert_eq!( - s.buffer[seq_start.wrapping_add(i) as usize], - None, - "Old packet ({i}) is not unreferenced (seq_start: {seq_start}, max_late: 10, pushed: 12)" - ); - } - assert_eq!(s.buffer[seq_start.wrapping_add(14) as usize], Some(pkt4)); - assert_eq!(s.buffer[seq_start.wrapping_add(12) as usize], Some(pkt5)); - } -} - -#[test] -fn test_sample_builder_push_max_zero() { - let pkt = Packet { - header: Header { - sequence_number: 0, - timestamp: 0, - marker: true, - ..Default::default() - }, - payload: bytes!(0x01), - }; - let d = FakeDepacketizer { - head_checker: true, - head_bytes: vec![bytes!(0x01)], - }; - let mut s = SampleBuilder::new(0, d, 1); - s.push(pkt); - assert!(s.pop().is_some(), "Should expect a popped sample.") -} - -#[test] -fn test_pop_with_timestamp() { - let mut s = SampleBuilder::new(0, FakeDepacketizer::new(), 1); - assert_eq!(s.pop_with_timestamp(), None); -} - -#[test] -fn test_sample_builder_data() { - let mut s = SampleBuilder::new(10, FakeDepacketizer::new(), 1); - let mut j: usize = 0; - for i in 0..0x20000_usize { - let p = Packet { - header: Header { - sequence_number: i as u16, - timestamp: (i + 42) as u32, - ..Default::default() - }, - payload: Bytes::copy_from_slice(&[i as u8]), - }; - s.push(p); - while let Some((sample, ts)) = s.pop_with_timestamp() { - assert_eq!(ts, (j + 42) as u32, "timestamp"); - assert_eq!(sample.data.len(), 1, "data length"); - assert_eq!(sample.data[0], j as u8, "timestamp"); - j += 1; - } - } - // only the last packet should be dropped - assert_eq!(j, 0x1FFFF); -} diff --git a/media/src/io/sample_builder/sample_sequence_location.rs b/media/src/io/sample_builder/sample_sequence_location.rs deleted file mode 100644 index b4e10de71..000000000 --- a/media/src/io/sample_builder/sample_sequence_location.rs +++ /dev/null @@ -1,81 +0,0 @@ -use super::seqnum_distance; - -#[derive(Debug, PartialEq)] -pub(crate) enum Comparison { - Void, - Before, - Inside, - After, -} - -pub(crate) struct Iterator<'a, T> { - data: &'a [Option], - sample: SampleSequenceLocation, - i: u16, -} - -impl<'a, T> std::iter::Iterator for Iterator<'a, T> { - type Item = Option<&'a T>; - - fn next(&mut self) -> Option { - if self.sample.compare(self.i) == Comparison::Inside { - let old_i = self.i as usize; - self.i = self.i.wrapping_add(1); - return Some(self.data[old_i].as_ref()); - } - - None - } -} - -#[derive(Debug, Clone, Copy)] -pub(crate) struct SampleSequenceLocation { - /// head is the first packet in a sequence - pub(crate) head: u16, - /// tail is always set to one after the final sequence number, - /// so if `head == tail` then the sequence is empty - pub(crate) tail: u16, -} - -impl SampleSequenceLocation { - pub(crate) fn new() -> Self { - Self { head: 0, tail: 0 } - } - - pub(crate) fn empty(&self) -> bool { - self.head == self.tail - } - - pub(crate) fn has_data(&self) -> bool { - self.head != self.tail - } - - pub(crate) fn count(&self) -> u16 { - seqnum_distance(self.head, self.tail) - } - - pub(crate) fn compare(&self, pos: u16) -> Comparison { - if self.head == self.tail { - return Comparison::Void; - } - if self.head < self.tail { - if self.head <= pos && pos < self.tail { - return Comparison::Inside; - } - } else if self.head <= pos || pos < self.tail { - return Comparison::Inside; - } - if self.head.wrapping_sub(pos) <= pos.wrapping_sub(self.tail) { - return Comparison::Before; - } - Comparison::After - } - - pub(crate) fn range<'a, T>(&self, data: &'a [Option]) -> Iterator<'a, T> { - Iterator { - data, - sample: *self, - i: self.head, - } - } -} diff --git a/media/src/io/sample_builder/sample_sequence_location_test.rs b/media/src/io/sample_builder/sample_sequence_location_test.rs deleted file mode 100644 index c6e3a3757..000000000 --- a/media/src/io/sample_builder/sample_sequence_location_test.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::sample_sequence_location::*; - -#[test] -fn test_sample_sequence_location_compare() { - let s1 = SampleSequenceLocation { head: 32, tail: 42 }; - assert_eq!(s1.compare(16), Comparison::Before); - assert_eq!(s1.compare(32), Comparison::Inside); - assert_eq!(s1.compare(38), Comparison::Inside); - assert_eq!(s1.compare(41), Comparison::Inside); - assert_eq!(s1.compare(42), Comparison::After); - assert_eq!(s1.compare(0x57), Comparison::After); - - let s2 = SampleSequenceLocation { - head: 0xffa0, - tail: 32, - }; - assert_eq!(s2.compare(0xff00), Comparison::Before); - assert_eq!(s2.compare(0xffa0), Comparison::Inside); - assert_eq!(s2.compare(0xffff), Comparison::Inside); - assert_eq!(s2.compare(0), Comparison::Inside); - assert_eq!(s2.compare(31), Comparison::Inside); - assert_eq!(s2.compare(32), Comparison::After); - assert_eq!(s2.compare(128), Comparison::After); -} - -#[test] -fn test_sample_sequence_location_range() { - let mut data: Vec> = vec![None; u16::MAX as usize + 1]; - - data[65533] = Some(65533); - data[65535] = Some(65535); - data[0] = Some(0); - data[2] = Some(2); - - let s = SampleSequenceLocation { - head: 65533, - tail: 3, - }; - let reconstructed: Vec<_> = s.range(&data).map(|x| x.cloned()).collect(); - - assert_eq!( - reconstructed, - [Some(65533), None, Some(65535), Some(0), None, Some(2)] - ); -} diff --git a/media/src/lib.rs b/media/src/lib.rs deleted file mode 100644 index 264277c56..000000000 --- a/media/src/lib.rs +++ /dev/null @@ -1,110 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod audio; -mod error; -pub mod io; -pub mod video; - -use std::time::{Duration, SystemTime}; - -use bytes::Bytes; -pub use error::Error; - -/// A Sample contains encoded media and timing information -#[derive(Debug)] -pub struct Sample { - /// The assembled data in the sample, as a bitstream. - /// - /// The format is Codec dependant, but is always a bitstream format - /// rather than the packetized format used when carried over RTP. - /// - /// See: [`rtp::packetizer::Depacketizer`] and implementations of it for more details. - pub data: Bytes, - - /// The wallclock time when this sample was generated. - pub timestamp: SystemTime, - - /// The duration of this sample - pub duration: Duration, - - /// The RTP packet timestamp of this sample. - /// - /// For all RTP packets that contributed to a single sample the timestamp is the same. - pub packet_timestamp: u32, - - /// The number of packets that were dropped prior to building this sample. - /// - /// Packets being dropped doesn't necessarily indicate something wrong, e.g., packets are sometimes - /// dropped because they aren't relevant for sample building. - pub prev_dropped_packets: u16, - - /// The number of packets that were identified as padding prior to building this sample. - /// - /// Some implementations, notably libWebRTC, send padding packets to keep the send rate steady. - /// These packets don't carry media and aren't useful for building samples. - /// - /// This field can be combined with [`Sample::prev_dropped_packets`] to determine if any - /// dropped packets are likely to have detrimental impact on the steadiness of the RTP stream. - /// - /// ## Example adjustment - /// - /// ```rust - /// # use bytes::Bytes; - /// # use std::time::{SystemTime, Duration}; - /// # use webrtc_media::Sample; - /// # let sample = Sample { - /// # data: Bytes::new(), - /// # timestamp: SystemTime::now(), - /// # duration: Duration::from_secs(0), - /// # packet_timestamp: 0, - /// # prev_dropped_packets: 10, - /// # prev_padding_packets: 15 - /// # }; - /// # - /// let adjusted_dropped = - /// sample.prev_dropped_packets.saturating_sub(sample.prev_padding_packets); - /// ``` - pub prev_padding_packets: u16, -} - -impl Default for Sample { - fn default() -> Self { - Sample { - data: Bytes::new(), - timestamp: SystemTime::now(), - duration: Duration::from_secs(0), - packet_timestamp: 0, - prev_dropped_packets: 0, - prev_padding_packets: 0, - } - } -} - -impl PartialEq for Sample { - fn eq(&self, other: &Self) -> bool { - let mut equal: bool = true; - if self.data != other.data { - equal = false; - } - if self.timestamp.elapsed().unwrap().as_secs() - != other.timestamp.elapsed().unwrap().as_secs() - { - equal = false; - } - if self.duration != other.duration { - equal = false; - } - if self.packet_timestamp != other.packet_timestamp { - equal = false; - } - if self.prev_dropped_packets != other.prev_dropped_packets { - equal = false; - } - if self.prev_padding_packets != other.prev_padding_packets { - equal = false; - } - - equal - } -} diff --git a/media/src/video/mod.rs b/media/src/video/mod.rs deleted file mode 100644 index 8b1378917..000000000 --- a/media/src/video/mod.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/rtcp/.gitignore b/rtcp/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/rtcp/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/rtcp/CHANGELOG.md b/rtcp/CHANGELOG.md deleted file mode 100644 index a9e642a8c..000000000 --- a/rtcp/CHANGELOG.md +++ /dev/null @@ -1,19 +0,0 @@ -# rtcp changelog - -## Unreleased - -## v0.8.0 - -* Fix over-NACK due not resetting lost_packets bitmask [\#372](https://github.com/webrtc-rs/webrtc/pull/372/). -* Increased minimum support rust version to `1.60.0`. -* Increased required `webrtc-util` version to `0.7.0`. - -## v0.7.0 - -* [#14 Prevent crash in RTCP NACK writing](https://github.com/webrtc-rs/rtcp/pull/14) by [@pthatcher](https://github.com/pthatcher). -* Adds `IntoIterator` for `NackPair` which iterates over all the sequence numbers specified by the `NackPair`. This is similar to `packet_list` but without requiring the allocation of a Vec. Added in [#225 Add RTP Stats to stats report](https://github.com/webrtc-rs/webrtc/pull/225) by [@k0nserv](https://github.com/k0nserv). - - -## Prior to 0.7.0 - -Before 0.7.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/rtcp/releases). diff --git a/rtcp/Cargo.toml b/rtcp/Cargo.toml deleted file mode 100644 index e5ed6febc..000000000 --- a/rtcp/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "rtcp" -version = "0.11.0" -authors = ["Rain Liu ", "Michael Uti "] -edition = "2021" -description = "A pure Rust implementation of RTCP" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/rtcp" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/rtcp" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["marshal"] } - -bytes = "1" -thiserror = "1" diff --git a/rtcp/LICENSE-APACHE b/rtcp/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/rtcp/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rtcp/LICENSE-MIT b/rtcp/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/rtcp/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/rtcp/README.md b/rtcp/README.md deleted file mode 100644 index 348f180fb..000000000 --- a/rtcp/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of RTCP. Rewrite Pion RTCP in Rust -

diff --git a/rtcp/codecov.yml b/rtcp/codecov.yml deleted file mode 100644 index e72b36629..000000000 --- a/rtcp/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 2971c79d-6f37-4e06-924b-e2325e3c8a06 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/rtcp/doc/webrtc.rs.png b/rtcp/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/rtcp/doc/webrtc.rs.png and /dev/null differ diff --git a/rtcp/src/compound_packet/compound_packet_test.rs b/rtcp/src/compound_packet/compound_packet_test.rs deleted file mode 100644 index 84d5fa5a9..000000000 --- a/rtcp/src/compound_packet/compound_packet_test.rs +++ /dev/null @@ -1,329 +0,0 @@ -use super::*; -use crate::goodbye::Goodbye; -use crate::payload_feedbacks::picture_loss_indication::PictureLossIndication; - -// An RTCP packet from a packet dump -const REAL_PACKET: [u8; 116] = [ - // Receiver Report (offset=0) - 0x81, 0xc9, 0x0, 0x7, // v=2, p=0, count=1, RR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - // Source Description (offset=32) - 0x81, 0xca, 0x0, 0xc, // v=2, p=0, count=1, SDES, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x1, 0x26, // CNAME, len=38 - 0x7b, 0x39, 0x63, 0x30, 0x30, 0x65, 0x62, 0x39, 0x32, 0x2d, 0x31, 0x61, 0x66, 0x62, 0x2d, 0x39, - 0x64, 0x34, 0x39, 0x2d, 0x61, 0x34, 0x37, 0x64, 0x2d, 0x39, 0x31, 0x66, 0x36, 0x34, 0x65, 0x65, - 0x65, 0x36, 0x39, 0x66, 0x35, 0x7d, // text="{9c00eb92-1afb-9d49-a47d-91f64eee69f5}" - 0x0, 0x0, 0x0, 0x0, // END + padding - // Goodbye (offset=84) - 0x81, 0xcb, 0x0, 0x1, // v=2, p=0, count=1, BYE, len=1 - 0x90, 0x2f, 0x9e, 0x2e, // source=0x902f9e2e - 0x81, 0xce, 0x0, 0x2, // Picture Loss Indication (offset=92) - 0x90, 0x2f, 0x9e, 0x2e, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e - 0x85, 0xcd, 0x0, 0x2, // RapidResynchronizationRequest (offset=104) - 0x90, 0x2f, 0x9e, 0x2e, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e -]; - -#[test] -fn test_read_eof() { - let mut short_header = Bytes::from_static(&[ - 0x81, 0xc9, // missing type & len - ]); - let result = unmarshal(&mut short_header); - assert!(result.is_err(), "missing type & len"); -} - -#[test] -fn test_bad_compound() { - let mut bad_compound = Bytes::copy_from_slice(&REAL_PACKET[..34]); - let result = unmarshal(&mut bad_compound); - assert!(result.is_err(), "trailing data!"); - - let mut bad_compound = Bytes::copy_from_slice(&REAL_PACKET[84..104]); - let p = unmarshal(&mut bad_compound).expect("Error unmarshalling packet"); - let compound = CompoundPacket(p); - - // this should return an error, - // it violates the "must start with RR or SR" rule - match compound.validate() { - Ok(_) => panic!("validation should return an error"), - - Err(err) => { - let a = Error::BadFirstPacket; - assert_eq!( - Error::BadFirstPacket, - err, - "Unmarshal(badcompound) err={err:?}, want {a:?}", - ); - } - }; - - let compound_len = compound.0.len(); - assert_eq!( - compound_len, 2, - "Unmarshal(badcompound) len={}, want {}", - compound_len, 2 - ); - - if compound.0[0].as_any().downcast_ref::().is_none() { - panic!("Unmarshal(badcompound), want Goodbye") - } - - if compound.0[1] - .as_any() - .downcast_ref::() - .is_none() - { - panic!("Unmarshal(badcompound), want PictureLossIndication") - } -} - -#[test] -fn test_valid_packet() { - let cname = SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1234, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"cname"), - }], - }], - }; - - let tests: Vec<(&str, CompoundPacket, Option)> = vec![ - ( - "no cname", - CompoundPacket(vec![Box::::default()]), - Some(Error::MissingCname), - ), - ( - "SDES / no cname", - CompoundPacket(vec![ - Box::::default(), - Box::::default(), - ]), - Some(Error::MissingCname), - ), - ( - "just SR", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname.to_owned()), - ]), - None, - ), - ( - "multiple SRs", - CompoundPacket(vec![ - Box::::default(), - Box::::default(), - Box::new(cname.clone()), - ]), - Some(Error::PacketBeforeCname), - ), - ( - "just RR", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname.clone()), - ]), - None, - ), - ( - "multiple RRs", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname.clone()), - Box::::default(), - ]), - None, - ), - ( - "goodbye", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname), - Box::::default(), - ]), - None, - ), - ]; - - for (name, packet, error) in tests { - let result = packet.validate(); - assert_eq!(result.is_ok(), error.is_none()); - if let (Some(err), Err(got)) = (error, result) { - assert_eq!(err, got, "Valid({name}) = {got:?}, want {err:?}"); - } - } -} - -#[test] -fn test_cname() { - let cname = SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1234, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"cname"), - }], - }], - }; - - let tests: Vec<(&str, CompoundPacket, Option, &str)> = vec![ - ( - "no cname", - CompoundPacket(vec![Box::::default()]), - Some(Error::MissingCname), - "", - ), - ( - "SDES / no cname", - CompoundPacket(vec![ - Box::::default(), - Box::::default(), - ]), - Some(Error::MissingCname), - "", - ), - ( - "just SR", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname.clone()), - ]), - None, - "cname", - ), - ( - "multiple SRs", - CompoundPacket(vec![ - Box::::default(), - Box::::default(), - Box::new(cname.clone()), - ]), - Some(Error::PacketBeforeCname), - "", - ), - ( - "just RR", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname.clone()), - ]), - None, - "cname", - ), - ( - "multiple RRs", - CompoundPacket(vec![ - Box::::default(), - Box::::default(), - Box::new(cname.clone()), - ]), - None, - "cname", - ), - ( - "goodbye", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname), - Box::::default(), - ]), - None, - "cname", - ), - ]; - - for (name, compound_packet, want_error, text) in tests { - let err = compound_packet.validate(); - assert_eq!(err.is_err(), want_error.is_some()); - if let (Some(want), Err(err)) = (&want_error, err) { - assert_eq!(*want, err, "Valid({name}) = {err:?}, want {want:?}"); - } - - let name_result = compound_packet.cname(); - assert_eq!(name_result.is_err(), want_error.is_some()); - - match name_result { - Ok(e) => { - assert_eq!(e, text, "CNAME({name}) = {e:?}, want {text}",); - } - - Err(err) => { - if let Some(want) = &want_error { - assert_eq!(*want, err, "CNAME({name}) = {err:?}, want {want:?}"); - } - } - } - } -} - -#[test] -fn test_compound_packet_roundtrip() { - let cname = SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1234, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"cname"), - }], - }], - }; - - let tests = vec![ - ( - "goodbye", - CompoundPacket(vec![ - Box::::default(), - Box::new(cname), - Box::new(Goodbye { - sources: vec![1234], - ..Default::default() - }), - ]), - None, - ), - ( - "no cname", - CompoundPacket(vec![Box::::default()]), - Some(Error::MissingCname), - ), - ]; - - for (name, packet, marshal_error) in tests { - let result = packet.marshal(); - if let Some(err) = marshal_error { - if let Err(got) = result { - assert_eq!(err, got, "marshal {name} header: err = {got}, want {err}"); - } else { - panic!("want error in test {name}"); - } - continue; - } else { - assert!(result.is_ok(), "must no error in test {name}"); - } - - let data1 = result.unwrap(); - let c = CompoundPacket::unmarshal(&mut data1.clone()) - .unwrap_or_else(|_| panic!("unmarshal {name} error")); - - let data2 = c - .marshal() - .unwrap_or_else(|_| panic!("marshal {name} error")); - - assert_eq!( - data1, data2, - "Unmarshal(Marshal({name:?})) = {data1:?}, want {data2:?}" - ) - } -} diff --git a/rtcp/src/compound_packet/mod.rs b/rtcp/src/compound_packet/mod.rs deleted file mode 100644 index 6a744f177..000000000 --- a/rtcp/src/compound_packet/mod.rs +++ /dev/null @@ -1,195 +0,0 @@ -#[cfg(test)] -mod compound_packet_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::receiver_report::*; -use crate::sender_report::*; -use crate::source_description::*; -use crate::util::*; - -type Result = std::result::Result; - -/// A CompoundPacket is a collection of RTCP packets transmitted as a single packet with -/// the underlying protocol (for example UDP). -/// -/// To maximize the resolution of reception statistics, the first Packet in a CompoundPacket -/// must always be either a SenderReport or a ReceiverReport. This is true even if no data -/// has been sent or received, in which case an empty ReceiverReport must be sent, and even -/// if the only other RTCP packet in the compound packet is a Goodbye. -/// -/// Next, a SourceDescription containing a CNAME item must be included in each CompoundPacket -/// to identify the source and to begin associating media for purposes such as lip-sync. -/// -/// Other RTCP packet types may follow in any order. Packet types may appear more than once. -#[derive(Debug, Default, PartialEq, Clone)] -pub struct CompoundPacket(pub Vec>); - -impl fmt::Display for CompoundPacket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl Packet for CompoundPacket { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns the synchronization sources associated with this - /// CompoundPacket's reception report. - fn destination_ssrc(&self) -> Vec { - if self.0.is_empty() { - vec![] - } else { - self.0[0].destination_ssrc() - } - } - - fn raw_size(&self) -> usize { - let mut l = 0; - for packet in &self.0 { - l += packet.marshal_size(); - } - l - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for CompoundPacket { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for CompoundPacket { - /// Marshal encodes the CompoundPacket as binary. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - self.validate()?; - - for packet in &self.0 { - let n = packet.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for CompoundPacket { - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let mut packets = vec![]; - - while raw_packet.has_remaining() { - let p = unmarshaller(raw_packet)?; - packets.push(p); - } - - let c = CompoundPacket(packets); - c.validate()?; - - Ok(c) - } -} - -impl CompoundPacket { - /// Validate returns an error if this is not an RFC-compliant CompoundPacket. - pub fn validate(&self) -> Result<()> { - if self.0.is_empty() { - return Err(Error::EmptyCompound.into()); - } - - // SenderReport and ReceiverReport are the only types that - // are allowed to be the first packet in a compound datagram - if self.0[0].as_any().downcast_ref::().is_none() - && self.0[0] - .as_any() - .downcast_ref::() - .is_none() - { - return Err(Error::BadFirstPacket.into()); - } - - for pkt in &self.0[1..] { - // If the number of RecetpionReports exceeds 31 additional ReceiverReports - // can be included here. - if pkt.as_any().downcast_ref::().is_some() { - continue; - // A SourceDescription containing a CNAME must be included in every - // CompoundPacket. - } else if let Some(e) = pkt.as_any().downcast_ref::() { - let mut has_cname = false; - for c in &e.chunks { - for it in &c.items { - if it.sdes_type == SdesType::SdesCname { - has_cname = true - } - } - } - - if !has_cname { - return Err(Error::MissingCname.into()); - } - - return Ok(()); - - // Other packets are not permitted before the CNAME - } else { - return Err(Error::PacketBeforeCname.into()); - } - } - - // CNAME never reached - Err(Error::MissingCname.into()) - } - - /// CNAME returns the CNAME that *must* be present in every CompoundPacket - pub fn cname(&self) -> Result { - if self.0.is_empty() { - return Err(Error::EmptyCompound.into()); - } - - for pkt in &self.0[1..] { - if let Some(sdes) = pkt.as_any().downcast_ref::() { - for c in &sdes.chunks { - for it in &c.items { - if it.sdes_type == SdesType::SdesCname { - return Ok(it.text.clone()); - } - } - } - } else if pkt.as_any().downcast_ref::().is_none() { - return Err(Error::PacketBeforeCname.into()); - } - } - - Err(Error::MissingCname.into()) - } -} diff --git a/rtcp/src/error.rs b/rtcp/src/error.rs deleted file mode 100644 index d46487904..000000000 --- a/rtcp/src/error.rs +++ /dev/null @@ -1,120 +0,0 @@ -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - /// Wrong marshal size. - #[error("Wrong marshal size")] - WrongMarshalSize, - /// Packet lost exceeds maximum amount of packets - /// that can possibly be lost. - #[error("Invalid total lost count")] - InvalidTotalLost, - /// Packet contains an invalid header. - #[error("Invalid header")] - InvalidHeader, - /// Packet contains empty compound. - #[error("Empty compound packet")] - EmptyCompound, - /// Invalid first packet in compound packets. First packet - /// should either be a SenderReport packet or ReceiverReport - #[error("First packet in compound must be SR or RR")] - BadFirstPacket, - /// CNAME was not defined. - #[error("Compound missing SourceDescription with CNAME")] - MissingCname, - /// Packet was defined before CNAME. - #[error("Feedback packet seen before CNAME")] - PacketBeforeCname, - /// Too many reports. - #[error("Too many reports")] - TooManyReports, - /// Too many chunks. - #[error("Too many chunks")] - TooManyChunks, - /// Too many sources. - #[error("too many sources")] - TooManySources, - /// Packet received is too short. - #[error("Packet status chunk must be 2 bytes")] - PacketTooShort, - /// Buffer is too short. - #[error("Buffer too short to be written")] - BufferTooShort, - /// Wrong packet type. - #[error("Wrong packet type")] - WrongType, - /// SDES received is too long. - #[error("SDES must be < 255 octets long")] - SdesTextTooLong, - /// SDES type is missing. - #[error("SDES item missing type")] - SdesMissingType, - /// Reason is too long. - #[error("Reason must be < 255 octets long")] - ReasonTooLong, - /// Invalid packet version. - #[error("Invalid packet version")] - BadVersion, - /// Invalid padding value. - #[error("Invalid padding value")] - WrongPadding, - /// Wrong feedback message type. - #[error("Wrong feedback message type")] - WrongFeedbackType, - /// Wrong payload type. - #[error("Wrong payload type")] - WrongPayloadType, - /// Header length is too small. - #[error("Header length is too small")] - HeaderTooSmall, - /// Media ssrc was defined as zero. - #[error("Media SSRC must be 0")] - SsrcMustBeZero, - /// Missing REMB identifier. - #[error("Missing REMB identifier")] - MissingRembIdentifier, - /// SSRC number and length mismatches. - #[error("SSRC num and length do not match")] - SsrcNumAndLengthMismatch, - /// Invalid size or start index. - #[error("Invalid size or startIndex")] - InvalidSizeOrStartIndex, - /// Delta exceeds limit. - #[error("Delta exceed limit")] - DeltaExceedLimit, - /// Packet status chunk is not 2 bytes. - #[error("Packet status chunk must be 2 bytes")] - PacketStatusChunkLength, - #[error("Invalid bitrate")] - InvalidBitrate, - #[error("Wrong chunk type")] - WrongChunkType, - #[error("Struct contains unexpected member type")] - BadStructMemberType, - #[error("Cannot read into non-pointer")] - BadReadParameter, - - #[error("{0}")] - Util(#[from] util::Error), - - #[error("{0}")] - Other(String), -} - -impl From for util::Error { - fn from(e: Error) -> Self { - util::Error::from_std(e) - } -} - -impl PartialEq for Error { - fn eq(&self, other: &util::Error) -> bool { - if let Some(down) = other.downcast_ref::() { - return self == down; - } - false - } -} diff --git a/rtcp/src/extended_report/dlrr.rs b/rtcp/src/extended_report/dlrr.rs deleted file mode 100644 index c5d9693d3..000000000 --- a/rtcp/src/extended_report/dlrr.rs +++ /dev/null @@ -1,151 +0,0 @@ -use super::*; - -const DLRR_REPORT_LENGTH: u16 = 12; - -/// DLRRReport encodes a single report inside a DLRRReportBlock. -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct DLRRReport { - pub ssrc: u32, - pub last_rr: u32, - pub dlrr: u32, -} - -impl fmt::Display for DLRRReport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -/// DLRRReportBlock encodes a DLRR Report Block as described in -/// RFC 3611 section 4.5. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | BT=5 | reserved | block length | -/// +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ -/// | SSRC_1 (ssrc of first receiver) | sub- -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ block -/// | last RR (LRR) | 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | delay since last RR (DLRR) | -/// +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ -/// | SSRC_2 (ssrc of second receiver) | sub- -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ block -/// : ... : 2 -/// +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct DLRRReportBlock { - pub reports: Vec, -} - -impl fmt::Display for DLRRReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl DLRRReportBlock { - pub fn xr_header(&self) -> XRHeader { - XRHeader { - block_type: BlockType::DLRR, - type_specific: 0, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl Packet for DLRRReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - let mut ssrc = Vec::with_capacity(self.reports.len()); - for r in &self.reports { - ssrc.push(r.ssrc); - } - ssrc - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + self.reports.len() * 4 * 3 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for DLRRReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for DLRRReportBlock { - /// marshal_to encodes the DLRRReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - for rep in &self.reports { - buf.put_u32(rep.ssrc); - buf.put_u32(rep.last_rr); - buf.put_u32(rep.dlrr); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for DLRRReportBlock { - /// Unmarshal decodes the DLRRReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if block_length % DLRR_REPORT_LENGTH != 0 || raw_packet.remaining() < block_length as usize - { - return Err(error::Error::PacketTooShort.into()); - } - - let mut offset = 0; - let mut reports = vec![]; - while offset < block_length { - let ssrc = raw_packet.get_u32(); - let last_rr = raw_packet.get_u32(); - let dlrr = raw_packet.get_u32(); - reports.push(DLRRReport { - ssrc, - last_rr, - dlrr, - }); - offset += DLRR_REPORT_LENGTH; - } - - Ok(DLRRReportBlock { reports }) - } -} diff --git a/rtcp/src/extended_report/extended_report_test.rs b/rtcp/src/extended_report/extended_report_test.rs deleted file mode 100644 index 3e42526e1..000000000 --- a/rtcp/src/extended_report/extended_report_test.rs +++ /dev/null @@ -1,184 +0,0 @@ -use super::*; - -fn decoded_packet() -> ExtendedReport { - ExtendedReport { - sender_ssrc: 0x01020304, - reports: vec![ - Box::new(LossRLEReportBlock { - is_loss_rle: true, - t: 12, - - ssrc: 0x12345689, - begin_seq: 5, - end_seq: 12, - chunks: vec![Chunk(0x4006), Chunk(0x0006), Chunk(0x8765), Chunk(0x0000)], - }), - Box::new(DuplicateRLEReportBlock { - is_loss_rle: false, - t: 6, - - ssrc: 0x12345689, - begin_seq: 5, - end_seq: 12, - chunks: vec![Chunk(0x4123), Chunk(0x3FFF), Chunk(0xFFFF), Chunk(0x0000)], - }), - Box::new(PacketReceiptTimesReportBlock { - t: 3, - - ssrc: 0x98765432, - begin_seq: 15432, - end_seq: 15577, - receipt_time: vec![0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555], - }), - Box::new(ReceiverReferenceTimeReportBlock { - ntp_timestamp: 0x0102030405060708, - }), - Box::new(DLRRReportBlock { - reports: vec![ - DLRRReport { - ssrc: 0x88888888, - last_rr: 0x12345678, - dlrr: 0x99999999, - }, - DLRRReport { - ssrc: 0x09090909, - last_rr: 0x12345678, - dlrr: 0x99999999, - }, - DLRRReport { - ssrc: 0x11223344, - last_rr: 0x12345678, - dlrr: 0x99999999, - }, - ], - }), - Box::new(StatisticsSummaryReportBlock { - loss_reports: true, - duplicate_reports: true, - jitter_reports: true, - ttl_or_hop_limit: TTLorHopLimitType::IPv4, - - ssrc: 0xFEDCBA98, - begin_seq: 0x1234, - end_seq: 0x5678, - lost_packets: 0x11111111, - dup_packets: 0x22222222, - min_jitter: 0x33333333, - max_jitter: 0x44444444, - mean_jitter: 0x55555555, - dev_jitter: 0x66666666, - min_ttl_or_hl: 0x01, - max_ttl_or_hl: 0x02, - mean_ttl_or_hl: 0x03, - dev_ttl_or_hl: 0x04, - }), - Box::new(VoIPMetricsReportBlock { - ssrc: 0x89ABCDEF, - loss_rate: 0x05, - discard_rate: 0x06, - burst_density: 0x07, - gap_density: 0x08, - burst_duration: 0x1111, - gap_duration: 0x2222, - round_trip_delay: 0x3333, - end_system_delay: 0x4444, - signal_level: 0x11, - noise_level: 0x22, - rerl: 0x33, - gmin: 0x44, - rfactor: 0x55, - ext_rfactor: 0x66, - mos_lq: 0x77, - mos_cq: 0x88, - rx_config: 0x99, - reserved: 0x00, - jb_nominal: 0x1122, - jb_maximum: 0x3344, - jb_abs_max: 0x5566, - }), - ], - } -} - -fn encoded_packet() -> Bytes { - Bytes::from_static(&[ - // RTP Header - 0x80, 0xCF, 0x00, 0x33, // byte 0 - 3 - // Sender SSRC - 0x01, 0x02, 0x03, 0x04, // Loss RLE Report Block - 0x01, 0x0C, 0x00, 0x04, // byte 8 - 11 - // Source SSRC - 0x12, 0x34, 0x56, 0x89, // Begin & End Seq - 0x00, 0x05, 0x00, 0x0C, // byte 16 - 19 - // Chunks - 0x40, 0x06, 0x00, 0x06, 0x87, 0x65, 0x00, 0x00, // byte 24 - 27 - // Duplicate RLE Report Block - 0x02, 0x06, 0x00, 0x04, // Source SSRC - 0x12, 0x34, 0x56, 0x89, // byte 32 - 35 - // Begin & End Seq - 0x00, 0x05, 0x00, 0x0C, // Chunks - 0x41, 0x23, 0x3F, 0xFF, // byte 40 - 43 - 0xFF, 0xFF, 0x00, 0x00, // Packet Receipt Times Report Block - 0x03, 0x03, 0x00, 0x07, // byte 48 - 51 - // Source SSRC - 0x98, 0x76, 0x54, 0x32, // Begin & End Seq - 0x3C, 0x48, 0x3C, 0xD9, // byte 56 - 59 - // Receipt times - 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, // byte 64 - 67 - 0x33, 0x33, 0x33, 0x33, 0x44, 0x44, 0x44, 0x44, // byte 72 - 75 - 0x55, 0x55, 0x55, 0x55, // Receiver Reference Time Report - 0x04, 0x00, 0x00, 0x02, // byte 80 - 83 - // Timestamp - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // byte 88 - 91 - // DLRR Report - 0x05, 0x00, 0x00, 0x09, // SSRC 1 - 0x88, 0x88, 0x88, 0x88, // byte 96 - 99 - // LastRR 1 - 0x12, 0x34, 0x56, 0x78, // DLRR 1 - 0x99, 0x99, 0x99, 0x99, // byte 104 - 107 - // SSRC 2 - 0x09, 0x09, 0x09, 0x09, // LastRR 2 - 0x12, 0x34, 0x56, 0x78, // byte 112 - 115 - // DLRR 2 - 0x99, 0x99, 0x99, 0x99, // SSRC 3 - 0x11, 0x22, 0x33, 0x44, // byte 120 - 123 - // LastRR 3 - 0x12, 0x34, 0x56, 0x78, // DLRR 3 - 0x99, 0x99, 0x99, 0x99, // byte 128 - 131 - // Statistics Summary Report - 0x06, 0xE8, 0x00, 0x09, // SSRC - 0xFE, 0xDC, 0xBA, 0x98, // byte 136 - 139 - // Various statistics - 0x12, 0x34, 0x56, 0x78, 0x11, 0x11, 0x11, 0x11, // byte 144 - 147 - 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, // byte 152 - 155 - 0x44, 0x44, 0x44, 0x44, 0x55, 0x55, 0x55, 0x55, // byte 160 - 163 - 0x66, 0x66, 0x66, 0x66, 0x01, 0x02, 0x03, 0x04, // byte 168 - 171 - // VoIP Metrics Report - 0x07, 0x00, 0x00, 0x08, // SSRC - 0x89, 0xAB, 0xCD, 0xEF, // byte 176 - 179 - // Various statistics - 0x05, 0x06, 0x07, 0x08, 0x11, 0x11, 0x22, 0x22, // byte 184 - 187 - 0x33, 0x33, 0x44, 0x44, 0x11, 0x22, 0x33, 0x44, // byte 192 - 195 - 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, // byte 200 - 203 - 0x33, 0x44, 0x55, 0x66, // byte 204 - 207 - ]) -} - -#[test] -fn test_encode() -> Result<()> { - let expected = encoded_packet(); - let packet = decoded_packet(); - let actual = packet.marshal()?; - assert_eq!(actual, expected); - Ok(()) -} - -#[test] -fn test_decode() -> Result<()> { - let mut encoded = encoded_packet(); - let expected = decoded_packet(); - let actual = ExtendedReport::unmarshal(&mut encoded)?; - assert_eq!(actual, expected); - assert_eq!(actual.to_string(), expected.to_string()); - Ok(()) -} diff --git a/rtcp/src/extended_report/mod.rs b/rtcp/src/extended_report/mod.rs deleted file mode 100644 index e8a2f3732..000000000 --- a/rtcp/src/extended_report/mod.rs +++ /dev/null @@ -1,302 +0,0 @@ -#[cfg(test)] -mod extended_report_test; - -pub mod dlrr; -pub mod prt; -pub mod rle; -pub mod rrt; -pub mod ssr; -pub mod unknown; -pub mod vm; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes}; -pub use dlrr::{DLRRReport, DLRRReportBlock}; -pub use prt::PacketReceiptTimesReportBlock; -pub use rle::{Chunk, ChunkType, DuplicateRLEReportBlock, LossRLEReportBlock, RLEReportBlock}; -pub use rrt::ReceiverReferenceTimeReportBlock; -pub use ssr::{StatisticsSummaryReportBlock, TTLorHopLimitType}; -pub use unknown::UnknownReportBlock; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; -pub use vm::VoIPMetricsReportBlock; - -use crate::error; -use crate::header::{Header, PacketType, HEADER_LENGTH, SSRC_LENGTH}; -use crate::packet::Packet; -use crate::util::{get_padding_size, put_padding}; - -type Result = std::result::Result; - -const XR_HEADER_LENGTH: usize = 4; - -/// BlockType specifies the type of report in a report block -/// Extended Report block types from RFC 3611. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum BlockType { - #[default] - Unknown = 0, - LossRLE = 1, // RFC 3611, section 4.1 - DuplicateRLE = 2, // RFC 3611, section 4.2 - PacketReceiptTimes = 3, // RFC 3611, section 4.3 - ReceiverReferenceTime = 4, // RFC 3611, section 4.4 - DLRR = 5, // RFC 3611, section 4.5 - StatisticsSummary = 6, // RFC 3611, section 4.6 - VoIPMetrics = 7, // RFC 3611, section 4.7 -} - -impl From for BlockType { - fn from(v: u8) -> Self { - match v { - 1 => BlockType::LossRLE, - 2 => BlockType::DuplicateRLE, - 3 => BlockType::PacketReceiptTimes, - 4 => BlockType::ReceiverReferenceTime, - 5 => BlockType::DLRR, - 6 => BlockType::StatisticsSummary, - 7 => BlockType::VoIPMetrics, - _ => BlockType::Unknown, - } - } -} - -/// converts the Extended report block types into readable strings -impl fmt::Display for BlockType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - BlockType::LossRLE => "LossRLEReportBlockType", - BlockType::DuplicateRLE => "DuplicateRLEReportBlockType", - BlockType::PacketReceiptTimes => "PacketReceiptTimesReportBlockType", - BlockType::ReceiverReferenceTime => "ReceiverReferenceTimeReportBlockType", - BlockType::DLRR => "DLRRReportBlockType", - BlockType::StatisticsSummary => "StatisticsSummaryReportBlockType", - BlockType::VoIPMetrics => "VoIPMetricsReportBlockType", - _ => "UnknownReportBlockType", - }; - write!(f, "{s}") - } -} - -/// TypeSpecificField as described in RFC 3611 section 4.5. In typical -/// cases, users of ExtendedReports shouldn't need to access this, -/// and should instead use the corresponding fields in the actual -/// report blocks themselves. -pub type TypeSpecificField = u8; - -/// XRHeader defines the common fields that must appear at the start -/// of each report block. In typical cases, users of ExtendedReports -/// shouldn't need to access this. For locally-constructed report -/// blocks, these values will not be accurate until the corresponding -/// packet is marshaled. -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct XRHeader { - pub block_type: BlockType, - pub type_specific: TypeSpecificField, - pub block_length: u16, -} - -impl MarshalSize for XRHeader { - fn marshal_size(&self) -> usize { - XR_HEADER_LENGTH - } -} - -impl Marshal for XRHeader { - /// marshal_to encodes the ExtendedReport in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < XR_HEADER_LENGTH { - return Err(error::Error::BufferTooShort.into()); - } - - buf.put_u8(self.block_type as u8); - buf.put_u8(self.type_specific); - buf.put_u16(self.block_length); - - Ok(XR_HEADER_LENGTH) - } -} - -impl Unmarshal for XRHeader { - /// Unmarshal decodes the ExtendedReport from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let block_type: BlockType = raw_packet.get_u8().into(); - let type_specific = raw_packet.get_u8(); - let block_length = raw_packet.get_u16(); - - Ok(XRHeader { - block_type, - type_specific, - block_length, - }) - } -} -/// The ExtendedReport packet is an Implementation of RTCP Extended -/// reports defined in RFC 3611. It is used to convey detailed -/// information about an RTP stream. Each packet contains one or -/// more report blocks, each of which conveys a different kind of -/// information. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |V=2|P|reserved | PT=XR=207 | length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ssrc | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// : report blocks : -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, PartialEq, Default, Clone)] -pub struct ExtendedReport { - pub sender_ssrc: u32, - pub reports: Vec>, -} - -impl fmt::Display for ExtendedReport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl Packet for ExtendedReport { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: 0, - packet_type: PacketType::ExtendedReport, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of ssrc values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - let mut ssrc = vec![]; - for p in &self.reports { - ssrc.extend(p.destination_ssrc()); - } - ssrc - } - - fn raw_size(&self) -> usize { - let mut reps_length = 0; - for rep in &self.reports { - reps_length += rep.marshal_size(); - } - HEADER_LENGTH + SSRC_LENGTH + reps_length - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for ExtendedReport { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for ExtendedReport { - /// marshal_to encodes the ExtendedReport in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - - for report in &self.reports { - let n = report.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for ExtendedReport { - /// Unmarshal decodes the ExtendedReport from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SSRC_LENGTH) { - return Err(error::Error::PacketTooShort.into()); - } - - let header = Header::unmarshal(raw_packet)?; - if header.packet_type != PacketType::ExtendedReport { - return Err(error::Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - - let mut offset = HEADER_LENGTH + SSRC_LENGTH; - let mut reports = vec![]; - while raw_packet.remaining() > 0 { - if offset + XR_HEADER_LENGTH > raw_packet_len { - return Err(error::Error::PacketTooShort.into()); - } - - let block_type: BlockType = raw_packet.chunk()[0].into(); - let report: Box = match block_type { - BlockType::LossRLE => Box::new(LossRLEReportBlock::unmarshal(raw_packet)?), - BlockType::DuplicateRLE => { - Box::new(DuplicateRLEReportBlock::unmarshal(raw_packet)?) - } - BlockType::PacketReceiptTimes => { - Box::new(PacketReceiptTimesReportBlock::unmarshal(raw_packet)?) - } - BlockType::ReceiverReferenceTime => { - Box::new(ReceiverReferenceTimeReportBlock::unmarshal(raw_packet)?) - } - BlockType::DLRR => Box::new(DLRRReportBlock::unmarshal(raw_packet)?), - BlockType::StatisticsSummary => { - Box::new(StatisticsSummaryReportBlock::unmarshal(raw_packet)?) - } - BlockType::VoIPMetrics => Box::new(VoIPMetricsReportBlock::unmarshal(raw_packet)?), - _ => Box::new(UnknownReportBlock::unmarshal(raw_packet)?), - }; - - offset += report.marshal_size(); - reports.push(report); - } - - Ok(ExtendedReport { - sender_ssrc, - reports, - }) - } -} diff --git a/rtcp/src/extended_report/prt.rs b/rtcp/src/extended_report/prt.rs deleted file mode 100644 index c9da820cb..000000000 --- a/rtcp/src/extended_report/prt.rs +++ /dev/null @@ -1,150 +0,0 @@ -use super::*; - -const PRT_REPORT_BLOCK_MIN_LENGTH: u16 = 8; - -/// PacketReceiptTimesReportBlock represents a Packet Receipt Times -/// report block, as described in RFC 3611 section 4.3. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | BT=3 | rsvd. | t | block length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ssrc of source | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | begin_seq | end_seq | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Receipt time of packet begin_seq | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Receipt time of packet (begin_seq + 1) mod 65536 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// : ... : -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Receipt time of packet (end_seq - 1) mod 65536 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct PacketReceiptTimesReportBlock { - //not included in marshal/unmarshal - pub t: u8, - - //marshal/unmarshal - pub ssrc: u32, - pub begin_seq: u16, - pub end_seq: u16, - pub receipt_time: Vec, -} - -impl fmt::Display for PacketReceiptTimesReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl PacketReceiptTimesReportBlock { - pub fn xr_header(&self) -> XRHeader { - XRHeader { - block_type: BlockType::PacketReceiptTimes, - type_specific: self.t & 0x0F, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl Packet for PacketReceiptTimesReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.ssrc] - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + PRT_REPORT_BLOCK_MIN_LENGTH as usize + self.receipt_time.len() * 4 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for PacketReceiptTimesReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for PacketReceiptTimesReportBlock { - /// marshal_to encodes the PacketReceiptTimesReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.ssrc); - buf.put_u16(self.begin_seq); - buf.put_u16(self.end_seq); - for rt in &self.receipt_time { - buf.put_u32(*rt); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for PacketReceiptTimesReportBlock { - /// Unmarshal decodes the PacketReceiptTimesReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if block_length < PRT_REPORT_BLOCK_MIN_LENGTH - || (block_length - PRT_REPORT_BLOCK_MIN_LENGTH) % 4 != 0 - || raw_packet.remaining() < block_length as usize - { - return Err(error::Error::PacketTooShort.into()); - } - - let t = xr_header.type_specific & 0x0F; - - let ssrc = raw_packet.get_u32(); - let begin_seq = raw_packet.get_u16(); - let end_seq = raw_packet.get_u16(); - - let remaining = block_length - PRT_REPORT_BLOCK_MIN_LENGTH; - let mut receipt_time = vec![]; - for _ in 0..remaining / 4 { - receipt_time.push(raw_packet.get_u32()); - } - - Ok(PacketReceiptTimesReportBlock { - t, - - ssrc, - begin_seq, - end_seq, - receipt_time, - }) - } -} diff --git a/rtcp/src/extended_report/rle.rs b/rtcp/src/extended_report/rle.rs deleted file mode 100644 index fe6770b76..000000000 --- a/rtcp/src/extended_report/rle.rs +++ /dev/null @@ -1,246 +0,0 @@ -use super::*; - -const RLE_REPORT_BLOCK_MIN_LENGTH: u16 = 8; - -/// ChunkType enumerates the three kinds of chunks described in RFC 3611 section 4.1. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum ChunkType { - RunLength = 0, - BitVector = 1, - TerminatingNull = 2, -} - -/// Chunk as defined in RFC 3611, section 4.1. These represent information -/// about packet losses and packet duplication. They have three representations: -/// -/// Run Length Chunk: -/// -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |C|R| run length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Bit Vector Chunk: -/// -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |C| bit vector | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Terminating Null Chunk: -/// -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0| -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct Chunk(pub u16); - -impl fmt::Display for Chunk { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.chunk_type() { - ChunkType::RunLength => { - let run_type = self.run_type().unwrap_or(0); - write!(f, "[RunLength type={}, length={}]", run_type, self.value()) - } - ChunkType::BitVector => write!(f, "[BitVector {:#b}", self.value()), - ChunkType::TerminatingNull => write!(f, "[TerminatingNull]"), - } - } -} -impl Chunk { - /// chunk_type returns the ChunkType that this Chunk represents - pub fn chunk_type(&self) -> ChunkType { - if self.0 == 0 { - ChunkType::TerminatingNull - } else if (self.0 >> 15) == 0 { - ChunkType::RunLength - } else { - ChunkType::BitVector - } - } - - /// run_type returns the run_type that this Chunk represents. It is - /// only valid if ChunkType is RunLengthChunkType. - pub fn run_type(&self) -> error::Result { - if self.chunk_type() != ChunkType::RunLength { - Err(error::Error::WrongChunkType) - } else { - Ok((self.0 >> 14) as u8 & 0x01) - } - } - - /// value returns the value represented in this Chunk - pub fn value(&self) -> u16 { - match self.chunk_type() { - ChunkType::RunLength => self.0 & 0x3FFF, - ChunkType::BitVector => self.0 & 0x7FFF, - ChunkType::TerminatingNull => 0, - } - } -} - -/// RleReportBlock defines the common structure used by both -/// Loss RLE report blocks (RFC 3611 ยง4.1) and Duplicate RLE -/// report blocks (RFC 3611 ยง4.2). -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | BT = 1 or 2 | rsvd. | t | block length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ssrc of source | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | begin_seq | end_seq | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | chunk 1 | chunk 2 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// : ... : -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | chunk n-1 | chunk n | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct RLEReportBlock { - //not included in marshal/unmarshal - pub is_loss_rle: bool, - pub t: u8, - - //marshal/unmarshal - pub ssrc: u32, - pub begin_seq: u16, - pub end_seq: u16, - pub chunks: Vec, -} - -/// LossRLEReportBlock is used to report information about packet -/// losses, as described in RFC 3611, section 4.1 -/// make sure to set is_loss_rle = true -pub type LossRLEReportBlock = RLEReportBlock; - -/// DuplicateRLEReportBlock is used to report information about packet -/// duplication, as described in RFC 3611, section 4.1 -/// make sure to set is_loss_rle = false -pub type DuplicateRLEReportBlock = RLEReportBlock; - -impl RLEReportBlock { - pub fn xr_header(&self) -> XRHeader { - XRHeader { - block_type: if self.is_loss_rle { - BlockType::LossRLE - } else { - BlockType::DuplicateRLE - }, - type_specific: self.t & 0x0F, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl fmt::Display for RLEReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl Packet for RLEReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.ssrc] - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + RLE_REPORT_BLOCK_MIN_LENGTH as usize + self.chunks.len() * 2 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for RLEReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for RLEReportBlock { - /// marshal_to encodes the RLEReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.ssrc); - buf.put_u16(self.begin_seq); - buf.put_u16(self.end_seq); - for chunk in &self.chunks { - buf.put_u16(chunk.0); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for RLEReportBlock { - /// Unmarshal decodes the RLEReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if block_length < RLE_REPORT_BLOCK_MIN_LENGTH - || (block_length - RLE_REPORT_BLOCK_MIN_LENGTH) % 2 != 0 - || raw_packet.remaining() < block_length as usize - { - return Err(error::Error::PacketTooShort.into()); - } - - let is_loss_rle = xr_header.block_type == BlockType::LossRLE; - let t = xr_header.type_specific & 0x0F; - - let ssrc = raw_packet.get_u32(); - let begin_seq = raw_packet.get_u16(); - let end_seq = raw_packet.get_u16(); - - let remaining = block_length - RLE_REPORT_BLOCK_MIN_LENGTH; - let mut chunks = vec![]; - for _ in 0..remaining / 2 { - chunks.push(Chunk(raw_packet.get_u16())); - } - - Ok(RLEReportBlock { - is_loss_rle, - t, - ssrc, - begin_seq, - end_seq, - chunks, - }) - } -} diff --git a/rtcp/src/extended_report/rrt.rs b/rtcp/src/extended_report/rrt.rs deleted file mode 100644 index fa0b34d65..000000000 --- a/rtcp/src/extended_report/rrt.rs +++ /dev/null @@ -1,111 +0,0 @@ -use super::*; - -const RRT_REPORT_BLOCK_LENGTH: u16 = 8; - -/// ReceiverReferenceTimeReportBlock encodes a Receiver Reference Time -/// report block as described in RFC 3611 section 4.4. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | BT=4 | reserved | block length = 2 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | NTP timestamp, most significant word | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | NTP timestamp, least significant word | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct ReceiverReferenceTimeReportBlock { - pub ntp_timestamp: u64, -} - -impl fmt::Display for ReceiverReferenceTimeReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl ReceiverReferenceTimeReportBlock { - pub fn xr_header(&self) -> XRHeader { - XRHeader { - block_type: BlockType::ReceiverReferenceTime, - type_specific: 0, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl Packet for ReceiverReferenceTimeReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - vec![] - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + RRT_REPORT_BLOCK_LENGTH as usize - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for ReceiverReferenceTimeReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for ReceiverReferenceTimeReportBlock { - /// marshal_to encodes the ReceiverReferenceTimeReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u64(self.ntp_timestamp); - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for ReceiverReferenceTimeReportBlock { - /// Unmarshal decodes the ReceiverReferenceTimeReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if block_length != RRT_REPORT_BLOCK_LENGTH || raw_packet.remaining() < block_length as usize - { - return Err(error::Error::PacketTooShort.into()); - } - - let ntp_timestamp = raw_packet.get_u64(); - - Ok(ReceiverReferenceTimeReportBlock { ntp_timestamp }) - } -} diff --git a/rtcp/src/extended_report/ssr.rs b/rtcp/src/extended_report/ssr.rs deleted file mode 100644 index c4c29e0b5..000000000 --- a/rtcp/src/extended_report/ssr.rs +++ /dev/null @@ -1,235 +0,0 @@ -use super::*; - -const SSR_REPORT_BLOCK_LENGTH: u16 = 4 + 2 * 2 + 4 * 6 + 4; - -/// StatisticsSummaryReportBlock encodes a Statistics Summary Report -/// Block as described in RFC 3611, section 4.6. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | BT=6 |L|D|J|ToH|rsvd.| block length = 9 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ssrc of source | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | begin_seq | end_seq | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | lost_packets | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | dup_packets | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | min_jitter | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | max_jitter | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | mean_jitter | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | dev_jitter | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | min_ttl_or_hl | max_ttl_or_hl |mean_ttl_or_hl | dev_ttl_or_hl | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct StatisticsSummaryReportBlock { - //not included in marshal/unmarshal - pub loss_reports: bool, - pub duplicate_reports: bool, - pub jitter_reports: bool, - pub ttl_or_hop_limit: TTLorHopLimitType, - - //marshal/unmarshal - pub ssrc: u32, - pub begin_seq: u16, - pub end_seq: u16, - pub lost_packets: u32, - pub dup_packets: u32, - pub min_jitter: u32, - pub max_jitter: u32, - pub mean_jitter: u32, - pub dev_jitter: u32, - pub min_ttl_or_hl: u8, - pub max_ttl_or_hl: u8, - pub mean_ttl_or_hl: u8, - pub dev_ttl_or_hl: u8, -} - -impl fmt::Display for StatisticsSummaryReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -/// TTLorHopLimitType encodes values for the ToH field in -/// a StatisticsSummaryReportBlock -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum TTLorHopLimitType { - #[default] - Missing = 0, - IPv4 = 1, - IPv6 = 2, -} - -impl From for TTLorHopLimitType { - fn from(v: u8) -> Self { - match v { - 1 => TTLorHopLimitType::IPv4, - 2 => TTLorHopLimitType::IPv6, - _ => TTLorHopLimitType::Missing, - } - } -} - -impl fmt::Display for TTLorHopLimitType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - TTLorHopLimitType::Missing => "[ToH Missing]", - TTLorHopLimitType::IPv4 => "[ToH = IPv4]", - TTLorHopLimitType::IPv6 => "[ToH = IPv6]", - }; - write!(f, "{s}") - } -} - -impl StatisticsSummaryReportBlock { - pub fn xr_header(&self) -> XRHeader { - let mut type_specific = 0x00; - if self.loss_reports { - type_specific |= 0x80; - } - if self.duplicate_reports { - type_specific |= 0x40; - } - if self.jitter_reports { - type_specific |= 0x20; - } - type_specific |= (self.ttl_or_hop_limit as u8 & 0x03) << 3; - - XRHeader { - block_type: BlockType::StatisticsSummary, - type_specific, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl Packet for StatisticsSummaryReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.ssrc] - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + SSR_REPORT_BLOCK_LENGTH as usize - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for StatisticsSummaryReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for StatisticsSummaryReportBlock { - /// marshal_to encodes the StatisticsSummaryReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.ssrc); - buf.put_u16(self.begin_seq); - buf.put_u16(self.end_seq); - buf.put_u32(self.lost_packets); - buf.put_u32(self.dup_packets); - buf.put_u32(self.min_jitter); - buf.put_u32(self.max_jitter); - buf.put_u32(self.mean_jitter); - buf.put_u32(self.dev_jitter); - buf.put_u8(self.min_ttl_or_hl); - buf.put_u8(self.max_ttl_or_hl); - buf.put_u8(self.mean_ttl_or_hl); - buf.put_u8(self.dev_ttl_or_hl); - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for StatisticsSummaryReportBlock { - /// Unmarshal decodes the StatisticsSummaryReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if block_length != SSR_REPORT_BLOCK_LENGTH || raw_packet.remaining() < block_length as usize - { - return Err(error::Error::PacketTooShort.into()); - } - - let loss_reports = xr_header.type_specific & 0x80 != 0; - let duplicate_reports = xr_header.type_specific & 0x40 != 0; - let jitter_reports = xr_header.type_specific & 0x20 != 0; - let ttl_or_hop_limit: TTLorHopLimitType = ((xr_header.type_specific & 0x18) >> 3).into(); - - let ssrc = raw_packet.get_u32(); - let begin_seq = raw_packet.get_u16(); - let end_seq = raw_packet.get_u16(); - let lost_packets = raw_packet.get_u32(); - let dup_packets = raw_packet.get_u32(); - let min_jitter = raw_packet.get_u32(); - let max_jitter = raw_packet.get_u32(); - let mean_jitter = raw_packet.get_u32(); - let dev_jitter = raw_packet.get_u32(); - let min_ttl_or_hl = raw_packet.get_u8(); - let max_ttl_or_hl = raw_packet.get_u8(); - let mean_ttl_or_hl = raw_packet.get_u8(); - let dev_ttl_or_hl = raw_packet.get_u8(); - - Ok(StatisticsSummaryReportBlock { - loss_reports, - duplicate_reports, - jitter_reports, - ttl_or_hop_limit, - - ssrc, - begin_seq, - end_seq, - lost_packets, - dup_packets, - min_jitter, - max_jitter, - mean_jitter, - dev_jitter, - min_ttl_or_hl, - max_ttl_or_hl, - mean_ttl_or_hl, - dev_ttl_or_hl, - }) - } -} diff --git a/rtcp/src/extended_report/unknown.rs b/rtcp/src/extended_report/unknown.rs deleted file mode 100644 index 707d77231..000000000 --- a/rtcp/src/extended_report/unknown.rs +++ /dev/null @@ -1,98 +0,0 @@ -use super::*; - -/// UnknownReportBlock is used to store bytes for any report block -/// that has an unknown Report Block Type. -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnknownReportBlock { - pub bytes: Bytes, -} - -impl fmt::Display for UnknownReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl UnknownReportBlock { - pub fn xr_header(&self) -> XRHeader { - XRHeader { - block_type: BlockType::Unknown, - type_specific: 0, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl Packet for UnknownReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - vec![] - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + self.bytes.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for UnknownReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for UnknownReportBlock { - /// marshal_to encodes the UnknownReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put(self.bytes.clone()); - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for UnknownReportBlock { - /// Unmarshal decodes the UnknownReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if raw_packet.remaining() < block_length as usize { - return Err(error::Error::PacketTooShort.into()); - } - - let bytes = raw_packet.copy_to_bytes(block_length as usize); - - Ok(UnknownReportBlock { bytes }) - } -} diff --git a/rtcp/src/extended_report/vm.rs b/rtcp/src/extended_report/vm.rs deleted file mode 100644 index 91cb79099..000000000 --- a/rtcp/src/extended_report/vm.rs +++ /dev/null @@ -1,209 +0,0 @@ -use super::*; - -const VM_REPORT_BLOCK_LENGTH: u16 = 4 + 4 + 2 * 4 + 10 + 2 * 3; - -/// VoIPMetricsReportBlock encodes a VoIP Metrics Report Block as described -/// in RFC 3611, section 4.7. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | BT=7 | reserved | block length = 8 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ssrc of source | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | loss rate | discard rate | burst density | gap density | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | burst duration | gap duration | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | round trip delay | end system delay | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | signal level | noise level | RERL | Gmin | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | R factor | ext. R factor | MOS-LQ | MOS-CQ | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | RX config | reserved | JB nominal | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | JB maximum | JB abs max | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct VoIPMetricsReportBlock { - pub ssrc: u32, - pub loss_rate: u8, - pub discard_rate: u8, - pub burst_density: u8, - pub gap_density: u8, - pub burst_duration: u16, - pub gap_duration: u16, - pub round_trip_delay: u16, - pub end_system_delay: u16, - pub signal_level: u8, - pub noise_level: u8, - pub rerl: u8, - pub gmin: u8, - pub rfactor: u8, - pub ext_rfactor: u8, - pub mos_lq: u8, - pub mos_cq: u8, - pub rx_config: u8, - pub reserved: u8, - pub jb_nominal: u16, - pub jb_maximum: u16, - pub jb_abs_max: u16, -} - -impl fmt::Display for VoIPMetricsReportBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl VoIPMetricsReportBlock { - pub fn xr_header(&self) -> XRHeader { - XRHeader { - block_type: BlockType::VoIPMetrics, - type_specific: 0, - block_length: (self.raw_size() / 4 - 1) as u16, - } - } -} - -impl Packet for VoIPMetricsReportBlock { - fn header(&self) -> Header { - Header::default() - } - - /// destination_ssrc returns an array of ssrc values that this report block refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.ssrc] - } - - fn raw_size(&self) -> usize { - XR_HEADER_LENGTH + VM_REPORT_BLOCK_LENGTH as usize - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for VoIPMetricsReportBlock { - fn marshal_size(&self) -> usize { - self.raw_size() - } -} - -impl Marshal for VoIPMetricsReportBlock { - /// marshal_to encodes the VoIPMetricsReportBlock in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(error::Error::BufferTooShort.into()); - } - - let h = self.xr_header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.ssrc); - buf.put_u8(self.loss_rate); - buf.put_u8(self.discard_rate); - buf.put_u8(self.burst_density); - buf.put_u8(self.gap_density); - buf.put_u16(self.burst_duration); - buf.put_u16(self.gap_duration); - buf.put_u16(self.round_trip_delay); - buf.put_u16(self.end_system_delay); - buf.put_u8(self.signal_level); - buf.put_u8(self.noise_level); - buf.put_u8(self.rerl); - buf.put_u8(self.gmin); - buf.put_u8(self.rfactor); - buf.put_u8(self.ext_rfactor); - buf.put_u8(self.mos_lq); - buf.put_u8(self.mos_cq); - buf.put_u8(self.rx_config); - buf.put_u8(self.reserved); - buf.put_u16(self.jb_nominal); - buf.put_u16(self.jb_maximum); - buf.put_u16(self.jb_abs_max); - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for VoIPMetricsReportBlock { - /// Unmarshal decodes the VoIPMetricsReportBlock from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < XR_HEADER_LENGTH { - return Err(error::Error::PacketTooShort.into()); - } - - let xr_header = XRHeader::unmarshal(raw_packet)?; - let block_length = xr_header.block_length * 4; - if block_length != VM_REPORT_BLOCK_LENGTH || raw_packet.remaining() < block_length as usize - { - return Err(error::Error::PacketTooShort.into()); - } - - let ssrc = raw_packet.get_u32(); - let loss_rate = raw_packet.get_u8(); - let discard_rate = raw_packet.get_u8(); - let burst_density = raw_packet.get_u8(); - let gap_density = raw_packet.get_u8(); - let burst_duration = raw_packet.get_u16(); - let gap_duration = raw_packet.get_u16(); - let round_trip_delay = raw_packet.get_u16(); - let end_system_delay = raw_packet.get_u16(); - let signal_level = raw_packet.get_u8(); - let noise_level = raw_packet.get_u8(); - let rerl = raw_packet.get_u8(); - let gmin = raw_packet.get_u8(); - let rfactor = raw_packet.get_u8(); - let ext_rfactor = raw_packet.get_u8(); - let mos_lq = raw_packet.get_u8(); - let mos_cq = raw_packet.get_u8(); - let rx_config = raw_packet.get_u8(); - let reserved = raw_packet.get_u8(); - let jb_nominal = raw_packet.get_u16(); - let jb_maximum = raw_packet.get_u16(); - let jb_abs_max = raw_packet.get_u16(); - - Ok(VoIPMetricsReportBlock { - ssrc, - loss_rate, - discard_rate, - burst_density, - gap_density, - burst_duration, - gap_duration, - round_trip_delay, - end_system_delay, - signal_level, - noise_level, - rerl, - gmin, - rfactor, - ext_rfactor, - mos_lq, - mos_cq, - rx_config, - reserved, - jb_nominal, - jb_maximum, - jb_abs_max, - }) - } -} diff --git a/rtcp/src/goodbye/goodbye_test.rs b/rtcp/src/goodbye/goodbye_test.rs deleted file mode 100644 index 22e3e42bf..000000000 --- a/rtcp/src/goodbye/goodbye_test.rs +++ /dev/null @@ -1,225 +0,0 @@ -use super::*; - -#[test] -fn test_goodbye_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - 0x81, 0xcb, 0x00, 0x0c, // v=2, p=0, count=1, BYE, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x03, 0x46, 0x4f, 0x4f, // len=3, text=FOO - ]), - Goodbye { - sources: vec![0x902f9e2e], - reason: Bytes::from_static(b"FOO"), - }, - None, - ), - ( - "invalid octet count", - Bytes::from_static(&[ - 0x81, 0xcb, 0x00, 0x0c, // v=2, p=0, count=1, BYE, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x04, 0x46, 0x4f, 0x4f, // len=4, text=FOO - ]), - Goodbye { - sources: vec![], - reason: Bytes::from_static(b""), - }, - Some(Error::PacketTooShort), - ), - ( - "wrong type", - Bytes::from_static(&[ - 0x81, 0xca, 0x00, 0x0c, // v=2, p=0, count=1, SDES, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x03, 0x46, 0x4f, 0x4f, // len=3, text=FOO - ]), - Goodbye { - sources: vec![], - reason: Bytes::from_static(b""), - }, - Some(Error::WrongType), - ), - ( - "short reason", - Bytes::from_static(&[ - 0x81, 0xcb, 0x00, 0x0c, // v=2, p=0, count=1, BYE, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x01, 0x46, 0x00, 0x00, // len=3, text=F + padding - ]), - Goodbye { - sources: vec![0x902f9e2e], - reason: Bytes::from_static(b"F"), - }, - None, - ), - ( - "not byte aligned", - Bytes::from_static(&[ - 0x81, 0xcb, 0x00, 0x0a, // v=2, p=0, count=1, BYE, len=10 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x01, 0x46, // len=1, text=F - ]), - Goodbye { - sources: vec![], - reason: Bytes::from_static(b""), - }, - Some(Error::PacketTooShort), - ), - ( - "bad count in header", - Bytes::from_static(&[ - 0x82, 0xcb, 0x00, 0x0c, // v=2, p=0, count=2, BYE, len=8 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - ]), - Goodbye { - sources: vec![], - reason: Bytes::from_static(b""), - }, - Some(Error::PacketTooShort), - ), - ( - "empty packet", - Bytes::from_static(&[ - // v=2, p=0, count=0, BYE, len=4 - 0x80, 0xcb, 0x00, 0x04, - ]), - Goodbye { - sources: vec![], - reason: Bytes::from_static(b""), - }, - None, - ), - ( - "nil", - Bytes::from_static(&[]), - Goodbye { - sources: vec![], - reason: Bytes::from_static(b""), - }, - Some(Error::PacketTooShort), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = Goodbye::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name} bye: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name} rr: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_goodbye_round_trip() { - let too_many_sources = vec![0u32; 1 << 5]; - - let mut too_long_text = String::new(); - for _ in 0..1 << 8 { - too_long_text.push('x'); - } - - let tests = vec![ - ( - "empty", - Goodbye { - sources: vec![], - ..Default::default() - }, - None, - ), - ( - "valid", - Goodbye { - sources: vec![0x01020304, 0x05060708], - reason: Bytes::from_static(b"because"), - }, - None, - ), - ( - "empty reason", - Goodbye { - sources: vec![0x01020304], - reason: Bytes::from_static(b""), - }, - None, - ), - ( - "reason no source", - Goodbye { - sources: vec![], - reason: Bytes::from_static(b"foo"), - }, - None, - ), - ( - "short reason", - Goodbye { - sources: vec![], - reason: Bytes::from_static(b"f"), - }, - None, - ), - ( - "count overflow", - Goodbye { - sources: too_many_sources, - reason: Bytes::from_static(b""), - }, - Some(Error::TooManySources), - ), - ( - "reason too long", - Goodbye { - sources: vec![], - reason: Bytes::copy_from_slice(too_long_text.as_bytes()), - }, - Some(Error::ReasonTooLong), - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = - Goodbye::unmarshal(&mut data).unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} diff --git a/rtcp/src/goodbye/mod.rs b/rtcp/src/goodbye/mod.rs deleted file mode 100644 index 9fa44fd77..000000000 --- a/rtcp/src/goodbye/mod.rs +++ /dev/null @@ -1,197 +0,0 @@ -#[cfg(test)] -mod goodbye_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -/// The Goodbye packet indicates that one or more sources are no longer active. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct Goodbye { - /// The SSRC/CSRC identifiers that are no longer active - pub sources: Vec, - /// Optional text indicating the reason for leaving, e.g., "camera malfunction" or "RTP loop detected" - pub reason: Bytes, -} - -impl fmt::Display for Goodbye { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = "Goodbye:\n\tSources:\n".to_string(); - for s in &self.sources { - out += format!("\t{}\n", *s).as_str(); - } - out += format!("\tReason: {:?}\n", self.reason).as_str(); - - write!(f, "{out}") - } -} - -impl Packet for Goodbye { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: self.sources.len() as u8, - packet_type: PacketType::Goodbye, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - self.sources.to_vec() - } - - fn raw_size(&self) -> usize { - let srcs_length = self.sources.len() * SSRC_LENGTH; - let reason_length = self.reason.len() + 1; - - HEADER_LENGTH + srcs_length + reason_length - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for Goodbye { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for Goodbye { - /// marshal_to encodes the packet in binary. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if self.sources.len() > COUNT_MAX { - return Err(Error::TooManySources.into()); - } - - if self.reason.len() > SDES_MAX_OCTET_COUNT { - return Err(Error::ReasonTooLong.into()); - } - - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * |V=2|P| SC | PT=BYE=203 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SSRC/CSRC | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * : ... : - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * (opt) | length | reason for leaving ... - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - for source in &self.sources { - buf.put_u32(*source); - } - - buf.put_u8(self.reason.len() as u8); - if !self.reason.is_empty() { - buf.put(self.reason.clone()); - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for Goodbye { - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * |V=2|P| SC | PT=BYE=203 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SSRC/CSRC | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * : ... : - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * (opt) | length | reason for leaving ... - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let raw_packet_len = raw_packet.remaining(); - - let header = Header::unmarshal(raw_packet)?; - if header.packet_type != PacketType::Goodbye { - return Err(Error::WrongType.into()); - } - - if get_padding_size(raw_packet_len) != 0 { - return Err(Error::PacketTooShort.into()); - } - - let reason_offset = HEADER_LENGTH + header.count as usize * SSRC_LENGTH; - - if reason_offset > raw_packet_len { - return Err(Error::PacketTooShort.into()); - } - - let mut sources = Vec::with_capacity(header.count as usize); - for _ in 0..header.count { - sources.push(raw_packet.get_u32()); - } - - let reason = if reason_offset < raw_packet_len { - let reason_len = raw_packet.get_u8() as usize; - let reason_end = reason_offset + 1 + reason_len; - - if reason_end > raw_packet_len { - return Err(Error::PacketTooShort.into()); - } - - raw_packet.copy_to_bytes(reason_len) - } else { - Bytes::new() - }; - - if - /*header.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(Goodbye { sources, reason }) - } -} diff --git a/rtcp/src/header.rs b/rtcp/src/header.rs deleted file mode 100644 index 508f87278..000000000 --- a/rtcp/src/header.rs +++ /dev/null @@ -1,315 +0,0 @@ -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; - -/// PacketType specifies the type of an RTCP packet -/// RTCP packet types registered with IANA. See: https://www.iana.org/assignments/rtp-parameters/rtp-parameters.xhtml#rtp-parameters-4 -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -#[repr(u8)] -pub enum PacketType { - #[default] - Unsupported = 0, - SenderReport = 200, // RFC 3550, 6.4.1 - ReceiverReport = 201, // RFC 3550, 6.4.2 - SourceDescription = 202, // RFC 3550, 6.5 - Goodbye = 203, // RFC 3550, 6.6 - ApplicationDefined = 204, // RFC 3550, 6.7 (unimplemented) - TransportSpecificFeedback = 205, // RFC 4585, 6051 - PayloadSpecificFeedback = 206, // RFC 4585, 6.3 - ExtendedReport = 207, // RFC 3611 -} - -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here -pub const FORMAT_SLI: u8 = 2; -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here -pub const FORMAT_PLI: u8 = 1; -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here -pub const FORMAT_FIR: u8 = 4; -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here -pub const FORMAT_TLN: u8 = 1; -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here -pub const FORMAT_RRR: u8 = 5; -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here -pub const FORMAT_REMB: u8 = 15; -/// Transport and Payload specific feedback messages overload the count field to act as a message type. those are listed here. -/// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-5 -pub const FORMAT_TCC: u8 = 15; - -impl std::fmt::Display for PacketType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = match self { - PacketType::Unsupported => "Unsupported", - PacketType::SenderReport => "SR", - PacketType::ReceiverReport => "RR", - PacketType::SourceDescription => "SDES", - PacketType::Goodbye => "BYE", - PacketType::ApplicationDefined => "APP", - PacketType::TransportSpecificFeedback => "TSFB", - PacketType::PayloadSpecificFeedback => "PSFB", - PacketType::ExtendedReport => "XR", - }; - write!(f, "{s}") - } -} - -impl From for PacketType { - fn from(b: u8) -> Self { - match b { - 200 => PacketType::SenderReport, // RFC 3550, 6.4.1 - 201 => PacketType::ReceiverReport, // RFC 3550, 6.4.2 - 202 => PacketType::SourceDescription, // RFC 3550, 6.5 - 203 => PacketType::Goodbye, // RFC 3550, 6.6 - 204 => PacketType::ApplicationDefined, // RFC 3550, 6.7 (unimplemented) - 205 => PacketType::TransportSpecificFeedback, // RFC 4585, 6051 - 206 => PacketType::PayloadSpecificFeedback, // RFC 4585, 6.3 - 207 => PacketType::ExtendedReport, // RFC 3611 - _ => PacketType::Unsupported, - } - } -} - -pub const RTP_VERSION: u8 = 2; -pub const VERSION_SHIFT: u8 = 6; -pub const VERSION_MASK: u8 = 0x3; -pub const PADDING_SHIFT: u8 = 5; -pub const PADDING_MASK: u8 = 0x1; -pub const COUNT_SHIFT: u8 = 0; -pub const COUNT_MASK: u8 = 0x1f; - -pub const HEADER_LENGTH: usize = 4; -pub const COUNT_MAX: usize = (1 << 5) - 1; -pub const SSRC_LENGTH: usize = 4; -pub const SDES_MAX_OCTET_COUNT: usize = (1 << 8) - 1; - -/// A Header is the common header shared by all RTCP packets -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct Header { - /// If the padding bit is set, this individual RTCP packet contains - /// some additional padding octets at the end which are not part of - /// the control information but are included in the length field. - pub padding: bool, - /// The number of reception reports, sources contained or FMT in this packet (depending on the Type) - pub count: u8, - /// The RTCP packet type for this packet - pub packet_type: PacketType, - /// The length of this RTCP packet in 32-bit words minus one, - /// including the header and any padding. - pub length: u16, -} - -/// Marshal encodes the Header in binary -impl MarshalSize for Header { - fn marshal_size(&self) -> usize { - HEADER_LENGTH - } -} - -impl Marshal for Header { - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if self.count > 31 { - return Err(Error::InvalidHeader.into()); - } - if buf.remaining_mut() < HEADER_LENGTH { - return Err(Error::BufferTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * |V=2|P| RC | PT=SR=200 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let b0 = (RTP_VERSION << VERSION_SHIFT) - | ((self.padding as u8) << PADDING_SHIFT) - | (self.count << COUNT_SHIFT); - - buf.put_u8(b0); - buf.put_u8(self.packet_type as u8); - buf.put_u16(self.length); - - Ok(HEADER_LENGTH) - } -} - -impl Unmarshal for Header { - /// Unmarshal decodes the Header from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < HEADER_LENGTH { - return Err(Error::PacketTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * |V=2|P| RC | PT | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let b0 = raw_packet.get_u8(); - let version = (b0 >> VERSION_SHIFT) & VERSION_MASK; - if version != RTP_VERSION { - return Err(Error::BadVersion.into()); - } - - let padding = ((b0 >> PADDING_SHIFT) & PADDING_MASK) > 0; - let count = (b0 >> COUNT_SHIFT) & COUNT_MASK; - let packet_type = PacketType::from(raw_packet.get_u8()); - let length = raw_packet.get_u16(); - - Ok(Header { - padding, - count, - packet_type, - length, - }) - } -} - -#[cfg(test)] -mod test { - use bytes::Bytes; - - use super::*; - - #[test] - fn test_header_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - // v=2, p=0, count=1, RR, len=7 - 0x81u8, 0xc9, 0x00, 0x07, - ]), - Header { - padding: false, - count: 1, - packet_type: PacketType::ReceiverReport, - length: 7, - }, - None, - ), - ( - "also valid", - Bytes::from_static(&[ - // v=2, p=1, count=1, BYE, len=7 - 0xa1, 0xcc, 0x00, 0x07, - ]), - Header { - padding: true, - count: 1, - packet_type: PacketType::ApplicationDefined, - length: 7, - }, - None, - ), - ( - "bad version", - Bytes::from_static(&[ - // v=0, p=0, count=0, RR, len=4 - 0x00, 0xc9, 0x00, 0x04, - ]), - Header { - padding: false, - count: 0, - packet_type: PacketType::Unsupported, - length: 0, - }, - Some(Error::BadVersion), - ), - ]; - - for (name, data, want, want_error) in tests { - let buf = &mut data.clone(); - let got = Header::unmarshal(buf); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(want_error) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - want_error, got_err, - "Unmarshal {name}: err = {got_err:?}, want {want_error:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name}: got {actual:?}, want {want:?}" - ); - } - } - } - - #[test] - fn test_header_roundtrip() { - let tests = vec![ - ( - "valid", - Header { - padding: true, - count: 31, - packet_type: PacketType::SenderReport, - length: 4, - }, - None, - ), - ( - "also valid", - Header { - padding: false, - count: 28, - packet_type: PacketType::ReceiverReport, - length: 65535, - }, - None, - ), - ( - "invalid count", - Header { - padding: false, - count: 40, - packet_type: PacketType::Unsupported, - length: 0, - }, - Some(Error::InvalidHeader), - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let data = got.ok().unwrap(); - let buf = &mut data.clone(); - let actual = Header::unmarshal(buf).unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } - } -} diff --git a/rtcp/src/lib.rs b/rtcp/src/lib.rs deleted file mode 100644 index 4d99b09d5..000000000 --- a/rtcp/src/lib.rs +++ /dev/null @@ -1,59 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -//! Package rtcp implements encoding and decoding of RTCP packets according to RFCs 3550 and 5506. -//! -//! RTCP is a sister protocol of the Real-time Transport Protocol (RTP). Its basic functionality -//! and packet structure is defined in RFC 3550. RTCP provides out-of-band statistics and control -//! information for an RTP session. It partners with RTP in the delivery and packaging of multimedia data, -//! but does not transport any media data itself. -//! -//! The primary function of RTCP is to provide feedback on the quality of service (QoS) -//! in media distribution by periodically sending statistics information such as transmitted octet -//! and packet counts, packet loss, packet delay variation, and round-trip delay time to participants -//! in a streaming multimedia session. An application may use this information to control quality of -//! service parameters, perhaps by limiting flow, or using a different codec. -//! -//! Decoding RTCP packets: -//!```nobuild -//! let pkt = rtcp::unmarshal(&rtcp_data).unwrap(); -//! -//! if let Some(e) = pkt -//! .as_any() -//! .downcast_ref::() -//! { -//! -//! } -//! else if let Some(e) = packet -//! .as_any() -//! .downcast_ref::(){} -//! .... -//!``` -//! -//! Encoding RTCP packets: -//!```nobuild -//! let pkt = PictureLossIndication{ -//! sender_ssrc: sender_ssrc, -//! media_ssrc: media_ssrc -//! }; -//! -//! let pli_data = pkt.marshal().unwrap(); -//! // ... -//!``` - -pub mod compound_packet; -mod error; -pub mod extended_report; -pub mod goodbye; -pub mod header; -pub mod packet; -pub mod payload_feedbacks; -pub mod raw_packet; -pub mod receiver_report; -pub mod reception_report; -pub mod sender_report; -pub mod source_description; -pub mod transport_feedbacks; -mod util; - -pub use error::Error; diff --git a/rtcp/src/packet.rs b/rtcp/src/packet.rs deleted file mode 100644 index 0a2aff805..000000000 --- a/rtcp/src/packet.rs +++ /dev/null @@ -1,276 +0,0 @@ -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use util::marshal::{Marshal, Unmarshal}; - -use crate::error::{Error, Result}; -use crate::extended_report::ExtendedReport; -use crate::goodbye::*; -use crate::header::*; -use crate::payload_feedbacks::full_intra_request::*; -use crate::payload_feedbacks::picture_loss_indication::*; -use crate::payload_feedbacks::receiver_estimated_maximum_bitrate::*; -use crate::payload_feedbacks::slice_loss_indication::*; -use crate::raw_packet::*; -use crate::receiver_report::*; -use crate::sender_report::*; -use crate::source_description::*; -use crate::transport_feedbacks::rapid_resynchronization_request::*; -use crate::transport_feedbacks::transport_layer_cc::*; -use crate::transport_feedbacks::transport_layer_nack::*; - -/// Packet represents an RTCP packet, a protocol used for out-of-band statistics and -/// control information for an RTP session -pub trait Packet: Marshal + Unmarshal + fmt::Display + fmt::Debug { - fn header(&self) -> Header; - fn destination_ssrc(&self) -> Vec; - fn raw_size(&self) -> usize; - fn as_any(&self) -> &(dyn Any + Send + Sync); - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool; - fn cloned(&self) -> Box; -} - -impl PartialEq for dyn Packet + Send + Sync { - fn eq(&self, other: &Self) -> bool { - self.equal(other) - } -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.cloned() - } -} - -/// marshal takes an array of Packets and serializes them to a single buffer -pub fn marshal(packets: &[Box]) -> Result { - let mut out = BytesMut::new(); - for p in packets { - let data = p.marshal()?; - out.put(data); - } - Ok(out.freeze()) -} - -/// Unmarshal takes an entire udp datagram (which may consist of multiple RTCP packets) and -/// returns the unmarshaled packets it contains. -/// -/// If this is a reduced-size RTCP packet a feedback packet (Goodbye, SliceLossIndication, etc) -/// will be returned. Otherwise, the underlying type of the returned packet will be -/// CompoundPacket. -pub fn unmarshal(raw_data: &mut B) -> Result>> -where - B: Buf, -{ - let mut packets = vec![]; - - while raw_data.has_remaining() { - let p = unmarshaller(raw_data)?; - packets.push(p); - } - - match packets.len() { - // Empty Packet - 0 => Err(Error::InvalidHeader), - - // Multiple Packet - _ => Ok(packets), - } -} - -/// unmarshaller is a factory which pulls the first RTCP packet from a bytestream, -/// and returns it's parsed representation, and the amount of data that was processed. -pub(crate) fn unmarshaller(raw_data: &mut B) -> Result> -where - B: Buf, -{ - let h = Header::unmarshal(raw_data)?; - - let length = (h.length as usize) * 4; - if length > raw_data.remaining() { - return Err(Error::PacketTooShort); - } - - let mut in_packet = h.marshal()?.chain(raw_data.take(length)); - - let p: Box = match h.packet_type { - PacketType::SenderReport => Box::new(SenderReport::unmarshal(&mut in_packet)?), - PacketType::ReceiverReport => Box::new(ReceiverReport::unmarshal(&mut in_packet)?), - PacketType::SourceDescription => Box::new(SourceDescription::unmarshal(&mut in_packet)?), - PacketType::Goodbye => Box::new(Goodbye::unmarshal(&mut in_packet)?), - - PacketType::TransportSpecificFeedback => match h.count { - FORMAT_TLN => Box::new(TransportLayerNack::unmarshal(&mut in_packet)?), - FORMAT_RRR => Box::new(RapidResynchronizationRequest::unmarshal(&mut in_packet)?), - FORMAT_TCC => Box::new(TransportLayerCc::unmarshal(&mut in_packet)?), - _ => Box::new(RawPacket::unmarshal(&mut in_packet)?), - }, - PacketType::PayloadSpecificFeedback => match h.count { - FORMAT_PLI => Box::new(PictureLossIndication::unmarshal(&mut in_packet)?), - FORMAT_SLI => Box::new(SliceLossIndication::unmarshal(&mut in_packet)?), - FORMAT_REMB => Box::new(ReceiverEstimatedMaximumBitrate::unmarshal(&mut in_packet)?), - FORMAT_FIR => Box::new(FullIntraRequest::unmarshal(&mut in_packet)?), - _ => Box::new(RawPacket::unmarshal(&mut in_packet)?), - }, - PacketType::ExtendedReport => Box::new(ExtendedReport::unmarshal(&mut in_packet)?), - _ => Box::new(RawPacket::unmarshal(&mut in_packet)?), - }; - - Ok(p) -} - -#[cfg(test)] -mod test { - use bytes::Bytes; - - use super::*; - use crate::reception_report::*; - - #[test] - fn test_packet_unmarshal() { - let mut data = Bytes::from_static(&[ - // Receiver Report (offset=0) - 0x81, 0xc9, 0x0, 0x7, // v=2, p=0, count=1, RR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - // Source Description (offset=32) - 0x81, 0xca, 0x0, 0xc, // v=2, p=0, count=1, SDES, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x1, 0x26, // CNAME, len=38 - 0x7b, 0x39, 0x63, 0x30, 0x30, 0x65, 0x62, 0x39, 0x32, 0x2d, 0x31, 0x61, 0x66, 0x62, - 0x2d, 0x39, 0x64, 0x34, 0x39, 0x2d, 0x61, 0x34, 0x37, 0x64, 0x2d, 0x39, 0x31, 0x66, - 0x36, 0x34, 0x65, 0x65, 0x65, 0x36, 0x39, 0x66, 0x35, - 0x7d, // text="{9c00eb92-1afb-9d49-a47d-91f64eee69f5}" - 0x0, 0x0, 0x0, 0x0, // END + padding - // Goodbye (offset=84) - 0x81, 0xcb, 0x0, 0x1, // v=2, p=0, count=1, BYE, len=1 - 0x90, 0x2f, 0x9e, 0x2e, // source=0x902f9e2e - 0x81, 0xce, 0x0, 0x2, // Picture Loss Indication (offset=92) - 0x90, 0x2f, 0x9e, 0x2e, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e - 0x85, 0xcd, 0x0, 0x2, // RapidResynchronizationRequest (offset=104) - 0x90, 0x2f, 0x9e, 0x2e, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e - ]); - - let packet = unmarshal(&mut data).expect("Error unmarshalling packets"); - - let a = ReceiverReport { - ssrc: 0x902f9e2e, - reports: vec![ReceptionReport { - ssrc: 0xbc5e9a40, - fraction_lost: 0, - total_lost: 0, - last_sequence_number: 0x46e1, - jitter: 273, - last_sender_report: 0x9f36432, - delay: 150137, - }], - ..Default::default() - }; - - let b = SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 0x902f9e2e, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"{9c00eb92-1afb-9d49-a47d-91f64eee69f5}"), - }], - }], - }; - - let c = Goodbye { - sources: vec![0x902f9e2e], - ..Default::default() - }; - - let d = PictureLossIndication { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - }; - - let e = RapidResynchronizationRequest { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - }; - - let expected: Vec> = vec![ - Box::new(a), - Box::new(b), - Box::new(c), - Box::new(d), - Box::new(e), - ]; - - assert!(packet == expected, "Invalid packets"); - } - - #[test] - fn test_packet_unmarshal_empty() -> Result<()> { - let result = unmarshal(&mut Bytes::new()); - if let Err(got) = result { - let want = Error::InvalidHeader; - assert_eq!(got, want, "Unmarshal(nil) err = {got}, want {want}"); - } else { - panic!("want error"); - } - - Ok(()) - } - - #[test] - fn test_packet_invalid_header_length() -> Result<()> { - let mut data = Bytes::from_static(&[ - // Goodbye (offset=84) - // v=2, p=0, count=1, BYE, len=100 - 0x81, 0xcb, 0x0, 0x64, - ]); - - let result = unmarshal(&mut data); - if let Err(got) = result { - let want = Error::PacketTooShort; - assert_eq!( - got, want, - "Unmarshal(invalid_header_length) err = {got}, want {want}" - ); - } else { - panic!("want error"); - } - - Ok(()) - } - #[test] - fn test_packet_unmarshal_firefox() -> Result<()> { - // issue report from https://github.com/webrtc-rs/srtp/issues/7 - let tests = vec![ - Bytes::from_static(&[ - 143, 205, 0, 6, 65, 227, 184, 49, 118, 243, 78, 96, 42, 63, 0, 5, 12, 162, 166, 0, - 32, 5, 200, 4, 0, 4, 0, 0, - ]), - Bytes::from_static(&[ - 143, 205, 0, 9, 65, 227, 184, 49, 118, 243, 78, 96, 42, 68, 0, 17, 12, 162, 167, 1, - 32, 17, 88, 0, 4, 0, 4, 8, 108, 0, 4, 0, 4, 12, 0, 4, 0, 4, 4, 0, - ]), - Bytes::from_static(&[ - 143, 205, 0, 8, 65, 227, 184, 49, 118, 243, 78, 96, 42, 91, 0, 12, 12, 162, 168, 3, - 32, 12, 220, 4, 0, 4, 0, 8, 128, 4, 0, 4, 0, 8, 0, 0, - ]), - Bytes::from_static(&[ - 143, 205, 0, 7, 65, 227, 184, 49, 118, 243, 78, 96, 42, 103, 0, 8, 12, 162, 169, 4, - 32, 8, 232, 4, 0, 4, 0, 4, 4, 0, 0, 0, - ]), - ]; - - for mut test in tests { - unmarshal(&mut test)?; - } - - Ok(()) - } -} diff --git a/rtcp/src/payload_feedbacks/full_intra_request/full_intra_request_test.rs b/rtcp/src/payload_feedbacks/full_intra_request/full_intra_request_test.rs deleted file mode 100644 index 2d0dce6e2..000000000 --- a/rtcp/src/payload_feedbacks/full_intra_request/full_intra_request_test.rs +++ /dev/null @@ -1,215 +0,0 @@ -use bytes::Bytes; - -use super::*; - -#[test] -fn test_full_intra_request_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - 0x84, 0xce, 0x00, 0x03, // v=2, p=0, FMT=4, PSFB, len=3 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - 0x12, 0x34, 0x56, 0x78, // ssrc=0x12345678 - 0x42, 0x00, 0x00, 0x00, // Seqno=0x42 - ]), - FullIntraRequest { - sender_ssrc: 0x0, - media_ssrc: 0x4bc4fcb4, - fir: vec![FirEntry { - ssrc: 0x12345678, - sequence_number: 0x42, - }], - }, - None, - ), - ( - "also valid", - Bytes::from_static(&[ - 0x84, 0xce, 0x00, 0x05, // v=2, p=0, FMT=4, PSFB, len=3 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - 0x12, 0x34, 0x56, 0x78, // ssrc=0x12345678 - 0x42, 0x00, 0x00, 0x00, // Seqno=0x42 - 0x98, 0x76, 0x54, 0x32, // ssrc=0x98765432 - 0x57, 0x00, 0x00, 0x00, // Seqno=0x57 - ]), - FullIntraRequest { - sender_ssrc: 0x0, - media_ssrc: 0x4bc4fcb4, - fir: vec![ - FirEntry { - ssrc: 0x12345678, - sequence_number: 0x42, - }, - FirEntry { - ssrc: 0x98765432, - sequence_number: 0x57, - }, - ], - }, - None, - ), - ( - "packet too short", - Bytes::from_static(&[0x00, 0x00, 0x00, 0x00]), - FullIntraRequest::default(), - Some(Error::PacketTooShort), - ), - ( - "invalid header", - Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]), - FullIntraRequest::default(), - Some(Error::BadVersion), - ), - ( - "wrong type", - Bytes::from_static(&[ - 0x84, 0xc9, 0x00, 0x03, // v=2, p=0, FMT=4, RR, len=3 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - 0x12, 0x34, 0x56, 0x78, // ssrc=0x12345678 - 0x42, 0x00, 0x00, 0x00, // Seqno=0x42 - ]), - FullIntraRequest::default(), - Some(Error::WrongType), - ), - ( - "wrong fmt", - Bytes::from_static(&[ - 0x82, 0xce, 0x00, 0x03, // v=2, p=0, FMT=2, PSFB, len=3 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - 0x12, 0x34, 0x56, 0x78, // ssrc=0x12345678 - 0x42, 0x00, 0x00, 0x00, // Seqno=0x42 - ]), - FullIntraRequest::default(), - Some(Error::WrongType), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = FullIntraRequest::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name} rr: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name} rr: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_full_intra_request_round_trip() { - let tests: Vec<(&str, FullIntraRequest, Option)> = vec![ - ( - "valid", - FullIntraRequest { - sender_ssrc: 1, - media_ssrc: 2, - fir: vec![FirEntry { - ssrc: 3, - sequence_number: 42, - }], - }, - None, - ), - ( - "also valid", - FullIntraRequest { - sender_ssrc: 5000, - media_ssrc: 6000, - fir: vec![FirEntry { - ssrc: 3, - sequence_number: 57, - }], - }, - None, - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = FullIntraRequest::unmarshal(&mut data) - .unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} - -#[test] -fn test_full_intra_request_unmarshal_header() { - let tests = vec![( - "valid header", - Bytes::from_static(&[ - 0x84, 0xce, 0x00, 0x02, // v=2, p=0, FMT=1, PSFB, len=1 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, 0x00, 0x00, 0x00, 0x00, // ssrc=0x4bc4fcb4 - ]), - Header { - count: FORMAT_FIR, - packet_type: PacketType::PayloadSpecificFeedback, - length: 2, - ..Default::default() - }, - )]; - - for (name, mut data, want) in tests { - let result = FullIntraRequest::unmarshal(&mut data); - - assert!( - result.is_ok(), - "Unmarshal header {name} rr: want {result:?}", - ); - - match result { - Err(_) => continue, - - Ok(fir) => { - let h = fir.header(); - - assert_eq!( - h, want, - "Unmarshal header {name} rr: got {h:?}, want {want:?}" - ) - } - } - } -} diff --git a/rtcp/src/payload_feedbacks/full_intra_request/mod.rs b/rtcp/src/payload_feedbacks/full_intra_request/mod.rs deleted file mode 100644 index fa8440889..000000000 --- a/rtcp/src/payload_feedbacks/full_intra_request/mod.rs +++ /dev/null @@ -1,172 +0,0 @@ -#[cfg(test)] -mod full_intra_request_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -/// A FIREntry is a (ssrc, seqno) pair, as carried by FullIntraRequest. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct FirEntry { - pub ssrc: u32, - pub sequence_number: u8, -} - -/// The FullIntraRequest packet is used to reliably request an Intra frame -/// in a video stream. See RFC 5104 Section 3.5.1. This is not for loss -/// recovery, which should use PictureLossIndication (PLI) instead. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct FullIntraRequest { - pub sender_ssrc: u32, - pub media_ssrc: u32, - pub fir: Vec, -} - -const FIR_OFFSET: usize = 8; - -impl fmt::Display for FullIntraRequest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = format!("FullIntraRequest {} {}", self.sender_ssrc, self.media_ssrc); - for e in &self.fir { - out += format!(" ({} {})", e.ssrc, e.sequence_number).as_str(); - } - write!(f, "{out}") - } -} - -impl Packet for FullIntraRequest { - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_FIR, - packet_type: PacketType::PayloadSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - let mut ssrcs: Vec = Vec::with_capacity(self.fir.len()); - for entry in &self.fir { - ssrcs.push(entry.ssrc); - } - ssrcs - } - - fn raw_size(&self) -> usize { - HEADER_LENGTH + FIR_OFFSET + self.fir.len() * 8 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for FullIntraRequest { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for FullIntraRequest { - /// Marshal encodes the FullIntraRequest - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(self.media_ssrc); - - for fir in self.fir.iter() { - buf.put_u32(fir.ssrc); - buf.put_u8(fir.sequence_number); - buf.put_u8(0); - buf.put_u16(0); - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for FullIntraRequest { - /// Unmarshal decodes the FullIntraRequest - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SSRC_LENGTH) { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - - if raw_packet_len < (HEADER_LENGTH + (4 * h.length) as usize) { - return Err(Error::PacketTooShort.into()); - } - - if h.packet_type != PacketType::PayloadSpecificFeedback || h.count != FORMAT_FIR { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - - let mut i = HEADER_LENGTH + FIR_OFFSET; - let mut fir = vec![]; - while i < HEADER_LENGTH + (h.length * 4) as usize { - fir.push(FirEntry { - ssrc: raw_packet.get_u32(), - sequence_number: raw_packet.get_u8(), - }); - raw_packet.get_u8(); - raw_packet.get_u16(); - - i += 8; - } - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(FullIntraRequest { - sender_ssrc, - media_ssrc, - fir, - }) - } -} diff --git a/rtcp/src/payload_feedbacks/mod.rs b/rtcp/src/payload_feedbacks/mod.rs deleted file mode 100644 index e7d02f177..000000000 --- a/rtcp/src/payload_feedbacks/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod full_intra_request; -pub mod picture_loss_indication; -pub mod receiver_estimated_maximum_bitrate; -pub mod slice_loss_indication; diff --git a/rtcp/src/payload_feedbacks/picture_loss_indication/mod.rs b/rtcp/src/payload_feedbacks/picture_loss_indication/mod.rs deleted file mode 100644 index 1092e0afc..000000000 --- a/rtcp/src/payload_feedbacks/picture_loss_indication/mod.rs +++ /dev/null @@ -1,141 +0,0 @@ -#[cfg(test)] -mod picture_loss_indication_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -const PLI_LENGTH: usize = 2; - -/// The PictureLossIndication packet informs the encoder about the loss of an undefined amount of coded video data belonging to one or more pictures -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct PictureLossIndication { - /// SSRC of sender - pub sender_ssrc: u32, - /// SSRC where the loss was experienced - pub media_ssrc: u32, -} - -impl fmt::Display for PictureLossIndication { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "PictureLossIndication {:x} {:x}", - self.sender_ssrc, self.media_ssrc - ) - } -} - -impl Packet for PictureLossIndication { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_PLI, - packet_type: PacketType::PayloadSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.media_ssrc] - } - - fn raw_size(&self) -> usize { - HEADER_LENGTH + SSRC_LENGTH * 2 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for PictureLossIndication { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for PictureLossIndication { - /// Marshal encodes the PictureLossIndication in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - /* - * PLI does not require parameters. Therefore, the length field MUST be - * 2, and there MUST NOT be any Feedback Control Information. - * - * The semantics of this FB message is independent of the payload type. - */ - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(self.media_ssrc); - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for PictureLossIndication { - /// Unmarshal decodes the PictureLossIndication from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + (SSRC_LENGTH * 2)) { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - if h.packet_type != PacketType::PayloadSpecificFeedback || h.count != FORMAT_PLI { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(PictureLossIndication { - sender_ssrc, - media_ssrc, - }) - } -} diff --git a/rtcp/src/payload_feedbacks/picture_loss_indication/picture_loss_indication_test.rs b/rtcp/src/payload_feedbacks/picture_loss_indication/picture_loss_indication_test.rs deleted file mode 100644 index c40a8984b..000000000 --- a/rtcp/src/payload_feedbacks/picture_loss_indication/picture_loss_indication_test.rs +++ /dev/null @@ -1,162 +0,0 @@ -use bytes::Bytes; - -use super::*; - -#[test] -fn test_picture_loss_indication_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - 0x81, 0xce, 0x00, 0x02, // v=2, p=0, FMT=1, PSFB, len=1 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - ]), - PictureLossIndication { - sender_ssrc: 0x0, - media_ssrc: 0x4bc4fcb4, - }, - None, - ), - ( - "packet too short", - Bytes::from_static(&[0x81, 0xce, 0x00, 0x00]), - PictureLossIndication::default(), - Some(Error::PacketTooShort), - ), - ( - "invalid header", - Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]), - PictureLossIndication::default(), - Some(Error::BadVersion), - ), - ( - "wrong type", - Bytes::from_static(&[ - 0x81, 0xc9, 0x00, 0x02, // v=2, p=0, FMT=1, RR, len=1 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - ]), - PictureLossIndication::default(), - Some(Error::WrongType), - ), - ( - "wrong fmt", - Bytes::from_static(&[ - 0x82, 0xc9, 0x00, 0x02, // v=2, p=0, FMT=2, RR, len=1 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - ]), - PictureLossIndication::default(), - Some(Error::WrongType), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = PictureLossIndication::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name} rr: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name} rr: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_picture_loss_indication_roundtrip() { - let tests: Vec<(&str, PictureLossIndication, Option)> = vec![ - ( - "valid", - PictureLossIndication { - sender_ssrc: 1, - media_ssrc: 2, - }, - None, - ), - ( - "also valid", - PictureLossIndication { - sender_ssrc: 5000, - media_ssrc: 6000, - }, - None, - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = PictureLossIndication::unmarshal(&mut data) - .unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} - -#[test] -fn test_picture_loss_indication_unmarshal_header() -> Result<()> { - let tests = vec![( - "valid header", - Bytes::from_static(&[ - 0x81u8, 0xce, 0x00, 0x02, // v=2, p=0, FMT=1, PSFB, len=1 - 0x00, 0x00, 0x00, 0x00, // ssrc=0x0 - 0x4b, 0xc4, 0xfc, 0xb4, // ssrc=0x4bc4fcb4 - ]), - Header { - count: FORMAT_PLI, - packet_type: PacketType::PayloadSpecificFeedback, - length: PLI_LENGTH as u16, - ..Default::default() - }, - )]; - - for (name, mut data, header) in tests { - let pli = PictureLossIndication::unmarshal(&mut data)?; - - assert_eq!( - pli.header(), - header, - "Unmarshal header {} rr: got {:?}, want {:?}", - name, - pli.header(), - header - ); - } - - Ok(()) -} diff --git a/rtcp/src/payload_feedbacks/receiver_estimated_maximum_bitrate/mod.rs b/rtcp/src/payload_feedbacks/receiver_estimated_maximum_bitrate/mod.rs deleted file mode 100644 index 17703ae29..000000000 --- a/rtcp/src/payload_feedbacks/receiver_estimated_maximum_bitrate/mod.rs +++ /dev/null @@ -1,287 +0,0 @@ -#[cfg(test)] -mod receiver_estimated_maximum_bitrate_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -/// ReceiverEstimatedMaximumBitrate contains the receiver's estimated maximum bitrate. -/// see: https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03 -#[derive(Debug, PartialEq, Default, Clone)] -pub struct ReceiverEstimatedMaximumBitrate { - /// SSRC of sender - pub sender_ssrc: u32, - - /// Estimated maximum bitrate - pub bitrate: f32, - - /// SSRC entries which this packet applies to - pub ssrcs: Vec, -} - -const REMB_OFFSET: usize = 16; - -/// Keep a table of powers to units for fast conversion. -const BIT_UNITS: [&str; 7] = ["b", "Kb", "Mb", "Gb", "Tb", "Pb", "Eb"]; -const UNIQUE_IDENTIFIER: [u8; 4] = [b'R', b'E', b'M', b'B']; - -/// String prints the REMB packet in a human-readable format. -impl fmt::Display for ReceiverEstimatedMaximumBitrate { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Do some unit conversions because b/s is far too difficult to read. - let mut bitrate = self.bitrate; - let mut powers = 0; - - // Keep dividing the bitrate until it's under 1000 - while bitrate >= 1000.0 && powers < BIT_UNITS.len() { - bitrate /= 1000.0; - powers += 1; - } - - let unit = BIT_UNITS[powers]; - - write!( - f, - "ReceiverEstimatedMaximumBitrate {:x} {:.2} {}/s", - self.sender_ssrc, bitrate, unit, - ) - } -} - -impl Packet for ReceiverEstimatedMaximumBitrate { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_REMB, - packet_type: PacketType::PayloadSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - self.ssrcs.clone() - } - - fn raw_size(&self) -> usize { - HEADER_LENGTH + REMB_OFFSET + self.ssrcs.len() * 4 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for ReceiverEstimatedMaximumBitrate { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for ReceiverEstimatedMaximumBitrate { - /// Marshal serializes the packet and returns a byte slice. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - const BITRATE_MAX: f32 = 2.417_842_4e24; //0x3FFFFp+63; - - /* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |V=2|P| FMT=15 | PT=206 | length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | SSRC of packet sender | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | SSRC of media source | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Unique identifier 'R' 'E' 'M' 'B' | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Num SSRC | BR Exp | BR Mantissa | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | SSRC feedback | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | ... | - */ - - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(0); // always zero - - buf.put_slice(&UNIQUE_IDENTIFIER); - - // Write the length of the ssrcs to follow at the end - buf.put_u8(self.ssrcs.len() as u8); - - let mut exp = 0; - let mut bitrate = self.bitrate; - if bitrate >= BITRATE_MAX { - bitrate = BITRATE_MAX - } - - if bitrate < 0.0 { - return Err(Error::InvalidBitrate.into()); - } - - while bitrate >= (1 << 18) as f32 { - bitrate /= 2.0; - exp += 1; - } - - if exp >= (1 << 6) { - return Err(Error::InvalidBitrate.into()); - } - - let mantissa = bitrate.floor() as u32; - - // We can't quite use the binary package because - // a) it's a uint24 and b) the exponent is only 6-bits - // Just trust me; this is big-endian encoding. - buf.put_u8((exp << 2) as u8 | (mantissa >> 16) as u8); - buf.put_u8((mantissa >> 8) as u8); - buf.put_u8(mantissa as u8); - - // Write the SSRCs at the very end. - for ssrc in &self.ssrcs { - buf.put_u32(*ssrc); - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for ReceiverEstimatedMaximumBitrate { - /// Unmarshal reads a REMB packet from the given byte slice. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - // 20 bytes is the size of the packet with no SSRCs - if raw_packet_len < 20 { - return Err(Error::PacketTooShort.into()); - } - - const MANTISSA_MAX: u32 = 0x7FFFFF; - /* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |V=2|P| FMT=15 | PT=206 | length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | SSRC of packet sender | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | SSRC of media source | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Unique identifier 'R' 'E' 'M' 'B' | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Num SSRC | BR Exp | BR Mantissa | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | SSRC feedback | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | ... | - */ - let header = Header::unmarshal(raw_packet)?; - - if header.packet_type != PacketType::PayloadSpecificFeedback || header.count != FORMAT_REMB - { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - if media_ssrc != 0 { - return Err(Error::SsrcMustBeZero.into()); - } - - // REMB rules all around me - let mut unique_identifier = [0; 4]; - unique_identifier[0] = raw_packet.get_u8(); - unique_identifier[1] = raw_packet.get_u8(); - unique_identifier[2] = raw_packet.get_u8(); - unique_identifier[3] = raw_packet.get_u8(); - if unique_identifier[0] != UNIQUE_IDENTIFIER[0] - || unique_identifier[1] != UNIQUE_IDENTIFIER[1] - || unique_identifier[2] != UNIQUE_IDENTIFIER[2] - || unique_identifier[3] != UNIQUE_IDENTIFIER[3] - { - return Err(Error::MissingRembIdentifier.into()); - } - - // The next byte is the number of SSRC entries at the end. - let ssrcs_len = raw_packet.get_u8() as usize; - - // Get the 6-bit exponent value. - let b17 = raw_packet.get_u8(); - let mut exp = (b17 as u64) >> 2; - exp += 127; // bias for IEEE754 - exp += 23; // IEEE754 biases the decimal to the left, abs-send-time biases it to the right - - // The remaining 2-bits plus the next 16-bits are the mantissa. - let b18 = raw_packet.get_u8(); - let b19 = raw_packet.get_u8(); - let mut mantissa = ((b17 & 3) as u32) << 16 | (b18 as u32) << 8 | b19 as u32; - - if mantissa != 0 { - // ieee754 requires an implicit leading bit - while (mantissa & (MANTISSA_MAX + 1)) == 0 { - exp -= 1; - mantissa *= 2; - } - } - - // bitrate = mantissa * 2^exp - let bitrate = f32::from_bits(((exp as u32) << 23) | (mantissa & MANTISSA_MAX)); - - let mut ssrcs = vec![]; - for _i in 0..ssrcs_len { - ssrcs.push(raw_packet.get_u32()); - } - - if - /*header.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(ReceiverEstimatedMaximumBitrate { - sender_ssrc, - //media_ssrc, - bitrate, - ssrcs, - }) - } -} diff --git a/rtcp/src/payload_feedbacks/receiver_estimated_maximum_bitrate/receiver_estimated_maximum_bitrate_test.rs b/rtcp/src/payload_feedbacks/receiver_estimated_maximum_bitrate/receiver_estimated_maximum_bitrate_test.rs deleted file mode 100644 index 4ea3efb99..000000000 --- a/rtcp/src/payload_feedbacks/receiver_estimated_maximum_bitrate/receiver_estimated_maximum_bitrate_test.rs +++ /dev/null @@ -1,121 +0,0 @@ -use bytes::Bytes; - -use super::*; - -#[test] -fn test_receiver_estimated_maximum_bitrate_marshal() { - let input = ReceiverEstimatedMaximumBitrate { - sender_ssrc: 1, - bitrate: 8927168.0, - ssrcs: vec![1215622422], - }; - - let expected = Bytes::from_static(&[ - 143, 206, 0, 5, 0, 0, 0, 1, 0, 0, 0, 0, 82, 69, 77, 66, 1, 26, 32, 223, 72, 116, 237, 22, - ]); - - let output = input.marshal().unwrap(); - assert_eq!(output, expected); -} - -#[test] -fn test_receiver_estimated_maximum_bitrate_unmarshal() { - // Real data sent by Chrome while watching a 6Mb/s stream - let mut input = Bytes::from_static(&[ - 143, 206, 0, 5, 0, 0, 0, 1, 0, 0, 0, 0, 82, 69, 77, 66, 1, 26, 32, 223, 72, 116, 237, 22, - ]); - - // mantissa = []byte{26 & 3, 32, 223} = []byte{2, 32, 223} = 139487 - // exp = 26 >> 2 = 6 - // bitrate = 139487 * 2^6 = 139487 * 64 = 8927168 = 8.9 Mb/s - let expected = ReceiverEstimatedMaximumBitrate { - sender_ssrc: 1, - bitrate: 8927168.0, - ssrcs: vec![1215622422], - }; - - let packet = ReceiverEstimatedMaximumBitrate::unmarshal(&mut input).unwrap(); - assert_eq!(packet, expected); -} - -#[test] -fn test_receiver_estimated_maximum_bitrate_truncate() { - let input = Bytes::from_static(&[ - 143, 206, 0, 5, 0, 0, 0, 1, 0, 0, 0, 0, 82, 69, 77, 66, 1, 26, 32, 223, 72, 116, 237, 22, - ]); - - // Make sure that we're interpreting the bitrate correctly. - // For the above example, we have: - - // mantissa = 139487 - // exp = 6 - // bitrate = 8927168 - - let mut buf = input.clone(); - let mut packet = ReceiverEstimatedMaximumBitrate::unmarshal(&mut buf).unwrap(); - assert_eq!(packet.bitrate, 8927168.0); - - // Just verify marshal produces the same input. - let output = packet.marshal().unwrap(); - assert_eq!(output, input); - - // If we subtract the bitrate by 1, we'll round down a lower mantissa - packet.bitrate -= 1.0; - - // bitrate = 8927167 - // mantissa = 139486 - // exp = 6 - - let mut output = packet.marshal().unwrap(); - assert_ne!(output, input); - let expected = Bytes::from_static(&[ - 143, 206, 0, 5, 0, 0, 0, 1, 0, 0, 0, 0, 82, 69, 77, 66, 1, 26, 32, 222, 72, 116, 237, 22, - ]); - assert_eq!(output, expected); - - // Which if we actually unmarshal again, we'll find that it's actually decreased by 63 (which is exp) - // mantissa = 139486 - // exp = 6 - // bitrate = 8927104 - - let packet = ReceiverEstimatedMaximumBitrate::unmarshal(&mut output).unwrap(); - assert_eq!(8927104.0, packet.bitrate); -} - -#[test] -fn test_receiver_estimated_maximum_bitrate_overflow() { - // Marshal a packet with the maximum possible bitrate. - let packet = ReceiverEstimatedMaximumBitrate { - bitrate: f32::MAX, - ..Default::default() - }; - - // mantissa = 262143 = 0x3FFFF - // exp = 63 - - let expected = Bytes::from_static(&[ - 143, 206, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 82, 69, 77, 66, 0, 255, 255, 255, - ]); - - let output = packet.marshal().unwrap(); - assert_eq!(output, expected); - - // mantissa = 262143 - // exp = 63 - // bitrate = 0xFFFFC00000000000 - - let mut buf = output; - let packet = ReceiverEstimatedMaximumBitrate::unmarshal(&mut buf).unwrap(); - assert_eq!(packet.bitrate, f32::from_bits(0x67FFFFC0)); - - // Make sure we marshal to the same result again. - let output = packet.marshal().unwrap(); - assert_eq!(output, expected); - - // Finally, try unmarshalling one number higher than we used to be able to handle. - let mut input = Bytes::from_static(&[ - 143, 206, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 82, 69, 77, 66, 0, 188, 0, 0, - ]); - let packet = ReceiverEstimatedMaximumBitrate::unmarshal(&mut input).unwrap(); - assert_eq!(packet.bitrate, f32::from_bits(0x62800000)); -} diff --git a/rtcp/src/payload_feedbacks/slice_loss_indication/mod.rs b/rtcp/src/payload_feedbacks/slice_loss_indication/mod.rs deleted file mode 100644 index c424f5c60..000000000 --- a/rtcp/src/payload_feedbacks/slice_loss_indication/mod.rs +++ /dev/null @@ -1,180 +0,0 @@ -#[cfg(test)] -mod slice_loss_indication_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -const SLI_LENGTH: usize = 2; -const SLI_OFFSET: usize = 8; - -/// SLIEntry represents a single entry to the SLI packet's -/// list of lost slices. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct SliEntry { - /// ID of first lost slice - pub first: u16, - /// Number of lost slices - pub number: u16, - /// ID of related picture - pub picture: u8, -} - -/// The SliceLossIndication packet informs the encoder about the loss of a picture slice -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct SliceLossIndication { - /// SSRC of sender - pub sender_ssrc: u32, - /// SSRC of the media source - pub media_ssrc: u32, - - pub sli_entries: Vec, -} - -impl fmt::Display for SliceLossIndication { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "SliceLossIndication {:x} {:x} {:?}", - self.sender_ssrc, self.media_ssrc, self.sli_entries, - ) - } -} - -impl Packet for SliceLossIndication { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_SLI, - packet_type: PacketType::TransportSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.media_ssrc] - } - - fn raw_size(&self) -> usize { - HEADER_LENGTH + SLI_OFFSET + self.sli_entries.len() * 4 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for SliceLossIndication { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for SliceLossIndication { - /// Marshal encodes the SliceLossIndication in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if (self.sli_entries.len() + SLI_LENGTH) as u8 > u8::MAX { - return Err(Error::TooManyReports.into()); - } - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(self.media_ssrc); - - for s in &self.sli_entries { - let sli = ((s.first as u32 & 0x1FFF) << 19) - | ((s.number as u32 & 0x1FFF) << 6) - | (s.picture as u32 & 0x3F); - - buf.put_u32(sli); - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for SliceLossIndication { - /// Unmarshal decodes the SliceLossIndication from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SSRC_LENGTH) { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - - if raw_packet_len < (HEADER_LENGTH + (4 * h.length as usize)) { - return Err(Error::PacketTooShort.into()); - } - - if h.packet_type != PacketType::TransportSpecificFeedback || h.count != FORMAT_SLI { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - - let mut i = HEADER_LENGTH + SLI_OFFSET; - let mut sli_entries = vec![]; - while i < HEADER_LENGTH + h.length as usize * 4 { - let sli = raw_packet.get_u32(); - sli_entries.push(SliEntry { - first: ((sli >> 19) & 0x1FFF) as u16, - number: ((sli >> 6) & 0x1FFF) as u16, - picture: (sli & 0x3F) as u8, - }); - - i += 4; - } - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(SliceLossIndication { - sender_ssrc, - media_ssrc, - sli_entries, - }) - } -} diff --git a/rtcp/src/payload_feedbacks/slice_loss_indication/slice_loss_indication_test.rs b/rtcp/src/payload_feedbacks/slice_loss_indication/slice_loss_indication_test.rs deleted file mode 100644 index c60b3be96..000000000 --- a/rtcp/src/payload_feedbacks/slice_loss_indication/slice_loss_indication_test.rs +++ /dev/null @@ -1,135 +0,0 @@ -use bytes::Bytes; - -use super::*; - -#[test] -fn test_slice_loss_indication_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - 0x82u8, 0xcd, 0x0, 0x3, // SliceLossIndication - 0x90, 0x2f, 0x9e, 0x2e, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e - 0x55, 0x50, 0x00, 0x2C, // nack 0xAAAA, 0x5555 - ]), - SliceLossIndication { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - sli_entries: vec![SliEntry { - first: 0xaaa, - number: 0, - picture: 0x2C, - }], - }, - None, - ), - ( - "short report", - Bytes::from_static(&[ - 0x82, 0xcd, 0x0, 0x2, // ssrc=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, - // report ends early - ]), - SliceLossIndication::default(), - Some(Error::PacketTooShort), - ), - ( - "wrong type", - Bytes::from_static(&[ - // v=2, p=0, count=1, SR, len=7 - 0x81, 0xc8, 0x0, 0x7, // ssrc=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0xbc5e9a40 - 0xbc, 0x5e, 0x9a, 0x40, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x0, 0x0, // lastSeq=0x46e1 - 0x0, 0x0, 0x46, 0xe1, // jitter=273 - 0x0, 0x0, 0x1, 0x11, // lsr=0x9f36432 - 0x9, 0xf3, 0x64, 0x32, // delay=150137 - 0x0, 0x2, 0x4a, 0x79, - ]), - SliceLossIndication::default(), - Some(Error::WrongType), - ), - ( - "nil", - Bytes::from_static(&[]), - SliceLossIndication::default(), - Some(Error::PacketTooShort), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = SliceLossIndication::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name} rr: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name} rr: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_slice_loss_indication_roundtrip() { - let tests: Vec<(&str, SliceLossIndication, Option)> = vec![( - "valid", - SliceLossIndication { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - sli_entries: vec![ - SliEntry { - first: 1, - number: 0xAA, - picture: 0x1F, - }, - SliEntry { - first: 1034, - number: 0x05, - picture: 0x6, - }, - ], - }, - None, - )]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = SliceLossIndication::unmarshal(&mut data) - .unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} diff --git a/rtcp/src/raw_packet.rs b/rtcp/src/raw_packet.rs deleted file mode 100644 index 68b8cd870..000000000 --- a/rtcp/src/raw_packet.rs +++ /dev/null @@ -1,168 +0,0 @@ -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::Packet; -use crate::util::*; - -/// RawPacket represents an unparsed RTCP packet. It's returned by Unmarshal when -/// a packet with an unknown type is encountered. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct RawPacket(pub Bytes); - -impl fmt::Display for RawPacket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "RawPacket: {self:?}") - } -} - -impl Packet for RawPacket { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - match Header::unmarshal(&mut self.0.clone()) { - Ok(h) => h, - Err(_) => Header::default(), - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - vec![] - } - - fn raw_size(&self) -> usize { - self.0.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for RawPacket { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for RawPacket { - /// Marshal encodes the packet in binary. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - let h = Header::unmarshal(&mut self.0.clone())?; - buf.put(self.0.clone()); - if h.padding { - put_padding(buf, self.raw_size()); - } - Ok(self.marshal_size()) - } -} - -impl Unmarshal for RawPacket { - /// Unmarshal decodes the packet from binary. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < HEADER_LENGTH { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - - let raw_hdr = h.marshal()?; - let raw_body = raw_packet.copy_to_bytes(raw_packet.remaining()); - let mut raw = BytesMut::new(); - raw.extend(raw_hdr); - raw.extend(raw_body); - - Ok(RawPacket(raw.freeze())) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_raw_packet_roundtrip() -> Result<(), Error> { - let tests: Vec<(&str, RawPacket, Option)> = vec![ - ( - "valid", - RawPacket(Bytes::from_static(&[ - 0x81, 0xcb, 0x00, 0x0c, // v=2, p=0, count=1, BYE, len=12 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x03, 0x46, 0x4f, 0x4f, // len=3, text=FOO - ])), - None, - ), - ( - "short header", - RawPacket(Bytes::from_static(&[0x80])), - Some(Error::PacketTooShort), - ), - ( - "invalid header", - RawPacket( - // v=0, p=0, count=0, RR, len=4 - Bytes::from_static(&[0x00, 0xc9, 0x00, 0x04]), - ), - Some(Error::BadVersion), - ), - ]; - - for (name, pkt, unmarshal_error) in tests { - let result = pkt.marshal(); - assert_eq!( - result.is_err(), - unmarshal_error.is_some(), - "Unmarshal {name}: err = {result:?}, want {unmarshal_error:?}" - ); - - if result.is_err() { - continue; - } - - let mut data = result.unwrap(); - - let result = RawPacket::unmarshal(&mut data); - - assert_eq!( - result.is_err(), - unmarshal_error.is_some(), - "Unmarshal {name}: err = {result:?}, want {unmarshal_error:?}" - ); - - if result.is_err() { - continue; - } - - let decoded = result.unwrap(); - - assert_eq!( - decoded, pkt, - "{name} raw round trip: got {decoded:?}, want {pkt:?}" - ) - } - - Ok(()) - } -} diff --git a/rtcp/src/receiver_report/mod.rs b/rtcp/src/receiver_report/mod.rs deleted file mode 100644 index 369c5d2c1..000000000 --- a/rtcp/src/receiver_report/mod.rs +++ /dev/null @@ -1,230 +0,0 @@ -#[cfg(test)] -mod receiver_report_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::reception_report::*; -use crate::util::*; - -type Result = std::result::Result; - -pub(super) const RR_SSRC_OFFSET: usize = HEADER_LENGTH; -pub(super) const RR_REPORT_OFFSET: usize = RR_SSRC_OFFSET + SSRC_LENGTH; - -/// A ReceiverReport (RR) packet provides reception quality feedback for an RTP stream -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct ReceiverReport { - /// The synchronization source identifier for the originator of this RR packet. - pub ssrc: u32, - /// Zero or more reception report blocks depending on the number of other - /// sources heard by this sender since the last report. Each reception report - /// block conveys statistics on the reception of RTP packets from a - /// single synchronization source. - pub reports: Vec, - /// Extension contains additional, payload-specific information that needs to - /// be reported regularly about the receiver. - pub profile_extensions: Bytes, -} - -impl fmt::Display for ReceiverReport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = format!("ReceiverReport from {}\n", self.ssrc); - out += "\tSSRC \tLost\tLastSequence\n"; - for rep in &self.reports { - out += format!( - "\t{:x}\t{}/{}\t{}\n", - rep.ssrc, rep.fraction_lost, rep.total_lost, rep.last_sequence_number - ) - .as_str(); - } - out += format!("\tProfile Extension Data: {:?}\n", self.profile_extensions).as_str(); - - write!(f, "{out}") - } -} - -impl Packet for ReceiverReport { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: self.reports.len() as u8, - packet_type: PacketType::ReceiverReport, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - self.reports.iter().map(|x| x.ssrc).collect() - } - - fn raw_size(&self) -> usize { - let mut reps_length = 0; - for rep in &self.reports { - reps_length += rep.marshal_size(); - } - - HEADER_LENGTH + SSRC_LENGTH + reps_length + self.profile_extensions.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for ReceiverReport { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for ReceiverReport { - /// marshal_to encodes the packet in binary. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if self.reports.len() > COUNT_MAX { - return Err(Error::TooManyReports.into()); - } - - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * header |V=2|P| RC | PT=RR=201 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SSRC of packet sender | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_1 (SSRC of first source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 1 | fraction lost | cumulative number of packets lost | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | extended highest sequence number received | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | interarrival jitter | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | last SR (LSR) | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | delay since last SR (DLSR) | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_2 (SSRC of second source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 2 : ... : - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | profile-specific extensions | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.ssrc); - - for report in &self.reports { - let n = report.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - buf.put(self.profile_extensions.clone()); - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for ReceiverReport { - /// Unmarshal decodes the ReceiverReport from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * header |V=2|P| RC | PT=RR=201 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SSRC of packet sender | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_1 (SSRC of first source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 1 | fraction lost | cumulative number of packets lost | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | extended highest sequence number received | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | interarrival jitter | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | last SR (LSR) | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | delay since last SR (DLSR) | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_2 (SSRC of second source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 2 : ... : - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | profile-specific extensions | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SSRC_LENGTH) { - return Err(Error::PacketTooShort.into()); - } - - let header = Header::unmarshal(raw_packet)?; - if header.packet_type != PacketType::ReceiverReport { - return Err(Error::WrongType.into()); - } - - let ssrc = raw_packet.get_u32(); - - let mut offset = RR_REPORT_OFFSET; - let mut reports = Vec::with_capacity(header.count as usize); - for _ in 0..header.count { - if offset + RECEPTION_REPORT_LENGTH > raw_packet_len { - return Err(Error::PacketTooShort.into()); - } - let reception_report = ReceptionReport::unmarshal(raw_packet)?; - reports.push(reception_report); - offset += RECEPTION_REPORT_LENGTH; - } - let profile_extensions = raw_packet.copy_to_bytes(raw_packet.remaining()); - /* - if header.padding && raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - */ - - Ok(ReceiverReport { - ssrc, - reports, - profile_extensions, - }) - } -} diff --git a/rtcp/src/receiver_report/receiver_report_test.rs b/rtcp/src/receiver_report/receiver_report_test.rs deleted file mode 100644 index 22450a54a..000000000 --- a/rtcp/src/receiver_report/receiver_report_test.rs +++ /dev/null @@ -1,242 +0,0 @@ -use super::*; - -#[test] -fn test_receiver_report_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - 0x81u8, 0xc9, 0x0, 0x7, // v=2, p=0, count=1, RR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - ]), - ReceiverReport { - ssrc: 0x902f9e2e, - reports: vec![ReceptionReport { - ssrc: 0xbc5e9a40, - fraction_lost: 0, - total_lost: 0, - last_sequence_number: 0x46e1, - jitter: 273, - last_sender_report: 0x9f36432, - delay: 150137, - }], - profile_extensions: Bytes::new(), - }, - None, - ), - ( - "valid with extension data", - Bytes::from_static(&[ - 0x81, 0xc9, 0x0, 0x9, // v=2, p=0, count=1, RR, len=9 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - 0x54, 0x45, 0x53, 0x54, 0x44, 0x41, 0x54, - 0x41, // profile-specific extension data - ]), - ReceiverReport { - ssrc: 0x902f9e2e, - reports: vec![ReceptionReport { - ssrc: 0xbc5e9a40, - fraction_lost: 0, - total_lost: 0, - last_sequence_number: 0x46e1, - jitter: 273, - last_sender_report: 0x9f36432, - delay: 150137, - }], - profile_extensions: Bytes::from_static(&[ - 0x54, 0x45, 0x53, 0x54, 0x44, 0x41, 0x54, 0x41, - ]), - }, - None, - ), - ( - "short report", - Bytes::from_static(&[ - 0x81, 0xc9, 0x00, 0x0c, // v=2, p=0, count=1, RR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0x00, 0x00, 0x00, - 0x00, // fracLost=0, totalLost=0 - // report ends early - ]), - ReceiverReport::default(), - Some(Error::PacketTooShort), - ), - ( - "wrong type", - Bytes::from_static(&[ - // v=2, p=0, count=1, SR, len=7 - 0x81, 0xc8, 0x0, 0x7, // ssrc=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0xbc5e9a40 - 0xbc, 0x5e, 0x9a, 0x40, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x0, 0x0, // lastSeq=0x46e1 - 0x0, 0x0, 0x46, 0xe1, // jitter=273 - 0x0, 0x0, 0x1, 0x11, // lsr=0x9f36432 - 0x9, 0xf3, 0x64, 0x32, // delay=150137 - 0x0, 0x2, 0x4a, 0x79, - ]), - ReceiverReport::default(), - Some(Error::WrongType), - ), - ( - "bad count in header", - Bytes::from_static(&[ - 0x82, 0xc9, 0x0, 0x7, // v=2, p=0, count=2, RR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - ]), - ReceiverReport::default(), - Some(Error::PacketTooShort), - ), - ( - "nil", - Bytes::from_static(&[]), - ReceiverReport::default(), - Some(Error::PacketTooShort), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = ReceiverReport::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name}: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name}: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_receiver_report_roundtrip() { - let mut too_many_reports = vec![]; - for _i in 0..(1 << 5) { - too_many_reports.push(ReceptionReport { - ssrc: 2, - fraction_lost: 2, - total_lost: 3, - last_sequence_number: 4, - jitter: 5, - last_sender_report: 6, - delay: 7, - }); - } - - let tests = vec![ - ( - "valid", - ReceiverReport { - ssrc: 1, - reports: vec![ - ReceptionReport { - ssrc: 2, - fraction_lost: 2, - total_lost: 3, - last_sequence_number: 4, - jitter: 5, - last_sender_report: 6, - delay: 7, - }, - ReceptionReport::default(), - ], - profile_extensions: Bytes::from_static(&[]), - }, - None, - ), - ( - "also valid", - ReceiverReport { - ssrc: 2, - reports: vec![ReceptionReport { - ssrc: 999, - fraction_lost: 30, - total_lost: 12345, - last_sequence_number: 99, - jitter: 22, - last_sender_report: 92, - delay: 46, - }], - ..Default::default() - }, - None, - ), - ( - "totallost overflow", - ReceiverReport { - ssrc: 1, - reports: vec![ReceptionReport { - total_lost: 1 << 25, - ..Default::default() - }], - ..Default::default() - }, - Some(Error::InvalidTotalLost), - ), - ( - "count overflow", - ReceiverReport { - ssrc: 1, - reports: too_many_reports, - ..Default::default() - }, - Some(Error::TooManyReports), - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = - ReceiverReport::unmarshal(&mut data).unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} diff --git a/rtcp/src/reception_report.rs b/rtcp/src/reception_report.rs deleted file mode 100644 index b9c49f2ad..000000000 --- a/rtcp/src/reception_report.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -pub(crate) const RECEPTION_REPORT_LENGTH: usize = 24; -pub(crate) const FRACTION_LOST_OFFSET: usize = 4; -pub(crate) const TOTAL_LOST_OFFSET: usize = 5; -pub(crate) const LAST_SEQ_OFFSET: usize = 8; -pub(crate) const JITTER_OFFSET: usize = 12; -pub(crate) const LAST_SR_OFFSET: usize = 16; -pub(crate) const DELAY_OFFSET: usize = 20; - -/// A ReceptionReport block conveys statistics on the reception of RTP packets -/// from a single synchronization source. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct ReceptionReport { - /// The SSRC identifier of the source to which the information in this - /// reception report block pertains. - pub ssrc: u32, - /// The fraction of RTP data packets from source SSRC lost since the - /// previous SR or RR packet was sent, expressed as a fixed point - /// number with the binary point at the left edge of the field. - pub fraction_lost: u8, - /// The total number of RTP data packets from source SSRC that have - /// been lost since the beginning of reception. - pub total_lost: u32, - /// The least significant 16 bits contain the highest sequence number received - /// in an RTP data packet from source SSRC, and the most significant 16 bits extend - /// that sequence number with the corresponding count of sequence number cycles. - pub last_sequence_number: u32, - /// An estimate of the statistical variance of the RTP data packet - /// interarrival time, measured in timestamp units and expressed as an - /// unsigned integer. - pub jitter: u32, - /// The middle 32 bits out of 64 in the NTP timestamp received as part of - /// the most recent RTCP sender report (SR) packet from source SSRC. If no - /// SR has been received yet, the field is set to zero. - pub last_sender_report: u32, - /// The delay, expressed in units of 1/65536 seconds, between receiving the - /// last SR packet from source SSRC and sending this reception report block. - /// If no SR packet has been received yet from SSRC, the field is set to zero. - pub delay: u32, -} - -impl fmt::Display for ReceptionReport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{self:?}") - } -} - -impl Packet for ReceptionReport { - fn header(&self) -> Header { - Header::default() - } - - fn destination_ssrc(&self) -> Vec { - vec![] - } - - fn raw_size(&self) -> usize { - RECEPTION_REPORT_LENGTH - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for ReceptionReport { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for ReceptionReport { - /// marshal_to encodes the ReceptionReport in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | SSRC | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | fraction lost | cumulative number of packets lost | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | extended highest sequence number received | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | interarrival jitter | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | last SR (LSR) | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | delay since last SR (DLSR) | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - */ - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - buf.put_u32(self.ssrc); - - buf.put_u8(self.fraction_lost); - - // pack TotalLost into 24 bits - if self.total_lost >= (1 << 25) { - return Err(Error::InvalidTotalLost.into()); - } - - buf.put_u8(((self.total_lost >> 16) & 0xFF) as u8); - buf.put_u8(((self.total_lost >> 8) & 0xFF) as u8); - buf.put_u8((self.total_lost & 0xFF) as u8); - - buf.put_u32(self.last_sequence_number); - buf.put_u32(self.jitter); - buf.put_u32(self.last_sender_report); - buf.put_u32(self.delay); - - put_padding(buf, self.raw_size()); - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for ReceptionReport { - /// unmarshal decodes the ReceptionReport from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < RECEPTION_REPORT_LENGTH { - return Err(Error::PacketTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | SSRC | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | fraction lost | cumulative number of packets lost | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | extended highest sequence number received | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | interarrival jitter | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | last SR (LSR) | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | delay since last SR (DLSR) | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - */ - let ssrc = raw_packet.get_u32(); - let fraction_lost = raw_packet.get_u8(); - - let t0 = raw_packet.get_u8(); - let t1 = raw_packet.get_u8(); - let t2 = raw_packet.get_u8(); - // TODO: The type of `total_lost` should be `i32`, per the RFC: - // The total number of RTP data packets from source SSRC_n that have - // been lost since the beginning of reception. This number is - // defined to be the number of packets expected less the number of - // packets actually received, where the number of packets received - // includes any which are late or duplicates. Thus, packets that - // arrive late are not counted as lost, and the loss may be negative - // if there are duplicates. The number of packets expected is - // defined to be the extended last sequence number received, as - // defined next, less the initial sequence number received. This may - // be calculated as shown in Appendix A.3. - let total_lost = (t2 as u32) | (t1 as u32) << 8 | (t0 as u32) << 16; - - let last_sequence_number = raw_packet.get_u32(); - let jitter = raw_packet.get_u32(); - let last_sender_report = raw_packet.get_u32(); - let delay = raw_packet.get_u32(); - - Ok(ReceptionReport { - ssrc, - fraction_lost, - total_lost, - last_sequence_number, - jitter, - last_sender_report, - delay, - }) - } -} diff --git a/rtcp/src/sender_report/mod.rs b/rtcp/src/sender_report/mod.rs deleted file mode 100644 index 1e66d9b83..000000000 --- a/rtcp/src/sender_report/mod.rs +++ /dev/null @@ -1,299 +0,0 @@ -#[cfg(test)] -mod sender_report_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::reception_report::*; -use crate::util::*; - -type Result = std::result::Result; - -pub(crate) const SR_HEADER_LENGTH: usize = 24; -pub(crate) const SR_SSRC_OFFSET: usize = HEADER_LENGTH; -pub(crate) const SR_REPORT_OFFSET: usize = SR_SSRC_OFFSET + SR_HEADER_LENGTH; - -pub(crate) const SR_NTP_OFFSET: usize = SR_SSRC_OFFSET + SSRC_LENGTH; -pub(crate) const NTP_TIME_LENGTH: usize = 8; -pub(crate) const SR_RTP_OFFSET: usize = SR_NTP_OFFSET + NTP_TIME_LENGTH; -pub(crate) const RTP_TIME_LENGTH: usize = 4; -pub(crate) const SR_PACKET_COUNT_OFFSET: usize = SR_RTP_OFFSET + RTP_TIME_LENGTH; -pub(crate) const SR_PACKET_COUNT_LENGTH: usize = 4; -pub(crate) const SR_OCTET_COUNT_OFFSET: usize = SR_PACKET_COUNT_OFFSET + SR_PACKET_COUNT_LENGTH; -pub(crate) const SR_OCTET_COUNT_LENGTH: usize = 4; - -/// A SenderReport (SR) packet provides reception quality feedback for an RTP stream -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct SenderReport { - /// The synchronization source identifier for the originator of this SR packet. - pub ssrc: u32, - /// The wallclock time when this report was sent so that it may be used in - /// combination with timestamps returned in reception reports from other - /// receivers to measure round-trip propagation to those receivers. - pub ntp_time: u64, - /// Corresponds to the same time as the NTP timestamp (above), but in - /// the same units and with the same random offset as the RTP - /// timestamps in data packets. This correspondence may be used for - /// intra- and inter-media synchronization for sources whose NTP - /// timestamps are synchronized, and may be used by media-independent - /// receivers to estimate the nominal RTP clock frequency. - pub rtp_time: u32, - /// The total number of RTP data packets transmitted by the sender - /// since starting transmission up until the time this SR packet was - /// generated. - pub packet_count: u32, - /// The total number of payload octets (i.e., not including header or - /// padding) transmitted in RTP data packets by the sender since - /// starting transmission up until the time this SR packet was - /// generated. - pub octet_count: u32, - /// Zero or more reception report blocks depending on the number of other - /// sources heard by this sender since the last report. Each reception report - /// block conveys statistics on the reception of RTP packets from a - /// single synchronization source. - pub reports: Vec, - - /// ProfileExtensions contains additional, payload-specific information that needs to - /// be reported regularly about the sender. - pub profile_extensions: Bytes, -} - -impl fmt::Display for SenderReport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = format!("SenderReport from {}\n", self.ssrc); - out += format!("\tNTPTime:\t{}\n", self.ntp_time).as_str(); - out += format!("\tRTPTIme:\t{}\n", self.rtp_time).as_str(); - out += format!("\tPacketCount:\t{}\n", self.packet_count).as_str(); - out += format!("\tOctetCount:\t{}\n", self.octet_count).as_str(); - out += "\tSSRC \tLost\tLastSequence\n"; - for rep in &self.reports { - out += format!( - "\t{:x}\t{}/{}\t{}\n", - rep.ssrc, rep.fraction_lost, rep.total_lost, rep.last_sequence_number - ) - .as_str(); - } - out += format!("\tProfile Extension Data: {:?}\n", self.profile_extensions).as_str(); - - write!(f, "{out}") - } -} - -impl Packet for SenderReport { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: self.reports.len() as u8, - packet_type: PacketType::SenderReport, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - let mut out: Vec = self.reports.iter().map(|x| x.ssrc).collect(); - out.push(self.ssrc); - out - } - - fn raw_size(&self) -> usize { - let mut reps_length = 0; - for rep in &self.reports { - reps_length += rep.marshal_size(); - } - - HEADER_LENGTH + SR_HEADER_LENGTH + reps_length + self.profile_extensions.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for SenderReport { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for SenderReport { - /// Marshal encodes the packet in binary. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if self.reports.len() > COUNT_MAX { - return Err(Error::TooManyReports.into()); - } - - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * header |V=2|P| RC | PT=SR=200 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SSRC of sender | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * sender | NTP timestamp, most significant word | - * info +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | NTP timestamp, least significant word | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | RTP timestamp | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | sender's packet count | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | sender's octet count | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_1 (SSRC of first source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 1 | fraction lost | cumulative number of packets lost | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | extended highest sequence number received | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | interarrival jitter | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | last SR (LSR) | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | delay since last SR (DLSR) | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_2 (SSRC of second source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 2 : ... : - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | profile-specific extensions | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.ssrc); - buf.put_u64(self.ntp_time); - buf.put_u32(self.rtp_time); - buf.put_u32(self.packet_count); - buf.put_u32(self.octet_count); - - for report in &self.reports { - let n = report.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - buf.put(self.profile_extensions.clone()); - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for SenderReport { - /// Unmarshal decodes the SenderReport from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * header |V=2|P| RC | PT=SR=200 | length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SSRC of sender | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * sender | NTP timestamp, most significant word | - * info +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | NTP timestamp, least significant word | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | RTP timestamp | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | sender's packet count | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | sender's octet count | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_1 (SSRC of first source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 1 | fraction lost | cumulative number of packets lost | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | extended highest sequence number received | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | interarrival jitter | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | last SR (LSR) | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | delay since last SR (DLSR) | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * report | SSRC_2 (SSRC of second source) | - * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * 2 : ... : - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | profile-specific extensions | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SR_HEADER_LENGTH) { - return Err(Error::PacketTooShort.into()); - } - - let header = Header::unmarshal(raw_packet)?; - if header.packet_type != PacketType::SenderReport { - return Err(Error::WrongType.into()); - } - - let ssrc = raw_packet.get_u32(); - let ntp_time = raw_packet.get_u64(); - let rtp_time = raw_packet.get_u32(); - let packet_count = raw_packet.get_u32(); - let octet_count = raw_packet.get_u32(); - - let mut offset = SR_REPORT_OFFSET; - let mut reports = Vec::with_capacity(header.count as usize); - for _ in 0..header.count { - if offset + RECEPTION_REPORT_LENGTH > raw_packet_len { - return Err(Error::PacketTooShort.into()); - } - let reception_report = ReceptionReport::unmarshal(raw_packet)?; - reports.push(reception_report); - offset += RECEPTION_REPORT_LENGTH; - } - let profile_extensions = raw_packet.copy_to_bytes(raw_packet.remaining()); - /* - if header.padding && raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - */ - - Ok(SenderReport { - ssrc, - ntp_time, - rtp_time, - packet_count, - octet_count, - reports, - profile_extensions, - }) - } -} diff --git a/rtcp/src/sender_report/sender_report_test.rs b/rtcp/src/sender_report/sender_report_test.rs deleted file mode 100644 index 7cd7617ed..000000000 --- a/rtcp/src/sender_report/sender_report_test.rs +++ /dev/null @@ -1,252 +0,0 @@ -use super::*; - -#[test] -fn test_sender_report_unmarshal() { - let tests = vec![ - ( - "nil", - Bytes::from_static(&[]), - SenderReport::default(), - Some(Error::PacketTooShort), - ), - ( - "valid", - Bytes::from_static(&[ - 0x81u8, 0xc8, 0x0, 0x7, // v=2, p=0, count=1, SR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xda, 0x8b, 0xd1, 0xfc, 0xdd, 0xdd, 0xa0, 0x5a, // ntp=0xda8bd1fcdddda05a - 0xaa, 0xf4, 0xed, 0xd5, // rtp=0xaaf4edd5 - 0x00, 0x00, 0x00, 0x01, // packetCount=1 - 0x00, 0x00, 0x00, 0x02, // octetCount=2 - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - ]), - SenderReport { - ssrc: 0x902f9e2e, - ntp_time: 0xda8bd1fcdddda05a, - rtp_time: 0xaaf4edd5, - packet_count: 1, - octet_count: 2, - reports: vec![ReceptionReport { - ssrc: 0xbc5e9a40, - fraction_lost: 0, - total_lost: 0, - last_sequence_number: 0x46e1, - jitter: 273, - last_sender_report: 0x9f36432, - delay: 150137, - }], - profile_extensions: Bytes::from_static(&[]), - }, - None, - ), - ( - "wrong type", - Bytes::from_static(&[ - 0x81, 0xc9, 0x0, 0x7, // v=2, p=0, count=1, RR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xda, 0x8b, 0xd1, 0xfc, 0xdd, 0xdd, 0xa0, 0x5a, // ntp=0xda8bd1fcdddda05a - 0xaa, 0xf4, 0xed, 0xd5, // rtp=0xaaf4edd5 - 0x00, 0x00, 0x00, 0x01, // packetCount=1 - 0x00, 0x00, 0x00, 0x02, // octetCount=2 - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // jitter=273 - 0x0, 0x0, 0x1, 0x11, // lastSeq=0x46e1 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - ]), - SenderReport::default(), - Some(Error::WrongType), - ), - ( - "bad count in header", - Bytes::from_static(&[ - 0x82, 0xc8, 0x0, 0x7, // v=2, p=0, count=1, SR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xda, 0x8b, 0xd1, 0xfc, 0xdd, 0xdd, 0xa0, 0x5a, // ntp=0xda8bd1fcdddda05a - 0xaa, 0xf4, 0xed, 0xd5, // rtp=0xaaf4edd5 - 0x00, 0x00, 0x00, 0x01, // packetCount=1 - 0x00, 0x00, 0x00, 0x02, // octetCount=2 - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - ]), - SenderReport::default(), - Some(Error::PacketTooShort), - ), - ( - "with extension", // issue #447 - Bytes::from_static(&[ - 0x80, 0xc8, 0x0, 0x6, // v=2, p=0, count=0, SR, len=6 - 0x2b, 0x7e, 0xc0, 0xc5, // ssrc=0x2b7ec0c5 - 0xe0, 0x20, 0xa2, 0xa9, 0x52, 0xa5, 0x3f, 0xc0, // ntp=0xe020a2a952a53fc0 - 0x2e, 0x48, 0xa5, 0x52, // rtp=0x2e48a552 - 0x0, 0x0, 0x0, 0x46, // packetCount=70 - 0x0, 0x0, 0x12, 0x1d, // octetCount=4637 - 0x81, 0xca, 0x0, 0x6, 0x2b, 0x7e, 0xc0, 0xc5, 0x1, 0x10, 0x4c, 0x63, 0x49, 0x66, - 0x7a, 0x58, 0x6f, 0x6e, 0x44, 0x6f, 0x72, 0x64, 0x53, 0x65, 0x57, 0x36, 0x0, - 0x0, // profile-specific extension - ]), - SenderReport { - ssrc: 0x2b7ec0c5, - ntp_time: 0xe020a2a952a53fc0, - rtp_time: 0x2e48a552, - packet_count: 70, - octet_count: 4637, - reports: vec![], - profile_extensions: Bytes::from_static(&[ - 0x81, 0xca, 0x0, 0x6, 0x2b, 0x7e, 0xc0, 0xc5, 0x1, 0x10, 0x4c, 0x63, 0x49, - 0x66, 0x7a, 0x58, 0x6f, 0x6e, 0x44, 0x6f, 0x72, 0x64, 0x53, 0x65, 0x57, 0x36, - 0x0, 0x0, - ]), - }, - None, - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = SenderReport::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name}: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name}: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_sender_report_roundtrip() { - let mut too_many_reports = vec![]; - for _i in 0..(1 << 5) { - too_many_reports.push(ReceptionReport { - ssrc: 2, - fraction_lost: 2, - total_lost: 3, - last_sequence_number: 4, - jitter: 5, - last_sender_report: 6, - delay: 7, - }); - } - - let tests = vec![ - ( - "valid", - SenderReport { - ssrc: 1, - ntp_time: 999, - rtp_time: 555, - packet_count: 32, - octet_count: 11, - reports: vec![ - ReceptionReport { - ssrc: 2, - fraction_lost: 2, - total_lost: 3, - last_sequence_number: 4, - jitter: 5, - last_sender_report: 6, - delay: 7, - }, - ReceptionReport::default(), - ], - profile_extensions: Bytes::from_static(&[]), - }, - None, - ), - ( - "also valid", - SenderReport { - ssrc: 2, - reports: vec![ReceptionReport { - ssrc: 999, - fraction_lost: 30, - total_lost: 12345, - last_sequence_number: 99, - jitter: 22, - last_sender_report: 92, - delay: 46, - }], - ..Default::default() - }, - None, - ), - ( - "extension", - SenderReport { - ssrc: 2, - reports: vec![ReceptionReport { - ssrc: 999, - fraction_lost: 30, - total_lost: 12345, - last_sequence_number: 99, - jitter: 22, - last_sender_report: 92, - delay: 46, - }], - profile_extensions: Bytes::from_static(&[1, 2, 3, 4]), - ..Default::default() - }, - None, - ), - ( - "count overflow", - SenderReport { - ssrc: 1, - reports: too_many_reports, - ..Default::default() - }, - Some(Error::TooManyReports), - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = - SenderReport::unmarshal(&mut data).unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} diff --git a/rtcp/src/source_description/mod.rs b/rtcp/src/source_description/mod.rs deleted file mode 100644 index 5cf66545b..000000000 --- a/rtcp/src/source_description/mod.rs +++ /dev/null @@ -1,440 +0,0 @@ -#[cfg(test)] -mod source_description_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -const SDES_SOURCE_LEN: usize = 4; -const SDES_TYPE_LEN: usize = 1; -const SDES_TYPE_OFFSET: usize = 0; -const SDES_OCTET_COUNT_LEN: usize = 1; -const SDES_OCTET_COUNT_OFFSET: usize = 1; -const SDES_MAX_OCTET_COUNT: usize = (1 << 8) - 1; -const SDES_TEXT_OFFSET: usize = 2; - -/// SDESType is the item type used in the RTCP SDES control packet. -/// RTP SDES item types registered with IANA. See: https://www.iana.org/assignments/rtp-parameters/rtp-parameters.xhtml#rtp-parameters-5 -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -#[repr(u8)] -pub enum SdesType { - #[default] - SdesEnd = 0, // end of SDES list RFC 3550, 6.5 - SdesCname = 1, // canonical name RFC 3550, 6.5.1 - SdesName = 2, // user name RFC 3550, 6.5.2 - SdesEmail = 3, // user's electronic mail address RFC 3550, 6.5.3 - SdesPhone = 4, // user's phone number RFC 3550, 6.5.4 - SdesLocation = 5, // geographic user location RFC 3550, 6.5.5 - SdesTool = 6, // name of application or tool RFC 3550, 6.5.6 - SdesNote = 7, // notice about the source RFC 3550, 6.5.7 - SdesPrivate = 8, // private extensions RFC 3550, 6.5.8 (not implemented) -} - -impl fmt::Display for SdesType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - SdesType::SdesEnd => "END", - SdesType::SdesCname => "CNAME", - SdesType::SdesName => "NAME", - SdesType::SdesEmail => "EMAIL", - SdesType::SdesPhone => "PHONE", - SdesType::SdesLocation => "LOC", - SdesType::SdesTool => "TOOL", - SdesType::SdesNote => "NOTE", - SdesType::SdesPrivate => "PRIV", - }; - write!(f, "{s}") - } -} - -impl From for SdesType { - fn from(b: u8) -> Self { - match b { - 1 => SdesType::SdesCname, - 2 => SdesType::SdesName, - 3 => SdesType::SdesEmail, - 4 => SdesType::SdesPhone, - 5 => SdesType::SdesLocation, - 6 => SdesType::SdesTool, - 7 => SdesType::SdesNote, - 8 => SdesType::SdesPrivate, - _ => SdesType::SdesEnd, - } - } -} - -/// A SourceDescriptionChunk contains items describing a single RTP source -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct SourceDescriptionChunk { - /// The source (ssrc) or contributing source (csrc) identifier this packet describes - pub source: u32, - pub items: Vec, -} - -impl SourceDescriptionChunk { - fn raw_size(&self) -> usize { - let mut len = SDES_SOURCE_LEN; - for it in &self.items { - len += it.marshal_size(); - } - len += SDES_TYPE_LEN; // for terminating null octet - len - } -} - -impl MarshalSize for SourceDescriptionChunk { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for SourceDescriptionChunk { - /// Marshal encodes the SourceDescriptionChunk in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - /* - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | SSRC/CSRC_1 | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SDES items | - * | ... | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - */ - - buf.put_u32(self.source); - - for it in &self.items { - let n = it.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - // The list of items in each chunk MUST be terminated by one or more null octets - buf.put_u8(SdesType::SdesEnd as u8); - - // additional null octets MUST be included if needed to pad until the next 32-bit boundary - put_padding(buf, self.raw_size()); - Ok(self.marshal_size()) - } -} - -impl Unmarshal for SourceDescriptionChunk { - /// Unmarshal decodes the SourceDescriptionChunk from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - /* - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | SSRC/CSRC_1 | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SDES items | - * | ... | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - */ - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (SDES_SOURCE_LEN + SDES_TYPE_LEN) { - return Err(Error::PacketTooShort.into()); - } - - let source = raw_packet.get_u32(); - - let mut offset = SDES_SOURCE_LEN; - let mut items = vec![]; - while offset < raw_packet_len { - let item = SourceDescriptionItem::unmarshal(raw_packet)?; - if item.sdes_type == SdesType::SdesEnd { - // offset + 1 (one byte for SdesEnd) - let padding_len = get_padding_size(offset + 1); - if raw_packet.remaining() >= padding_len { - raw_packet.advance(padding_len); - return Ok(SourceDescriptionChunk { source, items }); - } else { - return Err(Error::PacketTooShort.into()); - } - } - offset += item.marshal_size(); - items.push(item); - } - - Err(Error::PacketTooShort.into()) - } -} - -/// A SourceDescriptionItem is a part of a SourceDescription that describes a stream. -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct SourceDescriptionItem { - /// The type identifier for this item. eg, SDESCNAME for canonical name description. - /// - /// Type zero or SDESEnd is interpreted as the end of an item list and cannot be used. - pub sdes_type: SdesType, - /// Text is a unicode text blob associated with the item. Its meaning varies based on the item's Type. - pub text: Bytes, -} - -impl MarshalSize for SourceDescriptionItem { - fn marshal_size(&self) -> usize { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | CNAME=1 | length | user and domain name ... - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - SDES_TYPE_LEN + SDES_OCTET_COUNT_LEN + self.text.len() - } -} - -impl Marshal for SourceDescriptionItem { - /// Marshal encodes the SourceDescriptionItem in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | CNAME=1 | length | user and domain name ... - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - - if self.sdes_type == SdesType::SdesEnd { - return Err(Error::SdesMissingType.into()); - } - - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - buf.put_u8(self.sdes_type as u8); - - if self.text.len() > SDES_MAX_OCTET_COUNT { - return Err(Error::SdesTextTooLong.into()); - } - buf.put_u8(self.text.len() as u8); - buf.put(self.text.clone()); - - //no padding for each SourceDescriptionItem - Ok(self.marshal_size()) - } -} - -impl Unmarshal for SourceDescriptionItem { - /// Unmarshal decodes the SourceDescriptionItem from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | CNAME=1 | length | user and domain name ... - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < SDES_TYPE_LEN { - return Err(Error::PacketTooShort.into()); - } - - let sdes_type = SdesType::from(raw_packet.get_u8()); - if sdes_type == SdesType::SdesEnd { - return Ok(SourceDescriptionItem { - sdes_type, - text: Bytes::new(), - }); - } - - if raw_packet_len < (SDES_TYPE_LEN + SDES_OCTET_COUNT_LEN) { - return Err(Error::PacketTooShort.into()); - } - - let octet_count = raw_packet.get_u8() as usize; - if SDES_TEXT_OFFSET + octet_count > raw_packet_len { - return Err(Error::PacketTooShort.into()); - } - - let text = raw_packet.copy_to_bytes(octet_count); - - Ok(SourceDescriptionItem { sdes_type, text }) - } -} - -/// A SourceDescription (SDES) packet describes the sources in an RTP stream. -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SourceDescription { - pub chunks: Vec, -} - -impl fmt::Display for SourceDescription { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = "Source Description:\n".to_string(); - for c in &self.chunks { - out += format!("\t{:x}\n", c.source).as_str(); - for it in &c.items { - out += format!("\t\t{it:?}\n").as_str(); - } - } - write!(f, "{out}") - } -} - -impl Packet for SourceDescription { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: self.chunks.len() as u8, - packet_type: PacketType::SourceDescription, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - self.chunks.iter().map(|x| x.source).collect() - } - - fn raw_size(&self) -> usize { - let mut chunks_length = 0; - for c in &self.chunks { - chunks_length += c.marshal_size(); - } - - HEADER_LENGTH + chunks_length - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for SourceDescription { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for SourceDescription { - /// Marshal encodes the SourceDescription in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if self.chunks.len() > COUNT_MAX { - return Err(Error::TooManyChunks.into()); - } - - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * header |V=2|P| SC | PT=SDES=202 | length | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * chunk | SSRC/CSRC_1 | - * 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SDES items | - * | ... | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * chunk | SSRC/CSRC_2 | - * 2 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SDES items | - * | ... | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - */ - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - for c in &self.chunks { - let n = c.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for SourceDescription { - /// Unmarshal decodes the SourceDescription from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * header |V=2|P| SC | PT=SDES=202 | length | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * chunk | SSRC/CSRC_1 | - * 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SDES items | - * | ... | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * chunk | SSRC/CSRC_2 | - * 2 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | SDES items | - * | ... | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - */ - let raw_packet_len = raw_packet.remaining(); - - let h = Header::unmarshal(raw_packet)?; - if h.packet_type != PacketType::SourceDescription { - return Err(Error::WrongType.into()); - } - - let mut offset = HEADER_LENGTH; - let mut chunks = vec![]; - while offset < raw_packet_len { - let chunk = SourceDescriptionChunk::unmarshal(raw_packet)?; - offset += chunk.marshal_size(); - chunks.push(chunk); - } - - if chunks.len() != h.count as usize { - return Err(Error::InvalidHeader.into()); - } - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(SourceDescription { chunks }) - } -} diff --git a/rtcp/src/source_description/source_description_test.rs b/rtcp/src/source_description/source_description_test.rs deleted file mode 100644 index cdc0d68fb..000000000 --- a/rtcp/src/source_description/source_description_test.rs +++ /dev/null @@ -1,359 +0,0 @@ -use super::*; - -#[test] -fn test_source_description_unmarshal() { - let tests = vec![ - ( - "nil", - Bytes::from_static(&[]), - SourceDescription::default(), - Some(Error::PacketTooShort), - ), - ( - "no chunks", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=8 - 0x80, 0xca, 0x00, 0x04, - ]), - SourceDescription::default(), - None, - ), - ( - "missing type", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=8 - 0x81, 0xca, 0x00, 0x08, // ssrc=0x00000000 - 0x00, 0x00, 0x00, 0x00, - ]), - SourceDescription::default(), - Some(Error::PacketTooShort), - ), - ( - "bad cname length", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=10 - 0x81, 0xca, 0x00, 0x0a, // ssrc=0x00000000 - 0x00, 0x00, 0x00, 0x00, // CNAME, len = 1 - 0x01, 0x01, - ]), - SourceDescription::default(), - Some(Error::PacketTooShort), - ), - ( - "short cname", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=9 - 0x81, 0xca, 0x00, 0x09, // ssrc=0x00000000 - 0x00, 0x00, 0x00, 0x00, // CNAME, Missing length - 0x01, - ]), - SourceDescription::default(), - Some(Error::PacketTooShort), - ), - ( - "no end", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=11 - 0x81, 0xca, 0x00, 0x0b, // ssrc=0x00000000 - 0x00, 0x00, 0x00, 0x00, // CNAME, len=1, content=A - 0x01, 0x02, 0x41, - // Missing END - ]), - SourceDescription::default(), - Some(Error::PacketTooShort), - ), - ( - "bad octet count", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=10 - 0x81, 0xca, 0x00, 0x0a, // ssrc=0x00000000 - 0x00, 0x00, 0x00, 0x00, // CNAME, len=1 - 0x01, 0x01, - ]), - SourceDescription::default(), - Some(Error::PacketTooShort), - ), - ( - "zero item chunk", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=12 - 0x81, 0xca, 0x00, 0x0c, // ssrc=0x01020304 - 0x01, 0x02, 0x03, 0x04, // END + padding - 0x00, 0x00, 0x00, 0x00, - ]), - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 0x01020304, - items: vec![], - }], - }, - None, - ), - ( - "wrong type", - Bytes::from_static(&[ - // v=2, p=0, count=1, SR, len=12 - 0x81, 0xc8, 0x00, 0x0c, // ssrc=0x01020304 - 0x01, 0x02, 0x03, 0x04, // END + padding - 0x00, 0x00, 0x00, 0x00, - ]), - SourceDescription::default(), - Some(Error::WrongType), - ), - ( - "bad count in header", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=12 - 0x81, 0xca, 0x00, 0x0c, - ]), - SourceDescription::default(), - Some(Error::InvalidHeader), - ), - ( - "empty string", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=12 - 0x81, 0xca, 0x00, 0x0c, // ssrc=0x01020304 - 0x01, 0x02, 0x03, 0x04, // CNAME, len=0 - 0x01, 0x00, // END + padding - 0x00, 0x00, - ]), - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 0x01020304, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b""), - }], - }], - }, - None, - ), - ( - "two items", - Bytes::from_static(&[ - // v=2, p=0, count=1, SDES, len=16 - 0x81, 0xca, 0x00, 0x10, // ssrc=0x10000000 - 0x10, 0x00, 0x00, 0x00, // CNAME, len=1, content=A - 0x01, 0x01, 0x41, // PHONE, len=1, content=B - 0x04, 0x01, 0x42, // END + padding - 0x00, 0x00, - ]), - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 0x10000000, - items: vec![ - SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"A"), - }, - SourceDescriptionItem { - sdes_type: SdesType::SdesPhone, - text: Bytes::from_static(b"B"), - }, - ], - }], - }, - None, - ), - ( - "two chunks", - Bytes::from_static(&[ - // v=2, p=0, count=2, SDES, len=24 - 0x82, 0xca, 0x00, 0x18, // ssrc=0x01020304 - 0x01, 0x02, 0x03, 0x04, - // Chunk 1 - // CNAME, len=1, content=A - 0x01, 0x01, 0x41, // END - 0x00, // Chunk 2 - // SSRC 0x05060708 - 0x05, 0x06, 0x07, 0x08, // CNAME, len=3, content=BCD - 0x01, 0x03, 0x42, 0x43, 0x44, // END - 0x00, 0x00, 0x00, - ]), - SourceDescription { - chunks: vec![ - SourceDescriptionChunk { - source: 0x01020304, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"A"), - }], - }, - SourceDescriptionChunk { - source: 0x05060708, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"BCD"), - }], - }, - ], - }, - None, - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = SourceDescription::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name}: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name}: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_source_description_roundtrip() { - let mut too_long_text = String::new(); - for _ in 0..(1 << 8) { - too_long_text += "x"; - } - - let mut too_many_chunks = vec![]; - for _ in 0..(1 << 5) { - too_many_chunks.push(SourceDescriptionChunk::default()); - } - - let tests = vec![ - ( - "valid", - SourceDescription { - chunks: vec![ - SourceDescriptionChunk { - source: 1, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b"test@example.com"), - }], - }, - SourceDescriptionChunk { - source: 2, - items: vec![ - SourceDescriptionItem { - sdes_type: SdesType::SdesNote, - text: Bytes::from_static(b"some note"), - }, - SourceDescriptionItem { - sdes_type: SdesType::SdesNote, - text: Bytes::from_static(b"another note"), - }, - ], - }, - ], - }, - None, - ), - ( - "item without type", - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesEnd, - text: Bytes::from_static(b"test@example.com"), - }], - }], - }, - Some(Error::SdesMissingType), - ), - ( - "zero items", - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1, - items: vec![], - }], - }, - None, - ), - ( - "email item", - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesEmail, - text: Bytes::from_static(b"test@example.com"), - }], - }], - }, - None, - ), - ( - "empty text", - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::from_static(b""), - }], - }], - }, - None, - ), - ( - "text too long", - SourceDescription { - chunks: vec![SourceDescriptionChunk { - source: 1, - items: vec![SourceDescriptionItem { - sdes_type: SdesType::SdesCname, - text: Bytes::copy_from_slice(too_long_text.as_bytes()), - }], - }], - }, - Some(Error::SdesTextTooLong), - ), - ( - "count overflow", - SourceDescription { - chunks: too_many_chunks, - }, - Some(Error::TooManyChunks), - ), - ]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = SourceDescription::unmarshal(&mut data) - .unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} diff --git a/rtcp/src/transport_feedbacks/mod.rs b/rtcp/src/transport_feedbacks/mod.rs deleted file mode 100644 index f59db6075..000000000 --- a/rtcp/src/transport_feedbacks/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod rapid_resynchronization_request; -pub mod transport_layer_cc; -pub mod transport_layer_nack; diff --git a/rtcp/src/transport_feedbacks/rapid_resynchronization_request/mod.rs b/rtcp/src/transport_feedbacks/rapid_resynchronization_request/mod.rs deleted file mode 100644 index 7ea106755..000000000 --- a/rtcp/src/transport_feedbacks/rapid_resynchronization_request/mod.rs +++ /dev/null @@ -1,144 +0,0 @@ -#[cfg(test)] -mod rapid_resynchronization_request_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -const RRR_LENGTH: usize = 2; -const RRR_HEADER_LENGTH: usize = SSRC_LENGTH * 2; -const RRR_MEDIA_OFFSET: usize = 4; - -/// The RapidResynchronizationRequest packet informs the encoder about the loss of an undefined amount of coded video data belonging to one or more pictures -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct RapidResynchronizationRequest { - /// SSRC of sender - pub sender_ssrc: u32, - /// SSRC of the media source - pub media_ssrc: u32, -} - -impl fmt::Display for RapidResynchronizationRequest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "RapidResynchronizationRequest {:x} {:x}", - self.sender_ssrc, self.media_ssrc - ) - } -} - -impl Packet for RapidResynchronizationRequest { - /// Header returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_RRR, - packet_type: PacketType::TransportSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// Destination SSRC returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.media_ssrc] - } - - fn raw_size(&self) -> usize { - HEADER_LENGTH + RRR_HEADER_LENGTH - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for RapidResynchronizationRequest { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for RapidResynchronizationRequest { - /// Marshal encodes the RapidResynchronizationRequest in binary - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - /* - * RRR does not require parameters. Therefore, the length field MUST be - * 2, and there MUST NOT be any Feedback Control Information. - * - * The semantics of this FB message is independent of the payload type. - */ - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(self.media_ssrc); - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for RapidResynchronizationRequest { - /// Unmarshal decodes the RapidResynchronizationRequest from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + (SSRC_LENGTH * 2)) { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - - if h.packet_type != PacketType::TransportSpecificFeedback || h.count != FORMAT_RRR { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(RapidResynchronizationRequest { - sender_ssrc, - media_ssrc, - }) - } -} diff --git a/rtcp/src/transport_feedbacks/rapid_resynchronization_request/rapid_resynchronization_request_test.rs b/rtcp/src/transport_feedbacks/rapid_resynchronization_request/rapid_resynchronization_request_test.rs deleted file mode 100644 index 6eb86e6a4..000000000 --- a/rtcp/src/transport_feedbacks/rapid_resynchronization_request/rapid_resynchronization_request_test.rs +++ /dev/null @@ -1,116 +0,0 @@ -use bytes::Bytes; - -use super::*; - -#[test] -fn test_rapid_resynchronization_request_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - 0x85, 0xcd, 0x0, 0x2, // RapidResynchronizationRequest - 0x90, 0x2f, 0x9e, 0x2e, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e - ]), - RapidResynchronizationRequest { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - }, - None, - ), - ( - "short report", - Bytes::from_static(&[ - 0x85, 0xcd, 0x0, 0x2, // ssrc=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, - // report ends early - ]), - RapidResynchronizationRequest::default(), - Some(Error::PacketTooShort), - ), - ( - "wrong type", - Bytes::from_static(&[ - 0x81, 0xc8, 0x0, 0x7, // v=2, p=0, count=1, SR, len=7 - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0x902f9e2e - 0xbc, 0x5e, 0x9a, 0x40, // ssrc=0xbc5e9a40 - 0x0, 0x0, 0x0, 0x0, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x46, 0xe1, // lastSeq=0x46e1 - 0x0, 0x0, 0x1, 0x11, // jitter=273 - 0x9, 0xf3, 0x64, 0x32, // lsr=0x9f36432 - 0x0, 0x2, 0x4a, 0x79, // delay=150137 - ]), - RapidResynchronizationRequest::default(), - Some(Error::WrongType), - ), - ( - "nil", - Bytes::from_static(&[]), - RapidResynchronizationRequest::default(), - Some(Error::PacketTooShort), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = RapidResynchronizationRequest::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name} rr: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name} rr: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_rapid_resynchronization_request_roundtrip() { - let tests: Vec<(&str, RapidResynchronizationRequest, Option)> = vec![( - "valid", - RapidResynchronizationRequest { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - }, - None, - )]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = RapidResynchronizationRequest::unmarshal(&mut data) - .unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} diff --git a/rtcp/src/transport_feedbacks/transport_layer_cc/mod.rs b/rtcp/src/transport_feedbacks/transport_layer_cc/mod.rs deleted file mode 100644 index 5256cd83d..000000000 --- a/rtcp/src/transport_feedbacks/transport_layer_cc/mod.rs +++ /dev/null @@ -1,722 +0,0 @@ -#[cfg(test)] -mod transport_layer_cc_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -type Result = std::result::Result; - -/// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-5 -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |V=2|P| FMT=15 | PT=205 | length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | SSRC of packet sender | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | SSRC of media source | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | base sequence number | packet status count | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | reference time | fb pkt. count | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | packet chunk | packet chunk | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// . . -/// . . -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | packet chunk | recv delta | recv delta | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// . . -/// . . -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | recv delta | recv delta | zero padding | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -// for packet status chunk -/// type of packet status chunk -#[derive(Default, PartialEq, Eq, Debug, Clone)] -#[repr(u16)] -pub enum StatusChunkTypeTcc { - #[default] - RunLengthChunk = 0, - StatusVectorChunk = 1, -} - -/// type of packet status symbol and recv delta -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -#[repr(u16)] -pub enum SymbolTypeTcc { - /// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.1 - #[default] - PacketNotReceived = 0, - /// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.1 - PacketReceivedSmallDelta = 1, - /// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.1 - PacketReceivedLargeDelta = 2, - /// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - /// see Example 2: "packet received, w/o recv delta" - PacketReceivedWithoutDelta = 3, -} - -/// for status vector chunk -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -#[repr(u16)] -pub enum SymbolSizeTypeTcc { - /// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.4 - #[default] - OneBit = 0, - TwoBit = 1, -} - -impl From for SymbolSizeTypeTcc { - fn from(val: u16) -> Self { - match val { - 0 => SymbolSizeTypeTcc::OneBit, - _ => SymbolSizeTypeTcc::TwoBit, - } - } -} - -impl From for StatusChunkTypeTcc { - fn from(val: u16) -> Self { - match val { - 0 => StatusChunkTypeTcc::RunLengthChunk, - _ => StatusChunkTypeTcc::StatusVectorChunk, - } - } -} - -impl From for SymbolTypeTcc { - fn from(val: u16) -> Self { - match val { - 0 => SymbolTypeTcc::PacketNotReceived, - 1 => SymbolTypeTcc::PacketReceivedSmallDelta, - 2 => SymbolTypeTcc::PacketReceivedLargeDelta, - _ => SymbolTypeTcc::PacketReceivedWithoutDelta, - } - } -} - -/// PacketStatusChunk has two kinds: -/// RunLengthChunk and StatusVectorChunk -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PacketStatusChunk { - RunLengthChunk(RunLengthChunk), - StatusVectorChunk(StatusVectorChunk), -} - -impl MarshalSize for PacketStatusChunk { - fn marshal_size(&self) -> usize { - match self { - PacketStatusChunk::RunLengthChunk(c) => c.marshal_size(), - PacketStatusChunk::StatusVectorChunk(c) => c.marshal_size(), - } - } -} - -impl Marshal for PacketStatusChunk { - /// Marshal .. - fn marshal_to(&self, buf: &mut [u8]) -> Result { - match self { - PacketStatusChunk::RunLengthChunk(c) => c.marshal_to(buf), - PacketStatusChunk::StatusVectorChunk(c) => c.marshal_to(buf), - } - } -} - -/// RunLengthChunk T=TypeTCCRunLengthChunk -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |T| S | Run Length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct RunLengthChunk { - /// T = TypeTCCRunLengthChunk - pub type_tcc: StatusChunkTypeTcc, - /// S: type of packet status - /// kind: TypeTCCPacketNotReceived or... - pub packet_status_symbol: SymbolTypeTcc, - /// run_length: count of S - pub run_length: u16, -} - -impl MarshalSize for RunLengthChunk { - fn marshal_size(&self) -> usize { - PACKET_STATUS_CHUNK_LENGTH - } -} - -impl Marshal for RunLengthChunk { - /// Marshal .. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - // append 1 bit '0' - let mut dst = set_nbits_of_uint16(0, 1, 0, 0)?; - - // append 2 bit packet_status_symbol - dst = set_nbits_of_uint16(dst, 2, 1, self.packet_status_symbol as u16)?; - - // append 13 bit run_length - dst = set_nbits_of_uint16(dst, 13, 3, self.run_length)?; - - buf.put_u16(dst); - - Ok(PACKET_STATUS_CHUNK_LENGTH) - } -} - -impl Unmarshal for RunLengthChunk { - /// Unmarshal .. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < PACKET_STATUS_CHUNK_LENGTH { - return Err(Error::PacketStatusChunkLength.into()); - } - - // record type - let type_tcc = StatusChunkTypeTcc::RunLengthChunk; - - let b0 = raw_packet.get_u8(); - let b1 = raw_packet.get_u8(); - - // get PacketStatusSymbol - let packet_status_symbol = get_nbits_from_byte(b0, 1, 2).into(); - - // get RunLength - let run_length = ((get_nbits_from_byte(b0, 3, 5) as usize) << 8) as u16 + (b1 as u16); - - Ok(RunLengthChunk { - type_tcc, - packet_status_symbol, - run_length, - }) - } -} - -/// StatusVectorChunk T=typeStatusVectorChunk -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |T|S| symbol list | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct StatusVectorChunk { - /// T = TypeTCCRunLengthChunk - pub type_tcc: StatusChunkTypeTcc, - - /// TypeTCCSymbolSizeOneBit or TypeTCCSymbolSizeTwoBit - pub symbol_size: SymbolSizeTypeTcc, - - /// when symbol_size = TypeTCCSymbolSizeOneBit, symbol_list is 14*1bit: - /// TypeTCCSymbolListPacketReceived or TypeTCCSymbolListPacketNotReceived - /// when symbol_size = TypeTCCSymbolSizeTwoBit, symbol_list is 7*2bit: - /// TypeTCCPacketNotReceived TypeTCCPacketReceivedSmallDelta TypeTCCPacketReceivedLargeDelta or typePacketReserved - pub symbol_list: Vec, -} - -impl MarshalSize for StatusVectorChunk { - fn marshal_size(&self) -> usize { - PACKET_STATUS_CHUNK_LENGTH - } -} - -impl Marshal for StatusVectorChunk { - /// Marshal .. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - // set first bit '1' - let mut dst = set_nbits_of_uint16(0, 1, 0, 1)?; - - // set second bit symbol_size - dst = set_nbits_of_uint16(dst, 1, 1, self.symbol_size as u16)?; - - let num_of_bits = NUM_OF_BITS_OF_SYMBOL_SIZE[self.symbol_size as usize]; - // append 14 bit symbol_list - for (i, s) in self.symbol_list.iter().enumerate() { - let index = num_of_bits * (i as u16) + 2; - dst = set_nbits_of_uint16(dst, num_of_bits, index, *s as u16)?; - } - - buf.put_u16(dst); - - Ok(PACKET_STATUS_CHUNK_LENGTH) - } -} - -impl Unmarshal for StatusVectorChunk { - /// Unmarshal .. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < PACKET_STATUS_CHUNK_LENGTH { - return Err(Error::PacketBeforeCname.into()); - } - - let type_tcc = StatusChunkTypeTcc::StatusVectorChunk; - - let b0 = raw_packet.get_u8(); - let b1 = raw_packet.get_u8(); - - let symbol_size = get_nbits_from_byte(b0, 1, 1).into(); - - let mut symbol_list: Vec = vec![]; - match symbol_size { - SymbolSizeTypeTcc::OneBit => { - for i in 0..6u16 { - symbol_list.push(get_nbits_from_byte(b0, 2 + i, 1).into()); - } - - for i in 0..8u16 { - symbol_list.push(get_nbits_from_byte(b1, i, 1).into()) - } - } - - SymbolSizeTypeTcc::TwoBit => { - for i in 0..3u16 { - symbol_list.push(get_nbits_from_byte(b0, 2 + i * 2, 2).into()); - } - - for i in 0..4u16 { - symbol_list.push(get_nbits_from_byte(b1, i * 2, 2).into()); - } - } - } - - Ok(StatusVectorChunk { - type_tcc, - symbol_size, - symbol_list, - }) - } -} - -/// RecvDelta are represented as multiples of 250us -/// small delta is 1 byte: [0๏ผŒ63.75]ms = [0, 63750]us = [0, 255]*250us -/// big delta is 2 bytes: [-8192.0, 8191.75]ms = [-8192000, 8191750]us = [-32768, 32767]*250us -/// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.5 -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct RecvDelta { - pub type_tcc_packet: SymbolTypeTcc, - /// us - pub delta: i64, -} - -impl MarshalSize for RecvDelta { - fn marshal_size(&self) -> usize { - let delta = self.delta / TYPE_TCC_DELTA_SCALE_FACTOR; - - // small delta - if self.type_tcc_packet == SymbolTypeTcc::PacketReceivedSmallDelta - && delta >= 0 - && delta <= u8::MAX as i64 - { - return 1; - } - - // big delta - if self.type_tcc_packet == SymbolTypeTcc::PacketReceivedLargeDelta - && delta >= i16::MIN as i64 - && delta <= i16::MAX as i64 - { - return 2; - } - - 0 - } -} - -impl Marshal for RecvDelta { - /// Marshal .. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - let delta = self.delta / TYPE_TCC_DELTA_SCALE_FACTOR; - - // small delta - if self.type_tcc_packet == SymbolTypeTcc::PacketReceivedSmallDelta - && delta >= 0 - && delta <= u8::MAX as i64 - && buf.remaining_mut() >= 1 - { - buf.put_u8(delta as u8); - return Ok(1); - } - - // big delta - if self.type_tcc_packet == SymbolTypeTcc::PacketReceivedLargeDelta - && delta >= i16::MIN as i64 - && delta <= i16::MAX as i64 - && buf.remaining_mut() >= 2 - { - buf.put_i16(delta as i16); - return Ok(2); - } - - // overflow - Err(Error::DeltaExceedLimit.into()) - } -} - -impl Unmarshal for RecvDelta { - /// Unmarshal .. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let chunk_len = raw_packet.remaining(); - - // must be 1 or 2 bytes - if chunk_len != 1 && chunk_len != 2 { - return Err(Error::DeltaExceedLimit.into()); - } - - let (type_tcc_packet, delta) = if chunk_len == 1 { - ( - SymbolTypeTcc::PacketReceivedSmallDelta, - TYPE_TCC_DELTA_SCALE_FACTOR * raw_packet.get_u8() as i64, - ) - } else { - ( - SymbolTypeTcc::PacketReceivedLargeDelta, - TYPE_TCC_DELTA_SCALE_FACTOR * raw_packet.get_i16() as i64, - ) - }; - - Ok(RecvDelta { - type_tcc_packet, - delta, - }) - } -} - -/// The offset after header -const BASE_SEQUENCE_NUMBER_OFFSET: usize = 8; -/// The offset after header -const PACKET_STATUS_COUNT_OFFSET: usize = 10; -/// The offset after header -const REFERENCE_TIME_OFFSET: usize = 12; -/// The offset after header -const FB_PKT_COUNT_OFFSET: usize = 15; -/// The offset after header -const PACKET_CHUNK_OFFSET: usize = 16; -/// len of packet status chunk -const TYPE_TCC_STATUS_VECTOR_CHUNK: usize = 1; - -/// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.5 -pub const TYPE_TCC_DELTA_SCALE_FACTOR: i64 = 250; - -// Notice: RFC is wrong: "packet received" (0) and "packet not received" (1) -// if S == TYPE_TCCSYMBOL_SIZE_ONE_BIT, symbol list will be: TypeTCCPacketNotReceived TypeTCCPacketReceivedSmallDelta -// if S == TYPE_TCCSYMBOL_SIZE_TWO_BIT, symbol list will be same as above: - -static NUM_OF_BITS_OF_SYMBOL_SIZE: [u16; 2] = [1, 2]; - -/// len of packet status chunk -const PACKET_STATUS_CHUNK_LENGTH: usize = 2; - -/// TransportLayerCC for sender-BWE -/// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-5 -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct TransportLayerCc { - /// SSRC of sender - pub sender_ssrc: u32, - /// SSRC of the media source - pub media_ssrc: u32, - /// Transport wide sequence of rtp extension - pub base_sequence_number: u16, - /// packet_status_count - pub packet_status_count: u16, - /// reference_time - pub reference_time: u32, - /// fb_pkt_count - pub fb_pkt_count: u8, - /// packet_chunks - pub packet_chunks: Vec, - /// recv_deltas - pub recv_deltas: Vec, -} - -impl fmt::Display for TransportLayerCc { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = String::new(); - out += format!("TransportLayerCC:\n\tSender Ssrc {}\n", self.sender_ssrc).as_str(); - out += format!("\tMedia Ssrc {}\n", self.media_ssrc).as_str(); - out += format!("\tBase Sequence Number {}\n", self.base_sequence_number).as_str(); - out += format!("\tStatus Count {}\n", self.packet_status_count).as_str(); - out += format!("\tReference Time {}\n", self.reference_time).as_str(); - out += format!("\tFeedback Packet Count {}\n", self.fb_pkt_count).as_str(); - out += "\tpacket_chunks "; - out += "\n\trecv_deltas "; - for delta in &self.recv_deltas { - out += format!("{delta:?} ").as_str(); - } - out += "\n"; - - write!(f, "{out}") - } -} - -impl Packet for TransportLayerCc { - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_TCC, - packet_type: PacketType::TransportSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.media_ssrc] - } - - fn raw_size(&self) -> usize { - let mut n = HEADER_LENGTH + PACKET_CHUNK_OFFSET + self.packet_chunks.len() * 2; - for d in &self.recv_deltas { - // small delta - if d.type_tcc_packet == SymbolTypeTcc::PacketReceivedSmallDelta { - n += 1; - } else { - n += 2 - } - } - n - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for TransportLayerCc { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for TransportLayerCc { - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(self.media_ssrc); - buf.put_u16(self.base_sequence_number); - buf.put_u16(self.packet_status_count); - - let reference_time_and_fb_pkt_count = append_nbits_to_uint32(0, 24, self.reference_time); - let reference_time_and_fb_pkt_count = - append_nbits_to_uint32(reference_time_and_fb_pkt_count, 8, self.fb_pkt_count as u32); - - buf.put_u32(reference_time_and_fb_pkt_count); - - for chunk in &self.packet_chunks { - let n = chunk.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - for delta in &self.recv_deltas { - let n = delta.marshal_to(buf)?; - buf = &mut buf[n..]; - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for TransportLayerCc { - /// Unmarshal .. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SSRC_LENGTH) { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - - // https://tools.ietf.org/html/rfc4585#page-33 - // header's length + payload's length - let total_length = 4 * (h.length + 1) as usize; - - if total_length < HEADER_LENGTH + PACKET_CHUNK_OFFSET { - return Err(Error::PacketTooShort.into()); - } - - if raw_packet_len < total_length { - return Err(Error::PacketTooShort.into()); - } - - if h.packet_type != PacketType::TransportSpecificFeedback || h.count != FORMAT_TCC { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - let base_sequence_number = raw_packet.get_u16(); - let packet_status_count = raw_packet.get_u16(); - - let mut buf = vec![0u8; 3]; - buf[0] = raw_packet.get_u8(); - buf[1] = raw_packet.get_u8(); - buf[2] = raw_packet.get_u8(); - let reference_time = get_24bits_from_bytes(&buf); - let fb_pkt_count = raw_packet.get_u8(); - let mut packet_chunks = vec![]; - let mut recv_deltas = vec![]; - - let mut packet_status_pos = HEADER_LENGTH + PACKET_CHUNK_OFFSET; - let mut processed_packet_num = 0u16; - while processed_packet_num < packet_status_count { - if packet_status_pos + PACKET_STATUS_CHUNK_LENGTH >= total_length { - return Err(Error::PacketTooShort.into()); - } - - let mut chunk_reader = raw_packet.copy_to_bytes(PACKET_STATUS_CHUNK_LENGTH); - let b0 = chunk_reader[0]; - - let typ = get_nbits_from_byte(b0, 0, 1); - let initial_packet_status: PacketStatusChunk; - match typ.into() { - StatusChunkTypeTcc::RunLengthChunk => { - let packet_status = RunLengthChunk::unmarshal(&mut chunk_reader)?; - - let packet_number_to_process = - (packet_status_count - processed_packet_num).min(packet_status.run_length); - - if packet_status.packet_status_symbol == SymbolTypeTcc::PacketReceivedSmallDelta - || packet_status.packet_status_symbol - == SymbolTypeTcc::PacketReceivedLargeDelta - { - let mut j = 0u16; - - while j < packet_number_to_process { - recv_deltas.push(RecvDelta { - type_tcc_packet: packet_status.packet_status_symbol, - ..Default::default() - }); - - j += 1; - } - } - - initial_packet_status = PacketStatusChunk::RunLengthChunk(packet_status); - processed_packet_num += packet_number_to_process; - } - - StatusChunkTypeTcc::StatusVectorChunk => { - let packet_status = StatusVectorChunk::unmarshal(&mut chunk_reader)?; - - match packet_status.symbol_size { - SymbolSizeTypeTcc::OneBit => { - for sym in &packet_status.symbol_list { - if *sym == SymbolTypeTcc::PacketReceivedSmallDelta { - recv_deltas.push(RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - ..Default::default() - }) - } - } - } - - SymbolSizeTypeTcc::TwoBit => { - for sym in &packet_status.symbol_list { - if *sym == SymbolTypeTcc::PacketReceivedSmallDelta - || *sym == SymbolTypeTcc::PacketReceivedLargeDelta - { - recv_deltas.push(RecvDelta { - type_tcc_packet: *sym, - ..Default::default() - }) - } - } - } - } - - processed_packet_num += packet_status.symbol_list.len() as u16; - initial_packet_status = PacketStatusChunk::StatusVectorChunk(packet_status); - } - } - - packet_status_pos += PACKET_STATUS_CHUNK_LENGTH; - packet_chunks.push(initial_packet_status); - } - - let mut recv_deltas_pos = packet_status_pos; - - for delta in &mut recv_deltas { - if recv_deltas_pos >= total_length { - return Err(Error::PacketTooShort.into()); - } - - if delta.type_tcc_packet == SymbolTypeTcc::PacketReceivedSmallDelta { - let mut delta_reader = raw_packet.take(1); - *delta = RecvDelta::unmarshal(&mut delta_reader)?; - recv_deltas_pos += 1; - } - - if delta.type_tcc_packet == SymbolTypeTcc::PacketReceivedLargeDelta { - let mut delta_reader = raw_packet.take(2); - *delta = RecvDelta::unmarshal(&mut delta_reader)?; - recv_deltas_pos += 2; - } - } - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(TransportLayerCc { - sender_ssrc, - media_ssrc, - base_sequence_number, - packet_status_count, - reference_time, - fb_pkt_count, - packet_chunks, - recv_deltas, - }) - } -} diff --git a/rtcp/src/transport_feedbacks/transport_layer_cc/transport_layer_cc_test.rs b/rtcp/src/transport_feedbacks/transport_layer_cc/transport_layer_cc_test.rs deleted file mode 100644 index e1cc83fb9..000000000 --- a/rtcp/src/transport_feedbacks/transport_layer_cc/transport_layer_cc_test.rs +++ /dev/null @@ -1,927 +0,0 @@ -use bytes::Bytes; - -use super::*; - -#[test] -fn test_transport_layer_cc_run_length_chunk_unmarshal() -> Result<()> { - let tests = vec![ - ( - // 3.1.3 example1: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example1", - Bytes::from_static(&[0, 0xDD]), - RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketNotReceived, - run_length: 221, - }, - ), - ( - // 3.1.3 example2: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example2", - Bytes::from_static(&[0x60, 0x18]), - RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedWithoutDelta, - run_length: 24, - }, - ), - ]; - - for (name, mut data, want) in tests { - let got = RunLengthChunk::unmarshal(&mut data)?; - assert_eq!(got, want, "Unmarshal {name} : err",); - } - - Ok(()) -} - -#[test] -fn test_transport_layer_cc_run_length_chunk_marshal() -> Result<()> { - let tests = vec![ - ( - // 3.1.3 example1: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example1", - RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketNotReceived, - run_length: 221, - }, - Bytes::from_static(&[0, 0xDD]), - ), - ( - // 3.1.3 example2: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example2", - RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedWithoutDelta, - run_length: 24, - }, - Bytes::from_static(&[0x60, 0x18]), - ), - ]; - - for (name, chunk, want) in tests { - let got = chunk.marshal()?; - assert_eq!(got, want, "Marshal {name}: err",); - } - - Ok(()) -} - -#[test] -fn test_transport_layer_cc_status_vector_chunk_unmarshal() -> Result<()> { - let tests = vec![ - ( - // 3.1.4 example1: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example1", - Bytes::from_static(&[0x9F, 0x1C]), - StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }, - ), - ( - // 3.1.4 example2: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example2", - Bytes::from_static(&[0xCD, 0x50]), - StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }, - ), - ]; - - for (name, mut data, want) in tests { - let got = StatusVectorChunk::unmarshal(&mut data)?; - assert_eq!(got, want, "Unmarshal {name} : err",); - } - - Ok(()) -} - -#[test] -fn test_transport_layer_cc_status_vector_chunk_marshal() -> Result<()> { - let tests = vec![ - ( - //3.1.4 example1: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example1", - StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }, - Bytes::from_static(&[0x9F, 0x1C]), - ), - ( - //3.1.4 example2: https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 - "example2", - StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }, - Bytes::from_static(&[0xCD, 0x50]), - ), - ]; - - for (name, chunk, want) in tests { - let got = chunk.marshal()?; - assert_eq!(got, want, "Marshal {name}: err",); - } - - Ok(()) -} - -#[test] -fn test_transport_layer_cc_recv_delta_unmarshal() -> Result<()> { - let tests = vec![ - ( - "small delta 63.75ms", - Bytes::from_static(&[0xFF]), - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - // 255 * 250 - delta: 63750, - }, - ), - ( - "big delta 8191.75ms", - Bytes::from_static(&[0x7F, 0xFF]), - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - // 32767 * 250 - delta: 8191750, - }, - ), - ( - "big delta -8192ms", - Bytes::from_static(&[0x80, 0x00]), - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - // -32768 * 250 - delta: -8192000, - }, - ), - ]; - - for (name, mut data, want) in tests { - let got = RecvDelta::unmarshal(&mut data)?; - assert_eq!(got, want, "Unmarshal {name} : err",); - } - - Ok(()) -} - -#[test] -fn test_transport_layer_cc_recv_delta_marshal() -> Result<()> { - let tests = vec![ - ( - "small delta 63.75ms", - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - // 255 * 250 - delta: 63750, - }, - Bytes::from_static(&[0xFF]), - ), - ( - "big delta 8191.75ms", - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - // 32767 * 250 - delta: 8191750, - }, - Bytes::from_static(&[0x7F, 0xFF]), - ), - ( - "big delta -8192ms", - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - // -32768 * 250 - delta: -8192000, - }, - Bytes::from_static(&[0x80, 0x00]), - ), - ]; - - for (name, chunk, want) in tests { - let got = chunk.marshal()?; - assert_eq!(got, want, "Marshal {name}: err",); - } - - Ok(()) -} - -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |V=2|P| FMT=15 | PT=205 | length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | SSRC of packet sender | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | SSRC of media source | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | base sequence number | packet status count | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | reference time | fb pkt. count | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | packet chunk | recv delta | recv delta | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// 0b10101111,0b11001101,0b00000000,0b00000101, -/// 0b11111010,0b00010111,0b11111010,0b00010111, -/// 0b01000011,0b00000011,0b00101111,0b10100000, -/// 0b00000000,0b10011001,0b00000000,0b00000001, -/// 0b00111101,0b11101000,0b00000010,0b00010111, -/// 0b00100000,0b00000001,0b10010100,0b00000001, -#[test] -fn test_transport_layer_cc_unmarshal() -> Result<()> { - let tests = vec![ - ( - "example1", - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x5, 0xfa, 0x17, 0xfa, 0x17, 0x43, 0x3, 0x2f, 0xa0, 0x0, 0x99, - 0x0, 0x1, 0x3d, 0xe8, 0x2, 0x17, 0x20, 0x1, 0x94, 0x1, - ]), - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 1124282272, - base_sequence_number: 153, - packet_status_count: 1, - reference_time: 4057090, - fb_pkt_count: 23, - // 0b00100000, 0b00000001 - packet_chunks: vec![PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 1, - })], - // 0b10010100 - recv_deltas: vec![RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 37000, - }], - }, - ), - ( - "example2", - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x6, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x1, 0x74, - 0x0, 0xe, 0x45, 0xb1, 0x5a, 0x40, 0xd8, 0x0, 0xf0, 0xff, 0xd0, 0x0, 0x0, 0x3, - ]), - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 372, - packet_status_count: 14, - reference_time: 4567386, - fb_pkt_count: 64, - packet_chunks: vec![ - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }), - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedWithoutDelta, - ], - }), - ], - // 0b10010100 - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 52000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0, - }, - ], - }, - ), - ( - "example3", - Bytes::from_static(&[ - 0x8f, 0xcd, 0x0, 0x7, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x1, 0x74, - 0x0, 0x6, 0x45, 0xb1, 0x5a, 0x40, 0x40, 0x2, 0x20, 0x04, 0x1f, 0xfe, 0x1f, 0x9a, - 0xd0, 0x0, 0xd0, 0x0, - ]), - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 372, - packet_status_count: 6, - reference_time: 4567386, - fb_pkt_count: 64, - packet_chunks: vec![ - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedLargeDelta, - run_length: 2, - }), - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 4, - }), - ], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 2047500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 2022500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 52000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 52000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - ], - }, - ), - ( - "example4", - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x7, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x0, 0x4, - 0x0, 0x7, 0x10, 0x63, 0x6e, 0x1, 0x20, 0x7, 0x4c, 0x24, 0x24, 0x10, 0xc, 0xc, 0x10, - 0x0, 0x0, 0x3, - ]), - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 4, - packet_status_count: 7, - reference_time: 1074030, - fb_pkt_count: 1, - packet_chunks: vec![PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 7, - })], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 19000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 9000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 9000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - ], - }, - ), - ( - "example5", - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x6, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x0, 0x1, - 0x0, 0xe, 0x10, 0x63, 0x6d, 0x0, 0xba, 0x0, 0x10, 0xc, 0xc, 0x10, 0x0, 0x3, - ]), - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 1, - packet_status_count: 14, - reference_time: 1074029, - fb_pkt_count: 0, - packet_chunks: vec![PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - })], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - ], - }, - ), - ( - "example6", - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x7, 0x9b, 0x74, 0xf6, 0x1f, 0x93, 0x71, 0xdc, 0xbc, 0x85, 0x3c, - 0x0, 0x9, 0x63, 0xf9, 0x16, 0xb3, 0xd5, 0x52, 0x0, 0x30, 0x9b, 0xaa, 0x6a, 0xaa, - 0x7b, 0x1, 0x9, 0x1, - ]), - TransportLayerCc { - sender_ssrc: 2608133663, - media_ssrc: 2473712828, - base_sequence_number: 34108, - packet_status_count: 9, - reference_time: 6551830, - fb_pkt_count: 179, - packet_chunks: vec![ - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedLargeDelta, - ], - }), - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketNotReceived, - run_length: 48, - }), - ], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 38750, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 42500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 26500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 42500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 30750, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 66250, - }, - ], - }, - ), - ( - "example3", - Bytes::from_static(&[ - 0x8f, 0xcd, 0x0, 0x4, 0x9a, 0xcb, 0x4, 0x42, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, - ]), - TransportLayerCc { - sender_ssrc: 2596996162, - media_ssrc: 0, - base_sequence_number: 0, - packet_status_count: 0, - reference_time: 0, - fb_pkt_count: 0, - packet_chunks: vec![], - recv_deltas: vec![], - }, - ), - ]; - - for (name, mut data, want) in tests { - let got = TransportLayerCc::unmarshal(&mut data)?; - assert!(got == want, "Unmarshal {name} : err",); - } - - Ok(()) -} - -#[test] -fn test_transport_layer_cc_marshal() -> Result<()> { - let tests = vec![ - ( - "example1", - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 1124282272, - base_sequence_number: 153, - packet_status_count: 1, - reference_time: 4057090, - fb_pkt_count: 23, - // 0b00100000, 0b00000001 - packet_chunks: vec![PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 1, - })], - // 0b10010100 - recv_deltas: vec![RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 37000, - }], - }, - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x5, 0xfa, 0x17, 0xfa, 0x17, 0x43, 0x3, 0x2f, 0xa0, 0x0, 0x99, - 0x0, 0x1, 0x3d, 0xe8, 0x2, 0x17, 0x20, 0x1, 0x94, 0x1, - ]), - ), - ( - "example2", - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 372, - packet_status_count: 2, - reference_time: 4567386, - fb_pkt_count: 64, - packet_chunks: vec![ - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedLargeDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - }), - PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::TwoBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedWithoutDelta, - SymbolTypeTcc::PacketReceivedWithoutDelta, - ], - }), - ], - // 0b10010100 - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 52000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 0, - }, - ], - }, - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x6, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x1, 0x74, - 0x0, 0x2, 0x45, 0xb1, 0x5a, 0x40, 0xd8, 0x0, 0xf0, 0xff, 0xd0, 0x0, 0x0, 0x1, - ]), - ), - ( - "example3", - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 372, - packet_status_count: 6, - reference_time: 4567386, - fb_pkt_count: 64, - packet_chunks: vec![ - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedLargeDelta, - run_length: 2, - }), - PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 4, - }), - ], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 2047500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedLargeDelta, - delta: 2022500, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 52000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 52000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 0, - }, - ], - }, - Bytes::from_static(&[ - 0x8f, 0xcd, 0x0, 0x7, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x1, 0x74, - 0x0, 0x6, 0x45, 0xb1, 0x5a, 0x40, 0x40, 0x2, 0x20, 0x04, 0x1f, 0xfe, 0x1f, 0x9a, - 0xd0, 0x0, 0xd0, 0x0, - ]), - ), - ( - "example4", - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 4, - packet_status_count: 7, - reference_time: 1074030, - fb_pkt_count: 1, - packet_chunks: vec![PacketStatusChunk::RunLengthChunk(RunLengthChunk { - type_tcc: StatusChunkTypeTcc::RunLengthChunk, - packet_status_symbol: SymbolTypeTcc::PacketReceivedSmallDelta, - run_length: 7, - })], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 19000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 9000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 9000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - ], - }, - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x7, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x0, 0x4, - 0x0, 0x7, 0x10, 0x63, 0x6e, 0x1, 0x20, 0x7, 0x4c, 0x24, 0x24, 0x10, 0xc, 0xc, 0x10, - 0x0, 0x0, 0x3, - ]), - ), - ( - "example5", - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 423483579, - base_sequence_number: 1, - packet_status_count: 14, - reference_time: 1074029, - fb_pkt_count: 0, - packet_chunks: vec![PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - })], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 3000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 4000, - }, - ], - }, - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x6, 0xfa, 0x17, 0xfa, 0x17, 0x19, 0x3d, 0xd8, 0xbb, 0x0, 0x1, - 0x0, 0xe, 0x10, 0x63, 0x6d, 0x0, 0xba, 0x0, 0x10, 0xc, 0xc, 0x10, 0x0, 0x2, - ]), - ), - ( - "example6", - TransportLayerCc { - sender_ssrc: 4195875351, - media_ssrc: 1124282272, - base_sequence_number: 39956, - packet_status_count: 12, - reference_time: 7701536, - fb_pkt_count: 0, - packet_chunks: vec![PacketStatusChunk::StatusVectorChunk(StatusVectorChunk { - type_tcc: StatusChunkTypeTcc::StatusVectorChunk, - symbol_size: SymbolSizeTypeTcc::OneBit, - symbol_list: vec![ - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketReceivedSmallDelta, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - SymbolTypeTcc::PacketNotReceived, - ], - })], - recv_deltas: vec![ - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 48250, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 15750, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 14750, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 15750, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 20750, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 36000, - }, - RecvDelta { - type_tcc_packet: SymbolTypeTcc::PacketReceivedSmallDelta, - delta: 14750, - }, - ], - }, - Bytes::from_static(&[ - 0xaf, 0xcd, 0x0, 0x7, 0xfa, 0x17, 0xfa, 0x17, 0x43, 0x3, 0x2f, 0xa0, 0x9c, 0x14, - 0x0, 0xc, 0x75, 0x84, 0x20, 0x0, 0xbe, 0xc0, 0xc1, 0x3f, 0x3b, 0x3f, 0x53, 0x90, - 0x3b, 0x0, 0x0, 0x3, - ]), - ), - ]; - - for (name, chunk, want) in tests { - let got = chunk.marshal()?; - assert_eq!(got, want, "Marshal {name}: err"); - } - - Ok(()) -} diff --git a/rtcp/src/transport_feedbacks/transport_layer_nack/mod.rs b/rtcp/src/transport_feedbacks/transport_layer_nack/mod.rs deleted file mode 100644 index 05330017d..000000000 --- a/rtcp/src/transport_feedbacks/transport_layer_nack/mod.rs +++ /dev/null @@ -1,275 +0,0 @@ -#[cfg(test)] -mod transport_layer_nack_test; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; -use crate::packet::*; -use crate::util::*; - -/// PacketBitmap shouldn't be used like a normal integral, -/// so it's type is masked here. Access it with PacketList(). -type PacketBitmap = u16; - -/// NackPair is a wire-representation of a collection of -/// Lost RTP packets -#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] -pub struct NackPair { - /// ID of lost packets - pub packet_id: u16, - /// Bitmask of following lost packets - pub lost_packets: PacketBitmap, -} - -pub struct NackIterator { - packet_id: u16, - bitfield: PacketBitmap, - has_yielded_packet_id: bool, -} - -impl Iterator for NackIterator { - type Item = u16; - - fn next(&mut self) -> Option { - if !self.has_yielded_packet_id { - self.has_yielded_packet_id = true; - - Some(self.packet_id) - } else { - let mut i = 0; - - while self.bitfield != 0 { - if (self.bitfield & (1 << i)) != 0 { - self.bitfield &= !(1 << i); - - return Some(self.packet_id.wrapping_add(i + 1)); - } - - i += 1; - } - - None - } - } -} - -impl NackPair { - pub fn new(seq: u16) -> Self { - Self { - packet_id: seq, - lost_packets: Default::default(), - } - } - - /// PacketList returns a list of Nack'd packets that's referenced by a NackPair - pub fn packet_list(&self) -> Vec { - self.into_iter().collect() - } - - pub fn range(&self, f: F) - where - F: Fn(u16) -> bool, - { - for packet_id in self.into_iter() { - if !f(packet_id) { - return; - } - } - } -} - -/// Create an iterator over all the packet sequence numbers expressed by this NACK pair. -impl IntoIterator for NackPair { - type Item = u16; - - type IntoIter = NackIterator; - - fn into_iter(self) -> Self::IntoIter { - NackIterator { - packet_id: self.packet_id, - bitfield: self.lost_packets, - has_yielded_packet_id: false, - } - } -} - -const TLN_LENGTH: usize = 2; -const NACK_OFFSET: usize = 8; - -// The TransportLayerNack packet informs the encoder about the loss of a transport packet -// IETF RFC 4585, Section 6.2.1 -// https://tools.ietf.org/html/rfc4585#section-6.2.1 -#[derive(Debug, PartialEq, Eq, Default, Clone)] -pub struct TransportLayerNack { - /// SSRC of sender - pub sender_ssrc: u32, - /// SSRC of the media source - pub media_ssrc: u32, - - pub nacks: Vec, -} - -impl fmt::Display for TransportLayerNack { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = format!("TransportLayerNack from {:x}\n", self.sender_ssrc); - out += format!("\tMedia Ssrc {:x}\n", self.media_ssrc).as_str(); - out += "\tID\tLostPackets\n"; - for nack in &self.nacks { - out += format!("\t{}\t{:b}\n", nack.packet_id, nack.lost_packets).as_str(); - } - write!(f, "{out}") - } -} - -impl Packet for TransportLayerNack { - /// returns the Header associated with this packet. - fn header(&self) -> Header { - Header { - padding: get_padding_size(self.raw_size()) != 0, - count: FORMAT_TLN, - packet_type: PacketType::TransportSpecificFeedback, - length: ((self.marshal_size() / 4) - 1) as u16, - } - } - - /// destination_ssrc returns an array of SSRC values that this packet refers to. - fn destination_ssrc(&self) -> Vec { - vec![self.media_ssrc] - } - - fn raw_size(&self) -> usize { - HEADER_LENGTH + NACK_OFFSET + self.nacks.len() * 4 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn equal(&self, other: &(dyn Packet + Send + Sync)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } -} - -impl MarshalSize for TransportLayerNack { - fn marshal_size(&self) -> usize { - let l = self.raw_size(); - // align to 32-bit boundary - l + get_padding_size(l) - } -} - -impl Marshal for TransportLayerNack { - /// Marshal encodes the packet in binary. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if self.nacks.len() + TLN_LENGTH > u8::MAX as usize { - return Err(Error::TooManyReports.into()); - } - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::BufferTooShort.into()); - } - - let h = self.header(); - let n = h.marshal_to(buf)?; - buf = &mut buf[n..]; - - buf.put_u32(self.sender_ssrc); - buf.put_u32(self.media_ssrc); - - for i in 0..self.nacks.len() { - buf.put_u16(self.nacks[i].packet_id); - buf.put_u16(self.nacks[i].lost_packets); - } - - if h.padding { - put_padding(buf, self.raw_size()); - } - - Ok(self.marshal_size()) - } -} - -impl Unmarshal for TransportLayerNack { - /// Unmarshal decodes the ReceptionReport from binary - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < (HEADER_LENGTH + SSRC_LENGTH) { - return Err(Error::PacketTooShort.into()); - } - - let h = Header::unmarshal(raw_packet)?; - - if raw_packet_len < (HEADER_LENGTH + (4 * h.length) as usize) { - return Err(Error::PacketTooShort.into()); - } - - if h.packet_type != PacketType::TransportSpecificFeedback || h.count != FORMAT_TLN { - return Err(Error::WrongType.into()); - } - - let sender_ssrc = raw_packet.get_u32(); - let media_ssrc = raw_packet.get_u32(); - - let mut nacks = vec![]; - for _i in 0..(h.length as i32 - NACK_OFFSET as i32 / 4) { - nacks.push(NackPair { - packet_id: raw_packet.get_u16(), - lost_packets: raw_packet.get_u16(), - }); - } - - if - /*h.padding &&*/ - raw_packet.has_remaining() { - raw_packet.advance(raw_packet.remaining()); - } - - Ok(TransportLayerNack { - sender_ssrc, - media_ssrc, - nacks, - }) - } -} - -pub fn nack_pairs_from_sequence_numbers(seq_nos: &[u16]) -> Vec { - if seq_nos.is_empty() { - return vec![]; - } - - let mut nack_pair = NackPair::new(seq_nos[0]); - let mut pairs = vec![]; - - for &seq in seq_nos.iter().skip(1) { - if seq == nack_pair.packet_id { - continue; - } - if seq <= nack_pair.packet_id || seq > nack_pair.packet_id.saturating_add(16) { - pairs.push(nack_pair); - nack_pair = NackPair::new(seq); - continue; - } - - // Subtraction here is safe because the above checks that seqnum > nack_pair.packet_id. - nack_pair.lost_packets |= 1 << (seq - nack_pair.packet_id - 1); - } - - pairs.push(nack_pair); - - pairs -} diff --git a/rtcp/src/transport_feedbacks/transport_layer_nack/transport_layer_nack_test.rs b/rtcp/src/transport_feedbacks/transport_layer_nack/transport_layer_nack_test.rs deleted file mode 100644 index 40d833b86..000000000 --- a/rtcp/src/transport_feedbacks/transport_layer_nack/transport_layer_nack_test.rs +++ /dev/null @@ -1,361 +0,0 @@ -use std::sync::{Arc, Mutex}; - -use bytes::Bytes; - -use super::*; - -#[test] -fn test_transport_layer_nack_unmarshal() { - let tests = vec![ - ( - "valid", - Bytes::from_static(&[ - // TransportLayerNack - 0x81, 0xcd, 0x0, 0x3, // sender=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // media=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // nack 0xAAAA, 0x5555 - 0xaa, 0xaa, 0x55, 0x55, - ]), - TransportLayerNack { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - nacks: vec![NackPair { - packet_id: 0xaaaa, - lost_packets: 0x5555, - }], - }, - None, - ), - ( - "short report", - Bytes::from_static(&[ - 0x81, 0xcd, 0x0, 0x2, // ssrc=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, - // report ends early - ]), - TransportLayerNack::default(), - Some(Error::PacketTooShort), - ), - ( - "wrong type", - Bytes::from_static(&[ - // v=2, p=0, count=1, SR, len=7 - 0x81, 0xc8, 0x0, 0x7, // ssrc=0x902f9e2e - 0x90, 0x2f, 0x9e, 0x2e, // ssrc=0xbc5e9a40 - 0xbc, 0x5e, 0x9a, 0x40, // fracLost=0, totalLost=0 - 0x0, 0x0, 0x0, 0x0, // lastSeq=0x46e1 - 0x0, 0x0, 0x46, 0xe1, // jitter=273 - 0x0, 0x0, 0x1, 0x11, // lsr=0x9f36432 - 0x9, 0xf3, 0x64, 0x32, // delay=150137 - 0x0, 0x2, 0x4a, 0x79, - ]), - TransportLayerNack::default(), - Some(Error::WrongType), - ), - ( - "nil", - Bytes::from_static(&[]), - TransportLayerNack::default(), - Some(Error::PacketTooShort), - ), - ]; - - for (name, mut data, want, want_error) in tests { - let got = TransportLayerNack::unmarshal(&mut data); - - assert_eq!( - got.is_err(), - want_error.is_some(), - "Unmarshal {name} rr: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let actual = got.unwrap(); - assert_eq!( - actual, want, - "Unmarshal {name} rr: got {actual:?}, want {want:?}" - ); - } - } -} - -#[test] -fn test_transport_layer_nack_roundtrip() { - let tests: Vec<(&str, TransportLayerNack, Option)> = vec![( - "valid", - TransportLayerNack { - sender_ssrc: 0x902f9e2e, - media_ssrc: 0x902f9e2e, - nacks: vec![ - NackPair { - packet_id: 1, - lost_packets: 0xAA, - }, - NackPair { - packet_id: 1034, - lost_packets: 0x05, - }, - ], - }, - None, - )]; - - for (name, want, want_error) in tests { - let got = want.marshal(); - - assert_eq!( - got.is_ok(), - want_error.is_none(), - "Marshal {name}: err = {got:?}, want {want_error:?}" - ); - - if let Some(err) = want_error { - let got_err = got.err().unwrap(); - assert_eq!( - err, got_err, - "Unmarshal {name} rr: err = {got_err:?}, want {err:?}", - ); - } else { - let mut data = got.ok().unwrap(); - let actual = TransportLayerNack::unmarshal(&mut data) - .unwrap_or_else(|_| panic!("Unmarshal {name}")); - - assert_eq!( - actual, want, - "{name} round trip: got {actual:?}, want {want:?}" - ) - } - } -} - -#[test] -fn test_nack_pair() { - let test_nack = |s: Vec, n: NackPair| { - let l = n.packet_list(); - - assert_eq!(s, l, "{n:?}: expected {s:?}, got {l:?}"); - }; - - test_nack( - vec![42], - NackPair { - packet_id: 42, - lost_packets: 0, - }, - ); - - test_nack( - vec![42, 43], - NackPair { - packet_id: 42, - lost_packets: 1, - }, - ); - - test_nack( - vec![42, 44], - NackPair { - packet_id: 42, - lost_packets: 2, - }, - ); - - test_nack( - vec![42, 43, 44], - NackPair { - packet_id: 42, - lost_packets: 3, - }, - ); - - test_nack( - vec![42, 42 + 16], - NackPair { - packet_id: 42, - lost_packets: 0x8000, - }, - ); - - // Wrap around - test_nack( - vec![65534, 65535, 0, 1], - NackPair { - packet_id: 65534, - lost_packets: 0b0000_0111, - }, - ); - - // Gap - test_nack( - vec![123, 125, 127, 129], - NackPair { - packet_id: 123, - lost_packets: 0b0010_1010, - }, - ); -} - -#[test] -fn test_nack_pair_range() { - let n = NackPair { - packet_id: 42, - lost_packets: 2, - }; - - let out = Arc::new(Mutex::new(vec![])); - let out1 = Arc::clone(&out); - n.range(move |s: u16| -> bool { - let out2 = Arc::clone(&out1); - let mut o = out2.lock().unwrap(); - o.push(s); - true - }); - - { - let o = out.lock().unwrap(); - assert_eq!(*o, &[42, 44]); - } - - let out = Arc::new(Mutex::new(vec![])); - let out1 = Arc::clone(&out); - n.range(move |s: u16| -> bool { - let out2 = Arc::clone(&out1); - let mut o = out2.lock().unwrap(); - o.push(s); - false - }); - - { - let o = out.lock().unwrap(); - assert_eq!(*o, &[42]); - } -} - -#[test] -fn test_transport_layer_nack_pair_generation() { - let test = vec![ - ("No Sequence Numbers", vec![], vec![]), - ( - "Single Sequence Number", - vec![100u16], - vec![NackPair { - packet_id: 100, - lost_packets: 0x0, - }], - ), - // Make sure it doesn't crash. - ( - "Single Sequence Number (duplicates)", - vec![100u16, 100], - vec![NackPair { - packet_id: 100, - lost_packets: 0x0, - }], - ), - ( - "Multiple in range, Single NACKPair", - vec![100, 101, 105, 115], - vec![NackPair { - packet_id: 100, - lost_packets: 0x4011, - }], - ), - ( - "Multiple Ranges, Multiple NACKPair", - vec![100, 117, 500, 501, 502], - vec![ - NackPair { - packet_id: 100, - lost_packets: 0, - }, - NackPair { - packet_id: 117, - lost_packets: 0, - }, - NackPair { - packet_id: 500, - lost_packets: 0x3, - }, - ], - ), - ( - "Multiple Ranges, Multiple NACKPair", - vec![100, 117, 500, 501, 502], - vec![ - NackPair { - packet_id: 100, - lost_packets: 0, - }, - NackPair { - packet_id: 117, - lost_packets: 0, - }, - NackPair { - packet_id: 500, - lost_packets: 0x3, - }, - ], - ), - ( - "Multiple Ranges, Multiple NACKPair (with rollover)", - vec![100, 117, 65534, 65535, 0, 1, 99], - vec![ - NackPair { - packet_id: 100, - lost_packets: 0, - }, - NackPair { - packet_id: 117, - lost_packets: 0, - }, - NackPair { - packet_id: 65534, - lost_packets: 1, - }, - NackPair { - packet_id: 0, - lost_packets: 1, - }, - NackPair { - packet_id: 99, - lost_packets: 0, - }, - ], - ), - ]; - - for (name, seq_numbers, expected) in test { - let actual = nack_pairs_from_sequence_numbers(&seq_numbers); - - assert_eq!( - actual, expected, - "{name} NackPair generation mismatch: got {actual:#?}, want {expected:#?}" - ) - } -} - -/// This test case reproduced a bug in the implementation -#[test] -fn test_lost_packets_is_reset_when_crossing_16_bit_boundary() { - let seq: Vec<_> = (0u16..=17u16).collect(); - assert_eq!( - nack_pairs_from_sequence_numbers(&seq), - vec![ - NackPair { - packet_id: 0, - lost_packets: 0b1111_1111_1111_1111, - }, - NackPair { - packet_id: 17, - // Was 0xffff before fixing the bug - lost_packets: 0b0000_0000_0000_0000, - } - ], - ) -} diff --git a/rtcp/src/util.rs b/rtcp/src/util.rs deleted file mode 100644 index 44c1096de..000000000 --- a/rtcp/src/util.rs +++ /dev/null @@ -1,118 +0,0 @@ -use bytes::BufMut; - -use crate::error::{Error, Result}; - -// returns the padding required to make the length a multiple of 4 -pub(crate) fn get_padding_size(len: usize) -> usize { - if len % 4 == 0 { - 0 - } else { - 4 - (len % 4) - } -} - -pub(crate) fn put_padding(mut buf: &mut [u8], len: usize) { - let padding_size = get_padding_size(len); - for i in 0..padding_size { - if i == padding_size - 1 { - buf.put_u8(padding_size as u8); - } else { - buf.put_u8(0); - } - } -} - -// set_nbits_of_uint16 will truncate the value to size, left-shift to start_index position and set -pub(crate) fn set_nbits_of_uint16( - src: u16, - size: u16, - start_index: u16, - mut val: u16, -) -> Result { - if start_index + size > 16 { - return Err(Error::InvalidSizeOrStartIndex); - } - - // truncate val to size bits - val &= (1 << size) - 1; - - Ok(src | (val << (16 - size - start_index))) -} - -// appendBit32 will left-shift and append n bits of val -pub(crate) fn append_nbits_to_uint32(src: u32, n: u32, val: u32) -> u32 { - (src << n) | (val & (0xFFFFFFFF >> (32 - n))) -} - -// getNBit get n bits from 1 byte, begin with a position -pub(crate) fn get_nbits_from_byte(b: u8, begin: u16, n: u16) -> u16 { - let end_shift = 8 - (begin + n); - let mask = (0xFF >> begin) & (0xFF << end_shift) as u8; - (b & mask) as u16 >> end_shift -} - -// get24BitFromBytes get 24bits from `[3]byte` slice -pub(crate) fn get_24bits_from_bytes(b: &[u8]) -> u32 { - ((b[0] as u32) << 16) + ((b[1] as u32) << 8) + (b[2] as u32) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_get_padding() -> Result<()> { - let tests = vec![(0, 0), (1, 3), (2, 2), (3, 1), (4, 0), (100, 0), (500, 0)]; - - for (n, p) in tests { - assert_eq!( - get_padding_size(n), - p, - "Test case returned wrong value for input {n}" - ); - } - - Ok(()) - } - - #[test] - fn test_set_nbits_of_uint16() -> Result<()> { - let tests = vec![ - ("setOneBit", 0, 1, 8, 1, 128, None), - ("setStatusVectorBit", 0, 1, 0, 1, 32768, None), - ("setStatusVectorSecondBit", 32768, 1, 1, 1, 49152, None), - ( - "setStatusVectorInnerBitsAndCutValue", - 49152, - 2, - 6, - 11111, - 49920, - None, - ), - ("setRunLengthSecondTwoBit", 32768, 2, 1, 1, 40960, None), - ( - "setOneBitOutOfBounds", - 32768, - 2, - 15, - 1, - 0, - Some("invalid size or startIndex"), - ), - ]; - - for (name, source, size, index, value, result, err) in tests { - let res = set_nbits_of_uint16(source, size, index, value); - if err.is_some() { - assert!(res.is_err(), "setNBitsOfUint16 {name} : should be error"); - } else if let Ok(got) = res { - assert_eq!(got, result, "setNBitsOfUint16 {name}"); - } else { - panic!("setNBitsOfUint16 {name} :unexpected error result"); - } - } - - Ok(()) - } -} diff --git a/rtp/.gitignore b/rtp/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/rtp/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/rtp/CHANGELOG.md b/rtp/CHANGELOG.md deleted file mode 100644 index 52d87874c..000000000 --- a/rtp/CHANGELOG.md +++ /dev/null @@ -1,20 +0,0 @@ -# rtp changelog - -## Unreleased - -## v0.6.8 - -* Increased minimum support rust version to `1.60.0`. -* Adds a new generic header extensions type `rtp::extension::HeaderExtension` which allows abstracting over all known extensions as well as custom extensions. [#336](https://github.com/webrtc-rs/webrtc/pull/336) by [@k0nserv](https://github.com/k0nserv). -* Added video orientation(`urn:3gpp:video-orientation`) extension support. [#331](https://github.com/webrtc-rs/webrtc/pull/331) by [@algesten](https://github.com/algesten). -* Allow RTP extensions to be serialized and deserialized via serder. [#332](https://github.com/webrtc-rs/webrtc/pull/332) by [@algesten](https://github.com/algesten). -* Increased required `webrtc-util` version to `0.7.0`. - -## v0.6.7 - -* Bumped util dependency to `0.6.0`. - -## Prior to 0.6.7 - -Before 0.6.7 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/rtp/releases). - diff --git a/rtp/Cargo.toml b/rtp/Cargo.toml deleted file mode 100644 index 6c54464aa..000000000 --- a/rtp/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "rtp" -version = "0.11.0" -authors = ["Rain Liu ", "Michael Uti "] -edition = "2021" -description = "A pure Rust implementation of RTP" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/rtp" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/rtp" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["marshal"] } - -bytes = "1" -rand = "0.8" -thiserror = "1" -serde = { version = "1", features = ["derive"] } -portable-atomic = "1.6" - -memchr = "2.1.1" - -[dev-dependencies] -chrono = "0.4.28" -criterion = "0.5" - -[[bench]] -name = "packet_bench" -harness = false diff --git a/rtp/LICENSE-APACHE b/rtp/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/rtp/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rtp/LICENSE-MIT b/rtp/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/rtp/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/rtp/README.md b/rtp/README.md deleted file mode 100644 index c11c2f972..000000000 --- a/rtp/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of RTP. Rewrite Pion RTP in Rust -

diff --git a/rtp/benches/packet_bench.rs b/rtp/benches/packet_bench.rs deleted file mode 100644 index 128728e73..000000000 --- a/rtp/benches/packet_bench.rs +++ /dev/null @@ -1,62 +0,0 @@ -// Silence warning on `..Default::default()` with no effect: -#![allow(clippy::needless_update)] - -use bytes::{Bytes, BytesMut}; -use criterion::{criterion_group, criterion_main, Criterion}; -use rtp::header::*; -use rtp::packet::*; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -fn benchmark_packet(c: &mut Criterion) { - let pkt = Packet { - header: Header { - extension: true, - csrc: vec![1, 2], - extension_profile: EXTENSION_PROFILE_TWO_BYTE, - extensions: vec![ - Extension { - id: 1, - payload: Bytes::from_static(&[3, 4]), - }, - Extension { - id: 2, - payload: Bytes::from_static(&[5, 6]), - }, - ], - ..Default::default() - }, - payload: Bytes::from_static(&[0xFFu8; 15]), //vec![0x07, 0x08, 0x09, 0x0a], //MTU=1500 - ..Default::default() - }; - let raw = pkt.marshal().unwrap(); - let buf = &mut raw.clone(); - let p = Packet::unmarshal(buf).unwrap(); - if pkt != p { - panic!("marshal or unmarshal not correct: \npkt: {pkt:?} \nvs \np: {p:?}"); - } - - /////////////////////////////////////////////////////////////////////////////////////////////// - let mut buf = BytesMut::with_capacity(pkt.marshal_size()); - buf.resize(pkt.marshal_size(), 0); - c.bench_function("Benchmark MarshalTo", |b| { - b.iter(|| { - let _ = pkt.marshal_to(&mut buf).unwrap(); - }) - }); - - c.bench_function("Benchmark Marshal", |b| { - b.iter(|| { - let _ = pkt.marshal().unwrap(); - }) - }); - - c.bench_function("Benchmark Unmarshal ", |b| { - b.iter(|| { - let buf = &mut raw.clone(); - let _ = Packet::unmarshal(buf).unwrap(); - }) - }); -} - -criterion_group!(benches, benchmark_packet); -criterion_main!(benches); diff --git a/rtp/codecov.yml b/rtp/codecov.yml deleted file mode 100644 index 93bdf2942..000000000 --- a/rtp/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: cd48e18f-3916-4a20-ba56-81354d68a5d2 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/rtp/doc/webrtc.rs.png b/rtp/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/rtp/doc/webrtc.rs.png and /dev/null differ diff --git a/rtp/src/codecs/av1/av1_test.rs b/rtp/src/codecs/av1/av1_test.rs deleted file mode 100644 index 8734cae31..000000000 --- a/rtp/src/codecs/av1/av1_test.rs +++ /dev/null @@ -1,454 +0,0 @@ -use crate::codecs::av1::obu::{ - OBU_HAS_EXTENSION_BIT, OBU_TYPE_FRAME, OBU_TYPE_FRAME_HEADER, OBU_TYPE_METADATA, - OBU_TYPE_SEQUENCE_HEADER, OBU_TYPE_TEMPORAL_DELIMITER, OBU_TYPE_TILE_GROUP, OBU_TYPE_TILE_LIST, -}; -use crate::error::Result; - -use super::*; - -const OBU_EXTENSION_S1T1: u8 = 0b0010_1000; -const NEW_CODED_VIDEO_SEQUENCE_BIT: u8 = 0b0000_1000; - -struct Av1Obu { - header: u8, - extension: u8, - payload: Vec, -} - -impl Av1Obu { - pub fn new(obu_type: u8) -> Self { - Self { - header: obu_type << 3 | OBU_HAS_SIZE_BIT, - extension: 0, - payload: vec![], - } - } - - pub fn with_extension(mut self, extension: u8) -> Self { - self.extension = extension; - self.header |= OBU_HAS_EXTENSION_BIT; - self - } - - pub fn without_size(mut self) -> Self { - self.header &= !OBU_HAS_SIZE_BIT; - self - } - - pub fn with_payload(mut self, payload: Vec) -> Self { - self.payload = payload; - self - } -} - -fn build_av1_frame(obus: &Vec) -> Bytes { - let mut raw = vec![]; - for obu in obus { - raw.push(obu.header); - if obu.header & OBU_HAS_EXTENSION_BIT != 0 { - raw.push(obu.extension); - } - if obu.header & OBU_HAS_SIZE_BIT != 0 { - // write size in leb128 format. - let mut payload_size = obu.payload.len(); - while payload_size >= 0b1000_0000 { - raw.push(0b1000_0000 | (payload_size & 0b0111_1111) as u8); - payload_size >>= 7; - } - raw.push(payload_size as u8); - } - raw.extend_from_slice(&obu.payload); - } - Bytes::from(raw) -} - -#[test] -fn test_packetize_one_obu_without_size_and_extension() -> Result<()> { - let frame = build_av1_frame(&vec![Av1Obu::new(OBU_TYPE_FRAME) - .without_size() - .with_payload(vec![1, 2, 3, 4, 5, 6, 7])]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0001_0000, // aggregation header - OBU_TYPE_FRAME << 3, // header - 1, - 2, - 3, - 4, - 5, - 6, - 7 - ]] - ); - Ok(()) -} - -#[test] -fn test_packetize_one_obu_without_size_with_extension() -> Result<()> { - let frame = build_av1_frame(&vec![Av1Obu::new(OBU_TYPE_FRAME) - .without_size() - .with_extension(OBU_EXTENSION_S1T1) - .with_payload(vec![2, 3, 4, 5, 6, 7])]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0001_0000, // aggregation header - OBU_TYPE_FRAME << 3 | OBU_HAS_EXTENSION_BIT, // header - OBU_EXTENSION_S1T1, // extension header - 2, - 3, - 4, - 5, - 6, - 7 - ]] - ); - Ok(()) -} - -#[test] -fn removes_obu_size_field_without_extension() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![11, 12, 13, 14, 15, 16, 17]) - ]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0001_0000, // aggregation header - OBU_TYPE_FRAME << 3, // header - 11, - 12, - 13, - 14, - 15, - 16, - 17 - ]] - ); - Ok(()) -} - -#[test] -fn removes_obu_size_field_with_extension() -> Result<()> { - let frame = build_av1_frame(&vec![Av1Obu::new(OBU_TYPE_FRAME) - .with_extension(OBU_EXTENSION_S1T1) - .with_payload(vec![1, 2, 3, 4, 5, 6, 7])]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0001_0000, // aggregation header - OBU_TYPE_FRAME << 3 | OBU_HAS_EXTENSION_BIT, // header - OBU_EXTENSION_S1T1, // extension header - 1, - 2, - 3, - 4, - 5, - 6, - 7 - ]] - ); - Ok(()) -} - -#[test] -fn test_omits_size_for_last_obu_when_three_obus_fits_into_the_packet() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_SEQUENCE_HEADER).with_payload(vec![1, 2, 3, 4, 5, 6]), - Av1Obu::new(OBU_TYPE_METADATA).with_payload(vec![11, 12, 13, 14]), - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![21, 22, 23, 24, 25, 26]), - ]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0011_1000, // aggregation header - 7, // size of the first OBU - OBU_TYPE_SEQUENCE_HEADER << 3, // header of the first OBU - 1, - 2, - 3, - 4, - 5, - 6, - 5, // size of the second OBU - OBU_TYPE_METADATA << 3, // header of the second OBU - 11, - 12, - 13, - 14, - OBU_TYPE_FRAME << 3, // header of the third OBU - 21, - 22, - 23, - 24, - 25, - 26, - ]] - ); - Ok(()) -} - -#[test] -fn test_use_size_for_all_obus_when_four_obus_fits_into_the_packet() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_SEQUENCE_HEADER).with_payload(vec![1, 2, 3, 4, 5, 6]), - Av1Obu::new(OBU_TYPE_METADATA).with_payload(vec![11, 12, 13, 14]), - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![21, 22, 23]), - Av1Obu::new(OBU_TYPE_TILE_GROUP).with_payload(vec![31, 32, 33, 34, 35, 36]), - ]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0000_1000, // aggregation header - 7, // size of the first OBU - OBU_TYPE_SEQUENCE_HEADER << 3, // header of the first OBU - 1, - 2, - 3, - 4, - 5, - 6, - 5, // size of the second OBU - OBU_TYPE_METADATA << 3, // header of the second OBU - 11, - 12, - 13, - 14, - 4, // size of the third OBU - OBU_TYPE_FRAME << 3, // header of the third OBU - 21, - 22, - 23, - 7, // size of the fourth OBU - OBU_TYPE_TILE_GROUP << 3, // header of the fourth OBU - 31, - 32, - 33, - 34, - 35, - 36 - ]] - ); - Ok(()) -} - -#[test] -fn test_discards_temporal_delimiter_and_tile_list_obu() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_TEMPORAL_DELIMITER), - Av1Obu::new(OBU_TYPE_METADATA), - Av1Obu::new(OBU_TYPE_TILE_LIST).with_payload(vec![1, 2, 3, 4, 5, 6]), - Av1Obu::new(OBU_TYPE_FRAME_HEADER).with_payload(vec![21, 22, 23]), - Av1Obu::new(OBU_TYPE_TILE_GROUP).with_payload(vec![31, 32, 33, 34, 35, 36]), - ]); - let mut payloader = Av1Payloader {}; - assert_eq!( - payloader.payload(1200, &frame)?, - vec![vec![ - 0b0011_0000, // aggregation header - 1, // size of the first OBU - OBU_TYPE_METADATA << 3, // header of the first OBU - 4, // size of the second OBU - OBU_TYPE_FRAME_HEADER << 3, // header of the second OBU - 21, - 22, - 23, - OBU_TYPE_TILE_GROUP << 3, // header of the fourth OBU - 31, - 32, - 33, - 34, - 35, - 36 - ]] - ); - Ok(()) -} - -#[test] -fn test_split_two_obus_into_two_packet_force_split_obu_header() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_FRAME_HEADER) - .with_extension(OBU_EXTENSION_S1T1) - .with_payload(vec![21]), - Av1Obu::new(OBU_TYPE_TILE_GROUP) - .with_extension(OBU_EXTENSION_S1T1) - .with_payload(vec![11, 12, 13, 14]), - ]); - let mut payloader = Av1Payloader {}; - - // Craft expected payloads so that there is only one way to split original - // frame into two packets. - assert_eq!( - payloader.payload(6, &frame)?, - vec![ - vec![ - 0b0110_0000, // aggregation header - 3, // size of the first OBU - OBU_TYPE_FRAME_HEADER << 3 | OBU_HAS_EXTENSION_BIT, // header of the first OBU - OBU_EXTENSION_S1T1, // extension header - 21, - OBU_TYPE_TILE_GROUP << 3 | OBU_HAS_EXTENSION_BIT, // header of the second OBU - ], - vec![ - 0b1001_0000, // aggregation header - OBU_EXTENSION_S1T1, - 11, - 12, - 13, - 14 - ] - ] - ); - Ok(()) -} - -#[test] -fn test_sets_n_bit_at_the_first_packet_of_a_key_frame_with_sequence_header() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_SEQUENCE_HEADER).with_payload(vec![1, 2, 3, 4, 5, 6, 7]) - ]); - let mut payloader = Av1Payloader {}; - let result = payloader.payload(6, &frame)?; - assert_eq!(result.len(), 2); - assert_eq!( - result[0][0] & NEW_CODED_VIDEO_SEQUENCE_BIT, - NEW_CODED_VIDEO_SEQUENCE_BIT - ); - assert_eq!(result[1][0] & NEW_CODED_VIDEO_SEQUENCE_BIT, 0); - Ok(()) -} - -#[test] -fn test_doesnt_set_n_bit_at_the_packets_of_a_key_frame_without_sequence_header() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![1, 2, 3, 4, 5, 6, 7]) - ]); - let mut payloader = Av1Payloader {}; - let result = payloader.payload(6, &frame)?; - assert_eq!(result.len(), 2); - assert_eq!(result[0][0] & NEW_CODED_VIDEO_SEQUENCE_BIT, 0); - assert_eq!(result[1][0] & NEW_CODED_VIDEO_SEQUENCE_BIT, 0); - Ok(()) -} - -#[test] -fn test_doesnt_set_n_bit_at_the_packets_of_a_delta_frame() -> Result<()> { - // TODO: implement delta frame detection. - Ok(()) -} - -#[test] -fn test_split_single_obu_into_two_packets() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![11, 12, 13, 14, 15, 16, 17, 18, 19]) - ]); - let mut payloader = Av1Payloader {}; - // let result = payloader.payload(8, &frame)?; - // println!("{:?}", result[0].to_vec()); - // println!("{:?}", result[1].to_vec()); - assert_eq!( - payloader.payload(8, &frame)?, - vec![ - vec![ - 0b0101_0000, // aggregation header - OBU_TYPE_FRAME << 3, // header - 11, - 12, - 13, - 14, - 15, - 16 - ], - vec![ - 0b1001_0000, // aggregation header - 17, - 18, - 19 - ], - ] - ); - - Ok(()) -} - -#[test] -fn test_split_single_obu_into_many_packets() -> Result<()> { - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![27; 1200]) - ]); - let mut payloader = Av1Payloader {}; - let result = payloader.payload(100, &frame)?; - assert_eq!(result.len(), 13); - assert_eq!(result[0], { - let mut ret = vec![ - 0b0101_0000, // aggregation header - OBU_TYPE_FRAME << 3, // header - ]; - ret.extend(vec![27; 98]); - ret - }); - for packet in result.iter().take(12).skip(1) { - assert_eq!(packet.to_vec(), { - let mut ret = vec![ - 0b1101_0000, // aggregation header - ]; - ret.extend(vec![27; 99]); - ret - }); - } - assert_eq!(result[12], { - let mut ret = vec![ - 0b1001_0000, // aggregation header - ]; - ret.extend(vec![27; 13]); - ret - }); - - Ok(()) -} - -#[test] -fn test_split_two_obus_into_two_packets() -> Result<()> { - // 2nd OBU is too large to fit into one packet, so its head would be in the - // same packet as the 1st OBU. - let frame = build_av1_frame(&vec![ - Av1Obu::new(OBU_TYPE_SEQUENCE_HEADER).with_payload(vec![11, 12]), - Av1Obu::new(OBU_TYPE_FRAME).with_payload(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]), - ]); - let mut payloader = Av1Payloader {}; - let result = payloader.payload(8, &frame)?; - assert_eq!( - result, - vec![ - vec![ - 0b0110_1000, // aggregation header - 3, // size of the first OBU - OBU_TYPE_SEQUENCE_HEADER << 3, // header - 11, - 12, - OBU_TYPE_FRAME << 3, // header of the second OBU - 1, - 2 - ], - vec![ - 0b1001_0000, // aggregation header - 3, - 4, - 5, - 6, - 7, - 8, - 9 - ] - ] - ); - Ok(()) -} diff --git a/rtp/src/codecs/av1/leb128.rs b/rtp/src/codecs/av1/leb128.rs deleted file mode 100644 index 677524f5d..000000000 --- a/rtp/src/codecs/av1/leb128.rs +++ /dev/null @@ -1,64 +0,0 @@ -use bytes::{BufMut, Bytes, BytesMut}; - -pub fn encode_leb128(mut val: u32) -> u32 { - let mut b = 0; - loop { - b |= val & 0b_0111_1111; - val >>= 7; - if val != 0 { - b |= 0b_1000_0000; - b <<= 8; - } else { - return b; - } - } -} - -pub fn decode_leb128(mut val: u32) -> u32 { - let mut b = 0; - loop { - b |= val & 0b_0111_1111; - val >>= 8; - if val == 0 { - return b; - } - b <<= 7; - } -} - -pub fn read_leb128(bytes: &Bytes) -> (u32, usize) { - let mut encoded = 0; - for i in 0..bytes.len() { - encoded |= bytes[i] as u32; - if bytes[i] & 0b_1000_0000 == 0 { - return (decode_leb128(encoded), i + 1); - } - encoded <<= 8; - } - (0, 0) -} - -pub fn leb128_size(value: u32) -> usize { - let mut size = 0; - let mut value = value; - while value >= 0b_1000_0000 { - size += 1; - value >>= 7; - } - size + 1 -} - -pub trait BytesMutExt { - fn put_leb128(&mut self, n: u32); -} - -impl BytesMutExt for BytesMut { - fn put_leb128(&mut self, n: u32) { - let mut encoded = encode_leb128(n); - while encoded >= 0b_1000_0000 { - self.put_u8(0b_1000_0000 | (encoded & 0b_0111_1111) as u8); - encoded >>= 7; - } - self.put_u8(encoded as u8); - } -} diff --git a/rtp/src/codecs/av1/mod.rs b/rtp/src/codecs/av1/mod.rs deleted file mode 100644 index 42358c3b0..000000000 --- a/rtp/src/codecs/av1/mod.rs +++ /dev/null @@ -1,122 +0,0 @@ -use bytes::{BufMut, Bytes, BytesMut}; - -use crate::codecs::av1::leb128::BytesMutExt; -use crate::codecs::av1::obu::{obu_has_extension, parse_obus, OBU_HAS_SIZE_BIT}; -use crate::codecs::av1::packetizer::{ - get_aggregation_header, packetize, AGGREGATION_HEADER_SIZE, MAX_NUM_OBUS_TO_OMIT_SIZE, -}; -use crate::packetizer::Payloader; - -#[cfg(test)] -mod av1_test; -mod leb128; -mod obu; -mod packetizer; - -#[derive(Default, Clone, Debug)] -pub struct Av1Payloader {} - -impl Payloader for Av1Payloader { - /// Based on - /// Reference: - fn payload(&mut self, mtu: usize, payload: &Bytes) -> crate::error::Result> { - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // |Z|Y|1 0|N|-|-|-| OBU element 1 size (leb128) | | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | - // | | - // : : - // : OBU element 1 data : - // : : - // | | - // | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | | | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | - // | | - // : : - // : OBU element 2 data : - // : : - // | | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - - // Parse the payload into series of OBUs. - let obus = parse_obus(payload)?; - - // Packetize the OBUs, possibly aggregating multiple OBUs into a single packet, - // or splitting a single OBU across multiple packets. - let packets_metadata = packetize(&obus, mtu); - - let mut payloads = vec![]; - - // Split the payload into RTP packets according to the packetization scheme. - for packet_index in 0..packets_metadata.len() { - let packet = &packets_metadata[packet_index]; - let mut obu_offset = packet.first_obu_offset; - let aggregation_header = get_aggregation_header(&obus, &packets_metadata, packet_index); - - let mut out = BytesMut::with_capacity(AGGREGATION_HEADER_SIZE + packet.packet_size); - out.put_u8(aggregation_header); - - // Store all OBU elements except the last one. - for i in 0..packet.num_obu_elements - 1 { - let obu = &obus[packet.first_obu_index + i]; - let fragment_size = obu.size - obu_offset; - out.put_leb128(fragment_size as u32); - if obu_offset == 0 { - out.put_u8(obu.header & !OBU_HAS_SIZE_BIT); - } - if obu_offset <= 1 && obu_has_extension(obu.header) { - out.put_u8(obu.extension_header); - } - let payload_offset = if obu_offset > obu.header_size() { - obu_offset - obu.header_size() - } else { - 0 - }; - let payload_size = obu.payload.len() - payload_offset; - out.put_slice( - obu.payload - .slice(payload_offset..payload_offset + payload_size) - .as_ref(), - ); - // All obus are stored from the beginning, except, may be, the first one. - obu_offset = 0; - } - - // Store the last OBU element. - let last_obu = &obus[packet.first_obu_index + packet.num_obu_elements - 1]; - let mut fragment_size = packet.last_obu_size; - if packet.num_obu_elements > MAX_NUM_OBUS_TO_OMIT_SIZE { - out.put_leb128(fragment_size as u32); - } - if obu_offset == 0 && fragment_size > 0 { - out.put_u8(last_obu.header & !OBU_HAS_SIZE_BIT); - fragment_size -= 1; - } - if obu_offset <= 1 && obu_has_extension(last_obu.header) && fragment_size > 0 { - out.put_u8(last_obu.extension_header); - fragment_size -= 1; - } - let payload_offset = if obu_offset > last_obu.header_size() { - obu_offset - last_obu.header_size() - } else { - 0 - }; - out.put_slice( - last_obu - .payload - .slice(payload_offset..payload_offset + fragment_size) - .as_ref(), - ); - - payloads.push(out.freeze()); - } - - Ok(payloads) - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/rtp/src/codecs/av1/obu.rs b/rtp/src/codecs/av1/obu.rs deleted file mode 100644 index f28936e60..000000000 --- a/rtp/src/codecs/av1/obu.rs +++ /dev/null @@ -1,114 +0,0 @@ -//! Based on https://chromium.googlesource.com/external/webrtc/+/4e513346ec56c829b3a6010664998469fc237b35/modules/rtp_rtcp/source/rtp_packetizer_av1.cc -//! Reference: https://aomediacodec.github.io/av1-spec/#obu-syntax - -use bytes::Bytes; - -use crate::codecs::av1::leb128::read_leb128; -use crate::error::Result; -use crate::Error::{ErrPayloadTooSmallForObuExtensionHeader, ErrPayloadTooSmallForObuPayloadSize}; - -pub const OBU_HAS_EXTENSION_BIT: u8 = 0b0000_0100; -pub const OBU_HAS_SIZE_BIT: u8 = 0b0000_0010; -pub const OBU_TYPE_MASK: u8 = 0b0111_1000; - -pub const OBU_TYPE_SEQUENCE_HEADER: u8 = 1; -pub const OBU_TYPE_TEMPORAL_DELIMITER: u8 = 2; -pub const OBU_TYPE_FRAME_HEADER: u8 = 3; -pub const OBU_TYPE_TILE_GROUP: u8 = 4; -pub const OBU_TYPE_METADATA: u8 = 5; -pub const OBU_TYPE_FRAME: u8 = 6; -pub const OBU_TYPE_TILE_LIST: u8 = 8; -pub const OBU_TYPE_PADDING: u8 = 15; - -#[derive(Debug, Clone)] -pub struct Obu { - pub header: u8, - pub extension_header: u8, - pub payload: Bytes, - /// size of the header and payload combined. - pub size: usize, -} - -impl Obu { - pub fn header_size(&self) -> usize { - if obu_has_extension(self.header) { - 2 - } else { - 1 - } - } -} - -/// Parses the payload into series of OBUs. -/// Reference: https://aomediacodec.github.io/av1-spec/#obu-syntax -pub fn parse_obus(payload: &Bytes) -> Result> { - let mut obus = vec![]; - let mut payload_data_remaining = payload.len() as isize; - let mut payload_data_index: usize = 0; - - while payload_data_remaining > 0 { - // Read OBU header. - let header = payload[payload_data_index]; - let has_extension = obu_has_extension(header); - let has_size = obu_has_size(header); - let obu_type = obu_type(header); - - // Read OBU extension header. - let extension_header = if has_extension { - if payload_data_remaining < 2 { - return Err(ErrPayloadTooSmallForObuExtensionHeader); - } - payload[payload_data_index + 1] - } else { - 0 - }; - let obu_header_size = if has_extension { 2 } else { 1 }; - let payload_without_header = payload.slice(payload_data_index + obu_header_size..); - - // Read OBU payload. - let obu_payload = if !has_size { - payload_without_header - } else { - if payload_without_header.is_empty() { - return Err(ErrPayloadTooSmallForObuPayloadSize); - } - let (obu_payload_size, leb128_size) = read_leb128(&payload_without_header); - payload_data_remaining -= leb128_size as isize; - payload_data_index += leb128_size; - payload_without_header.slice(leb128_size..leb128_size + obu_payload_size as usize) - }; - - let obu_size = obu_header_size + obu_payload.len(); - if !should_ignore_obu_type(obu_type) { - obus.push(Obu { - header, - extension_header, - payload: obu_payload, - size: obu_size, - }); - } - - payload_data_remaining -= obu_size as isize; - payload_data_index += obu_size; - } - - Ok(obus) -} - -pub fn obu_has_extension(header: u8) -> bool { - header & OBU_HAS_EXTENSION_BIT != 0 -} - -pub fn obu_has_size(header: u8) -> bool { - header & OBU_HAS_SIZE_BIT != 0 -} - -pub fn obu_type(header: u8) -> u8 { - (header & OBU_TYPE_MASK) >> 3 -} - -fn should_ignore_obu_type(obu_type: u8) -> bool { - obu_type == OBU_TYPE_TEMPORAL_DELIMITER - || obu_type == OBU_TYPE_TILE_LIST - || obu_type == OBU_TYPE_PADDING -} diff --git a/rtp/src/codecs/av1/packetizer.rs b/rtp/src/codecs/av1/packetizer.rs deleted file mode 100644 index 119a91f48..000000000 --- a/rtp/src/codecs/av1/packetizer.rs +++ /dev/null @@ -1,258 +0,0 @@ -//! Based on https://chromium.googlesource.com/external/webrtc/+/4e513346ec56c829b3a6010664998469fc237b35/modules/rtp_rtcp/source/rtp_packetizer_av1.cc -//! Reference: https://aomediacodec.github.io/av1-rtp-spec - -use std::cmp::min; - -use crate::codecs::av1::leb128::leb128_size; -use crate::codecs::av1::obu::{obu_type, Obu, OBU_TYPE_SEQUENCE_HEADER}; - -/// When there are 3 or less OBU (fragments) in a packet, size of the last one -/// can be omitted. -pub const MAX_NUM_OBUS_TO_OMIT_SIZE: usize = 3; -pub const AGGREGATION_HEADER_SIZE: usize = 1; - -pub struct PacketMetadata { - pub first_obu_index: usize, - pub num_obu_elements: usize, - pub first_obu_offset: usize, - pub last_obu_size: usize, - /// Total size consumed by the packet. - pub packet_size: usize, -} - -impl PacketMetadata { - fn new(first_obu_index: usize) -> Self { - Self { - first_obu_index, - num_obu_elements: 0, - first_obu_offset: 0, - last_obu_size: 0, - packet_size: 0, - } - } -} - -/// Returns the scheme for how to aggregate or split the OBUs across RTP packets. -/// Reference: https://aomediacodec.github.io/av1-rtp-spec/#45-payload-structure -/// https://aomediacodec.github.io/av1-rtp-spec/#5-packetization-rules -pub fn packetize(obus: &[Obu], mtu: usize) -> Vec { - if obus.is_empty() { - return vec![]; - } - // Ignore certain edge cases where packets should be very small. They are - // impractical but adds complexity to handle. - if mtu < 3 { - return vec![]; - } - - let mut packets = vec![]; - - // Aggregation header will be present in all packets. - let max_payload_size = mtu - AGGREGATION_HEADER_SIZE; - - // Assemble packets. Push to current packet as much as it can hold before - // considering next one. That would normally cause uneven distribution across - // packets, specifically last one would be generally smaller. - packets.push(PacketMetadata::new(0)); - let mut packet_remaining_bytes = max_payload_size; - - for obu_index in 0..obus.len() { - let is_last_obu = obu_index == obus.len() - 1; - let obu = &obus[obu_index]; - - // Putting |obu| into the last packet would make last obu element stored in - // that packet not last. All not last OBU elements must be prepend with the - // element length. AdditionalBytesForPreviousObuElement calculates how many - // bytes are needed to store that length. - let mut packet = packets.pop().unwrap(); - let mut previous_obu_extra_size = additional_bytes_for_previous_obu_element(&packet); - let min_required_size = if packet.num_obu_elements >= MAX_NUM_OBUS_TO_OMIT_SIZE { - 2 - } else { - 1 - }; - if packet_remaining_bytes < previous_obu_extra_size + min_required_size { - // Start a new packet. - packets.push(packet); - packet = PacketMetadata::new(obu_index); - packet_remaining_bytes = max_payload_size; - previous_obu_extra_size = 0; - } - packet.packet_size += previous_obu_extra_size; - packet_remaining_bytes -= previous_obu_extra_size; - packet.num_obu_elements += 1; - let must_write_obu_element_size = packet.num_obu_elements > MAX_NUM_OBUS_TO_OMIT_SIZE; - - // Can fit all of the obu into the packet? - let mut required_bytes = obu.size; - if must_write_obu_element_size { - required_bytes += leb128_size(obu.size as u32); - } - if required_bytes < packet_remaining_bytes { - // Insert the obu into the packet unfragmented. - packet.last_obu_size = obu.size; - packet.packet_size += required_bytes; - packet_remaining_bytes -= required_bytes; - packets.push(packet); - continue; - } - - // Fragment the obu. - let max_first_fragment_size = if must_write_obu_element_size { - max_fragment_size(packet_remaining_bytes) - } else { - packet_remaining_bytes - }; - // Because available_bytes might be different than - // packet_remaining_bytes it might happen that max_first_fragment_size >= - // obu.size. Also, since checks above verified |obu| should not be put - // completely into the |packet|, leave at least 1 byte for later packet. - let first_fragment_size = min(obu.size - 1, max_first_fragment_size); - if first_fragment_size == 0 { - // Rather than writing 0-size element at the tail of the packet, - // 'uninsert' the |obu| from the |packet|. - packet.num_obu_elements -= 1; - packet.packet_size -= previous_obu_extra_size; - } else { - packet.packet_size += first_fragment_size; - if must_write_obu_element_size { - packet.packet_size += leb128_size(first_fragment_size as u32); - } - packet.last_obu_size = first_fragment_size; - } - packets.push(packet); - - // Add middle fragments that occupy all of the packet. - // These are easy because - // - one obu per packet imply no need to store the size of the obu. - // - this packets are nor the first nor the last packets of the frame, so - // packet capacity is always limits.max_payload_len. - let mut obu_offset = first_fragment_size; - while obu_offset + max_payload_size < obu.size { - let mut packet = PacketMetadata::new(obu_index); - packet.num_obu_elements = 1; - packet.first_obu_offset = obu_offset; - let middle_fragment_size = max_payload_size; - packet.last_obu_size = middle_fragment_size; - packet.packet_size = middle_fragment_size; - packets.push(packet); - obu_offset += max_payload_size; - } - - // Add the last fragment of the obu. - let mut last_fragment_size = obu.size - obu_offset; - // Check for corner case where last fragment of the last obu is too large - // to fit into last packet, but may fully fit into semi-last packet. - if is_last_obu && last_fragment_size > max_payload_size { - // Split last fragments into two. - // Try to even packet sizes rather than payload sizes across the last - // two packets. - let mut semi_last_fragment_size = last_fragment_size / 2; - // But leave at least one payload byte for the last packet to avoid - // weird scenarios where size of the fragment is zero and rtp payload has - // nothing except for an aggregation header. - if semi_last_fragment_size >= last_fragment_size { - semi_last_fragment_size = last_fragment_size - 1; - } - last_fragment_size -= semi_last_fragment_size; - let mut packet = PacketMetadata::new(obu_index); - packet.first_obu_offset = obu_offset; - packet.last_obu_size = semi_last_fragment_size; - packet.packet_size = semi_last_fragment_size; - packets.push(packet); - obu_offset += semi_last_fragment_size - } - let mut last_packet = PacketMetadata::new(obu_index); - last_packet.num_obu_elements = 1; - last_packet.first_obu_offset = obu_offset; - last_packet.last_obu_size = last_fragment_size; - last_packet.packet_size = last_fragment_size; - packets.push(last_packet); - packet_remaining_bytes = max_payload_size - last_fragment_size; - } - - packets -} - -/// Returns the aggregation header for the packet. -/// Reference: https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header -pub fn get_aggregation_header(obus: &[Obu], packets: &[PacketMetadata], packet_index: usize) -> u8 { - let packet = &packets[packet_index]; - let mut header: u8 = 0; - - // Set Z flag: first obu element is continuation of the previous OBU. - let first_obu_element_is_fragment = packet.first_obu_offset > 0; - if first_obu_element_is_fragment { - header |= 1 << 7; - } - - // Set Y flag: last obu element will be continuated in the next packet. - let last_obu_offset = if packet.num_obu_elements == 1 { - packet.first_obu_offset - } else { - 0 - }; - let last_obu_is_fragment = last_obu_offset + packet.last_obu_size - < obus[packet.first_obu_index + packet.num_obu_elements - 1].size; - if last_obu_is_fragment { - header |= 1 << 6; - } - - // Set W field: number of obu elements in the packet (when not too large). - if packet.num_obu_elements <= MAX_NUM_OBUS_TO_OMIT_SIZE { - header |= (packet.num_obu_elements as u8) << 4; - } - - // Set N flag: beginning of a new coded video sequence. - // Encoder may produce key frame without a sequence header, thus double check - // incoming frame includes the sequence header. Since Temporal delimiter is - // already filtered out, sequence header should be the first obu when present. - // - // TODO: This is technically incorrect, since sequence headers may be present in delta frames. - // However, unlike the Chromium implementation: https://chromium.googlesource.com/external/webrtc/+/4e513346ec56c829b3a6010664998469fc237b35/modules/rtp_rtcp/source/rtp_packetizer_av1.cc#345, - // we do not have direct access to the whether this is a keyframe or a delta frame. - // Thus for now we assume that every frame that starts with a sequence header is a keyframe, - // which is not always true. This is the best we can do for now until implementing - // a proper frame type detection, perhaps by parsing the FRAME_HEADER OBUs according to - // https://aomediacodec.github.io/av1-spec/#ordering-of-obus: - // A new coded video sequence is defined to start at each temporal unit which - // satisfies both of the following conditions: - // - A sequence header OBU appears before the first frame header. - // - The first frame header has frame_type equal to KEY_FRAME, show_frame equal - // to 1, show_existing_frame equal to 0, and temporal_id equal to 0. - if packet_index == 0 && obu_type(obus.first().unwrap().header) == OBU_TYPE_SEQUENCE_HEADER { - header |= 1 << 3; - } - header -} - -/// Returns the number of additional bytes needed to store the previous OBU -/// element if an additional OBU element is added to the packet. -fn additional_bytes_for_previous_obu_element(packet: &PacketMetadata) -> usize { - if packet.packet_size == 0 || packet.num_obu_elements > MAX_NUM_OBUS_TO_OMIT_SIZE { - // Packet is still empty => no last OBU element, no need to reserve space for it. - // OR - // There are so many obu elements in the packet, all of them must be - // prepended with the length field. That imply space for the length of the - // last obu element is already reserved. - 0 - } else { - leb128_size(packet.last_obu_size as u32) - } -} - -/// Given |remaining_bytes| free bytes left in a packet, returns max size of an -/// OBU fragment that can fit into the packet. -/// i.e. MaxFragmentSize + Leb128Size(MaxFragmentSize) <= remaining_bytes. -fn max_fragment_size(remaining_bytes: usize) -> usize { - if remaining_bytes <= 1 { - return 0; - } - let mut i = 1; - loop { - if remaining_bytes < (1 << (7 * i)) + i { - return remaining_bytes - i; - } - i += 1; - } -} diff --git a/rtp/src/codecs/g7xx/g7xx_test.rs b/rtp/src/codecs/g7xx/g7xx_test.rs deleted file mode 100644 index a172944d1..000000000 --- a/rtp/src/codecs/g7xx/g7xx_test.rs +++ /dev/null @@ -1,50 +0,0 @@ -use super::*; - -#[test] -fn test_g7xx_payload() -> Result<()> { - let mut pck = G711Payloader::default(); - - const TEST_LEN: usize = 10000; - const TEST_MTU: usize = 1500; - - //generate random 8-bit g722 samples - let samples: Vec = (0..TEST_LEN).map(|_| rand::random::()).collect(); - - //make a copy, for payloader input - let mut samples_in = vec![0u8; TEST_LEN]; - samples_in.clone_from_slice(&samples); - let samples_in = Bytes::copy_from_slice(&samples_in); - - //split our samples into payloads - let payloads = pck.payload(TEST_MTU, &samples_in)?; - - let outcnt = ((TEST_LEN as f64) / (TEST_MTU as f64)).ceil() as usize; - assert_eq!( - outcnt, - payloads.len(), - "Generated {} payloads instead of {}", - payloads.len(), - outcnt - ); - assert_eq!(&samples, &samples_in, "Modified input samples"); - - let samples_out = payloads.concat(); - assert_eq!(&samples_out, &samples_in, "Output samples don't match"); - - let empty = Bytes::from_static(&[]); - let payload = Bytes::from_static(&[0x90, 0x90, 0x90]); - - // Positive MTU, empty payload - let result = pck.payload(1, &empty)?; - assert!(result.is_empty(), "Generated payload should be empty"); - - // 0 MTU, small payload - let result = pck.payload(0, &payload)?; - assert_eq!(result.len(), 0, "Generated payload should be empty"); - - // Positive MTU, small payload - let result = pck.payload(10, &payload)?; - assert_eq!(result.len(), 1, "Generated payload should be the 1"); - - Ok(()) -} diff --git a/rtp/src/codecs/g7xx/mod.rs b/rtp/src/codecs/g7xx/mod.rs deleted file mode 100644 index e0c69020f..000000000 --- a/rtp/src/codecs/g7xx/mod.rs +++ /dev/null @@ -1,43 +0,0 @@ -#[cfg(test)] -mod g7xx_test; - -use bytes::Bytes; - -use crate::error::Result; -use crate::packetizer::Payloader; - -/// G711Payloader payloads G711 packets -pub type G711Payloader = G7xxPayloader; -/// G722Payloader payloads G722 packets -pub type G722Payloader = G7xxPayloader; - -#[derive(Default, Debug, Copy, Clone)] -pub struct G7xxPayloader; - -impl Payloader for G7xxPayloader { - /// Payload fragments an G7xx packet across one or more byte arrays - fn payload(&mut self, mtu: usize, payload: &Bytes) -> Result> { - if payload.is_empty() || mtu == 0 { - return Ok(vec![]); - } - - let mut payload_data_remaining = payload.len(); - let mut payload_data_index = 0; - let mut payloads = Vec::with_capacity(payload_data_remaining / mtu); - while payload_data_remaining > 0 { - let current_fragment_size = std::cmp::min(mtu, payload_data_remaining); - payloads.push( - payload.slice(payload_data_index..payload_data_index + current_fragment_size), - ); - - payload_data_remaining -= current_fragment_size; - payload_data_index += current_fragment_size; - } - - Ok(payloads) - } - - fn clone_to(&self) -> Box { - Box::new(*self) - } -} diff --git a/rtp/src/codecs/h264/h264_test.rs b/rtp/src/codecs/h264/h264_test.rs deleted file mode 100644 index bc8aa4641..000000000 --- a/rtp/src/codecs/h264/h264_test.rs +++ /dev/null @@ -1,263 +0,0 @@ -// Silence warning on `for i in 0..vec.len() { โ€ฆ }`: -#![allow(clippy::needless_range_loop)] - -use super::*; - -#[test] -fn test_h264_payload() -> Result<()> { - let empty = Bytes::from_static(&[]); - let small_payload = Bytes::from_static(&[0x90, 0x90, 0x90]); - let multiple_payload = Bytes::from_static(&[0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x01, 0x90]); - let large_payload = Bytes::from_static(&[ - 0x00, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, 0x11, - 0x12, 0x13, 0x14, 0x15, - ]); - let large_payload_packetized = vec![ - Bytes::from_static(&[0x1c, 0x80, 0x01, 0x02, 0x03]), - Bytes::from_static(&[0x1c, 0x00, 0x04, 0x05, 0x06]), - Bytes::from_static(&[0x1c, 0x00, 0x07, 0x08, 0x09]), - Bytes::from_static(&[0x1c, 0x00, 0x10, 0x11, 0x12]), - Bytes::from_static(&[0x1c, 0x40, 0x13, 0x14, 0x15]), - ]; - - let mut pck = H264Payloader::default(); - - // Positive MTU, empty payload - let result = pck.payload(1, &empty)?; - assert!(result.is_empty(), "Generated payload should be empty"); - - // 0 MTU, small payload - let result = pck.payload(0, &small_payload)?; - assert_eq!(result.len(), 0, "Generated payload should be empty"); - - // Positive MTU, small payload - let result = pck.payload(1, &small_payload)?; - assert_eq!(result.len(), 0, "Generated payload should be empty"); - - // Positive MTU, small payload - let result = pck.payload(5, &small_payload)?; - assert_eq!(result.len(), 1, "Generated payload should be the 1"); - assert_eq!( - result[0].len(), - small_payload.len(), - "Generated payload should be the same size as original payload size" - ); - - // Multiple NALU in a single payload - let result = pck.payload(5, &multiple_payload)?; - assert_eq!(result.len(), 2, "2 nal units should be broken out"); - for i in 0..2 { - assert_eq!( - result[i].len(), - 1, - "Payload {} of 2 is packed incorrectly", - i + 1, - ); - } - - // Large Payload split across multiple RTP Packets - let result = pck.payload(5, &large_payload)?; - assert_eq!( - result, large_payload_packetized, - "FU-A packetization failed" - ); - - // Nalu type 9 or 12 - let small_payload2 = Bytes::from_static(&[0x09, 0x00, 0x00]); - let result = pck.payload(5, &small_payload2)?; - assert_eq!(result.len(), 0, "Generated payload should be empty"); - - Ok(()) -} - -#[test] -fn test_h264_packet_unmarshal() -> Result<()> { - let single_payload = Bytes::from_static(&[0x90, 0x90, 0x90]); - let single_payload_unmarshaled = - Bytes::from_static(&[0x00, 0x00, 0x00, 0x01, 0x90, 0x90, 0x90]); - let single_payload_unmarshaled_avc = - Bytes::from_static(&[0x00, 0x00, 0x00, 0x03, 0x90, 0x90, 0x90]); - - let large_payload = Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, - 0x11, 0x12, 0x13, 0x14, 0x15, - ]); - let large_payload_avc = Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, - 0x11, 0x12, 0x13, 0x14, 0x15, - ]); - let large_payload_packetized = vec![ - Bytes::from_static(&[0x1c, 0x80, 0x01, 0x02, 0x03]), - Bytes::from_static(&[0x1c, 0x00, 0x04, 0x05, 0x06]), - Bytes::from_static(&[0x1c, 0x00, 0x07, 0x08, 0x09]), - Bytes::from_static(&[0x1c, 0x00, 0x10, 0x11, 0x12]), - Bytes::from_static(&[0x1c, 0x40, 0x13, 0x14, 0x15]), - ]; - - let single_payload_multi_nalu = Bytes::from_static(&[ - 0x78, 0x00, 0x0f, 0x67, 0x42, 0xc0, 0x1f, 0x1a, 0x32, 0x35, 0x01, 0x40, 0x7a, 0x40, 0x3c, - 0x22, 0x11, 0xa8, 0x00, 0x05, 0x68, 0x1a, 0x34, 0xe3, 0xc8, - ]); - let single_payload_multi_nalu_unmarshaled = Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0xc0, 0x1f, 0x1a, 0x32, 0x35, 0x01, 0x40, 0x7a, 0x40, - 0x3c, 0x22, 0x11, 0xa8, 0x00, 0x00, 0x00, 0x01, 0x68, 0x1a, 0x34, 0xe3, 0xc8, - ]); - let single_payload_multi_nalu_unmarshaled_avc = Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x0f, 0x67, 0x42, 0xc0, 0x1f, 0x1a, 0x32, 0x35, 0x01, 0x40, 0x7a, 0x40, - 0x3c, 0x22, 0x11, 0xa8, 0x00, 0x00, 0x00, 0x05, 0x68, 0x1a, 0x34, 0xe3, 0xc8, - ]); - - let incomplete_single_payload_multi_nalu = Bytes::from_static(&[ - 0x78, 0x00, 0x0f, 0x67, 0x42, 0xc0, 0x1f, 0x1a, 0x32, 0x35, 0x01, 0x40, 0x7a, 0x40, 0x3c, - 0x22, 0x11, - ]); - - let mut pkt = H264Packet::default(); - let mut avc_pkt = H264Packet { - is_avc: true, - ..Default::default() - }; - - let data = Bytes::from_static(&[]); - let result = pkt.depacketize(&data); - assert!(result.is_err(), "Unmarshal did not fail on nil payload"); - - let data = Bytes::from_static(&[0x00, 0x00]); - let result = pkt.depacketize(&data); - assert!( - result.is_err(), - "Unmarshal accepted a packet that is too small for a payload and header" - ); - - let data = Bytes::from_static(&[0xFF, 0x00, 0x00]); - let result = pkt.depacketize(&data); - assert!( - result.is_err(), - "Unmarshal accepted a packet with a NALU Type we don't handle" - ); - - let result = pkt.depacketize(&incomplete_single_payload_multi_nalu); - assert!( - result.is_err(), - "Unmarshal accepted a STAP-A packet with insufficient data" - ); - - let payload = pkt.depacketize(&single_payload)?; - assert_eq!( - payload, single_payload_unmarshaled, - "Unmarshalling a single payload shouldn't modify the payload" - ); - - let payload = avc_pkt.depacketize(&single_payload)?; - assert_eq!( - payload, single_payload_unmarshaled_avc, - "Unmarshalling a single payload into avc stream shouldn't modify the payload" - ); - - let mut large_payload_result = BytesMut::new(); - for p in &large_payload_packetized { - let payload = pkt.depacketize(p)?; - large_payload_result.put(&*payload.clone()); - } - assert_eq!( - large_payload_result.freeze(), - large_payload, - "Failed to unmarshal a large payload" - ); - - let mut large_payload_result_avc = BytesMut::new(); - for p in &large_payload_packetized { - let payload = avc_pkt.depacketize(p)?; - large_payload_result_avc.put(&*payload.clone()); - } - assert_eq!( - large_payload_result_avc.freeze(), - large_payload_avc, - "Failed to unmarshal a large payload into avc stream" - ); - - let payload = pkt.depacketize(&single_payload_multi_nalu)?; - assert_eq!( - payload, single_payload_multi_nalu_unmarshaled, - "Failed to unmarshal a single packet with multiple NALUs" - ); - - let payload = avc_pkt.depacketize(&single_payload_multi_nalu)?; - assert_eq!( - payload, single_payload_multi_nalu_unmarshaled_avc, - "Failed to unmarshal a single packet with multiple NALUs into avc stream" - ); - - Ok(()) -} - -#[test] -fn test_h264_partition_head_checker_is_partition_head() -> Result<()> { - let h264 = H264Packet::default(); - let empty_nalu = Bytes::from_static(&[]); - assert!( - !h264.is_partition_head(&empty_nalu), - "empty nalu must not be a partition head" - ); - - let single_nalu = Bytes::from_static(&[1, 0]); - assert!( - h264.is_partition_head(&single_nalu), - "single nalu must be a partition head" - ); - - let stapa_nalu = Bytes::from_static(&[STAPA_NALU_TYPE, 0]); - assert!( - h264.is_partition_head(&stapa_nalu), - "stapa nalu must be a partition head" - ); - - let fua_start_nalu = Bytes::from_static(&[FUA_NALU_TYPE, FU_START_BITMASK]); - assert!( - h264.is_partition_head(&fua_start_nalu), - "fua start nalu must be a partition head" - ); - - let fua_end_nalu = Bytes::from_static(&[FUA_NALU_TYPE, FU_END_BITMASK]); - assert!( - !h264.is_partition_head(&fua_end_nalu), - "fua end nalu must not be a partition head" - ); - - let fub_start_nalu = Bytes::from_static(&[FUB_NALU_TYPE, FU_START_BITMASK]); - assert!( - h264.is_partition_head(&fub_start_nalu), - "fub start nalu must be a partition head" - ); - - let fub_end_nalu = Bytes::from_static(&[FUB_NALU_TYPE, FU_END_BITMASK]); - assert!( - !h264.is_partition_head(&fub_end_nalu), - "fub end nalu must not be a partition head" - ); - - Ok(()) -} - -#[test] -fn test_h264_payloader_payload_sps_and_pps_handling() -> Result<()> { - let mut pck = H264Payloader::default(); - let expected = vec![ - Bytes::from_static(&[ - 0x78, 0x00, 0x03, 0x07, 0x00, 0x01, 0x00, 0x03, 0x08, 0x02, 0x03, - ]), - Bytes::from_static(&[0x05, 0x04, 0x05]), - ]; - - // When packetizing SPS and PPS are emitted with following NALU - let res = pck.payload(1500, &Bytes::from_static(&[0x07, 0x00, 0x01]))?; - assert!(res.is_empty(), "Generated payload should be empty"); - - let res = pck.payload(1500, &Bytes::from_static(&[0x08, 0x02, 0x03]))?; - assert!(res.is_empty(), "Generated payload should be empty"); - - let actual = pck.payload(1500, &Bytes::from_static(&[0x05, 0x04, 0x05]))?; - assert_eq!(actual, expected, "SPS and PPS aren't packed together"); - - Ok(()) -} diff --git a/rtp/src/codecs/h264/mod.rs b/rtp/src/codecs/h264/mod.rs deleted file mode 100644 index 104b6ddde..000000000 --- a/rtp/src/codecs/h264/mod.rs +++ /dev/null @@ -1,310 +0,0 @@ -#[cfg(test)] -mod h264_test; - -use bytes::{BufMut, Bytes, BytesMut}; - -use crate::error::{Error, Result}; -use crate::packetizer::{Depacketizer, Payloader}; - -/// H264Payloader payloads H264 packets -#[derive(Default, Debug, Clone)] -pub struct H264Payloader { - sps_nalu: Option, - pps_nalu: Option, -} - -pub const STAPA_NALU_TYPE: u8 = 24; -pub const FUA_NALU_TYPE: u8 = 28; -pub const FUB_NALU_TYPE: u8 = 29; -pub const SPS_NALU_TYPE: u8 = 7; -pub const PPS_NALU_TYPE: u8 = 8; -pub const AUD_NALU_TYPE: u8 = 9; -pub const FILLER_NALU_TYPE: u8 = 12; - -pub const FUA_HEADER_SIZE: usize = 2; -pub const STAPA_HEADER_SIZE: usize = 1; -pub const STAPA_NALU_LENGTH_SIZE: usize = 2; - -pub const NALU_TYPE_BITMASK: u8 = 0x1F; -pub const NALU_REF_IDC_BITMASK: u8 = 0x60; -pub const FU_START_BITMASK: u8 = 0x80; -pub const FU_END_BITMASK: u8 = 0x40; - -pub const OUTPUT_STAP_AHEADER: u8 = 0x78; - -pub static ANNEXB_NALUSTART_CODE: Bytes = Bytes::from_static(&[0x00, 0x00, 0x00, 0x01]); - -impl H264Payloader { - fn next_ind(nalu: &Bytes, start: usize) -> (isize, isize) { - let mut zero_count = 0; - - for (i, &b) in nalu[start..].iter().enumerate() { - if b == 0 { - zero_count += 1; - continue; - } else if b == 1 && zero_count >= 2 { - return ((start + i - zero_count) as isize, zero_count as isize + 1); - } - zero_count = 0 - } - (-1, -1) - } - - fn emit(&mut self, nalu: &Bytes, mtu: usize, payloads: &mut Vec) { - if nalu.is_empty() { - return; - } - - let nalu_type = nalu[0] & NALU_TYPE_BITMASK; - let nalu_ref_idc = nalu[0] & NALU_REF_IDC_BITMASK; - - if nalu_type == AUD_NALU_TYPE || nalu_type == FILLER_NALU_TYPE { - return; - } else if nalu_type == SPS_NALU_TYPE { - self.sps_nalu = Some(nalu.clone()); - return; - } else if nalu_type == PPS_NALU_TYPE { - self.pps_nalu = Some(nalu.clone()); - return; - } else if let (Some(sps_nalu), Some(pps_nalu)) = (&self.sps_nalu, &self.pps_nalu) { - // Pack current NALU with SPS and PPS as STAP-A - let sps_len = (sps_nalu.len() as u16).to_be_bytes(); - let pps_len = (pps_nalu.len() as u16).to_be_bytes(); - - let mut stap_a_nalu = Vec::with_capacity(1 + 2 + sps_nalu.len() + 2 + pps_nalu.len()); - stap_a_nalu.push(OUTPUT_STAP_AHEADER); - stap_a_nalu.extend(sps_len); - stap_a_nalu.extend_from_slice(sps_nalu); - stap_a_nalu.extend(pps_len); - stap_a_nalu.extend_from_slice(pps_nalu); - if stap_a_nalu.len() <= mtu { - payloads.push(Bytes::from(stap_a_nalu)); - } - } - - if self.sps_nalu.is_some() && self.pps_nalu.is_some() { - self.sps_nalu = None; - self.pps_nalu = None; - } - - // Single NALU - if nalu.len() <= mtu { - payloads.push(nalu.clone()); - return; - } - - // FU-A - let max_fragment_size = mtu as isize - FUA_HEADER_SIZE as isize; - - // The FU payload consists of fragments of the payload of the fragmented - // NAL unit so that if the fragmentation unit payloads of consecutive - // FUs are sequentially concatenated, the payload of the fragmented NAL - // unit can be reconstructed. The NAL unit type octet of the fragmented - // NAL unit is not included as such in the fragmentation unit payload, - // but rather the information of the NAL unit type octet of the - // fragmented NAL unit is conveyed in the F and NRI fields of the FU - // indicator octet of the fragmentation unit and in the type field of - // the FU header. An FU payload MAY have any number of octets and MAY - // be empty. - - let nalu_data = nalu; - // According to the RFC, the first octet is skipped due to redundant information - let mut nalu_data_index = 1; - let nalu_data_length = nalu.len() as isize - nalu_data_index; - let mut nalu_data_remaining = nalu_data_length; - - if std::cmp::min(max_fragment_size, nalu_data_remaining) <= 0 { - return; - } - - while nalu_data_remaining > 0 { - let current_fragment_size = std::cmp::min(max_fragment_size, nalu_data_remaining); - //out: = make([]byte, fuaHeaderSize + currentFragmentSize) - let mut out = BytesMut::with_capacity(FUA_HEADER_SIZE + current_fragment_size as usize); - // +---------------+ - // |0|1|2|3|4|5|6|7| - // +-+-+-+-+-+-+-+-+ - // |F|NRI| Type | - // +---------------+ - let b0 = FUA_NALU_TYPE | nalu_ref_idc; - out.put_u8(b0); - - // +---------------+ - //|0|1|2|3|4|5|6|7| - //+-+-+-+-+-+-+-+-+ - //|S|E|R| Type | - //+---------------+ - - let mut b1 = nalu_type; - if nalu_data_remaining == nalu_data_length { - // Set start bit - b1 |= 1 << 7; - } else if nalu_data_remaining - current_fragment_size == 0 { - // Set end bit - b1 |= 1 << 6; - } - out.put_u8(b1); - - out.put( - &nalu_data - [nalu_data_index as usize..(nalu_data_index + current_fragment_size) as usize], - ); - payloads.push(out.freeze()); - - nalu_data_remaining -= current_fragment_size; - nalu_data_index += current_fragment_size; - } - } -} - -impl Payloader for H264Payloader { - /// Payload fragments a H264 packet across one or more byte arrays - fn payload(&mut self, mtu: usize, payload: &Bytes) -> Result> { - if payload.is_empty() || mtu == 0 { - return Ok(vec![]); - } - - let mut payloads = vec![]; - - let (mut next_ind_start, mut next_ind_len) = H264Payloader::next_ind(payload, 0); - if next_ind_start == -1 { - self.emit(payload, mtu, &mut payloads); - } else { - while next_ind_start != -1 { - let prev_start = (next_ind_start + next_ind_len) as usize; - let (next_ind_start2, next_ind_len2) = H264Payloader::next_ind(payload, prev_start); - next_ind_start = next_ind_start2; - next_ind_len = next_ind_len2; - if next_ind_start != -1 { - self.emit( - &payload.slice(prev_start..next_ind_start as usize), - mtu, - &mut payloads, - ); - } else { - // Emit until end of stream, no end indicator found - self.emit(&payload.slice(prev_start..), mtu, &mut payloads); - } - } - } - - Ok(payloads) - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} - -/// H264Packet represents the H264 header that is stored in the payload of an RTP Packet -#[derive(PartialEq, Eq, Debug, Default, Clone)] -pub struct H264Packet { - pub is_avc: bool, - fua_buffer: Option, -} - -impl Depacketizer for H264Packet { - /// depacketize parses the passed byte slice and stores the result in the H264Packet this method is called upon - fn depacketize(&mut self, packet: &Bytes) -> Result { - if packet.len() <= 2 { - return Err(Error::ErrShortPacket); - } - - let mut payload = BytesMut::new(); - - // NALU Types - // https://tools.ietf.org/html/rfc6184#section-5.4 - let b0 = packet[0]; - let nalu_type = b0 & NALU_TYPE_BITMASK; - - match nalu_type { - 1..=23 => { - if self.is_avc { - payload.put_u32(packet.len() as u32); - } else { - payload.put(&*ANNEXB_NALUSTART_CODE); - } - payload.put(&*packet.clone()); - Ok(payload.freeze()) - } - STAPA_NALU_TYPE => { - let mut curr_offset = STAPA_HEADER_SIZE; - while curr_offset < packet.len() { - let nalu_size = - ((packet[curr_offset] as usize) << 8) | packet[curr_offset + 1] as usize; - curr_offset += STAPA_NALU_LENGTH_SIZE; - - if packet.len() < curr_offset + nalu_size { - return Err(Error::StapASizeLargerThanBuffer( - nalu_size, - packet.len() - curr_offset, - )); - } - - if self.is_avc { - payload.put_u32(nalu_size as u32); - } else { - payload.put(&*ANNEXB_NALUSTART_CODE); - } - payload.put(&*packet.slice(curr_offset..curr_offset + nalu_size)); - curr_offset += nalu_size; - } - - Ok(payload.freeze()) - } - FUA_NALU_TYPE => { - if packet.len() < FUA_HEADER_SIZE { - return Err(Error::ErrShortPacket); - } - - if self.fua_buffer.is_none() { - self.fua_buffer = Some(BytesMut::new()); - } - - if let Some(fua_buffer) = &mut self.fua_buffer { - fua_buffer.put(&*packet.slice(FUA_HEADER_SIZE..)); - } - - let b1 = packet[1]; - if b1 & FU_END_BITMASK != 0 { - let nalu_ref_idc = b0 & NALU_REF_IDC_BITMASK; - let fragmented_nalu_type = b1 & NALU_TYPE_BITMASK; - - if let Some(fua_buffer) = self.fua_buffer.take() { - if self.is_avc { - payload.put_u32((fua_buffer.len() + 1) as u32); - } else { - payload.put(&*ANNEXB_NALUSTART_CODE); - } - payload.put_u8(nalu_ref_idc | fragmented_nalu_type); - payload.put(fua_buffer); - } - - Ok(payload.freeze()) - } else { - Ok(Bytes::new()) - } - } - _ => Err(Error::NaluTypeIsNotHandled(nalu_type)), - } - } - - /// is_partition_head checks if this is the head of a packetized nalu stream. - fn is_partition_head(&self, payload: &Bytes) -> bool { - if payload.len() < 2 { - return false; - } - - if payload[0] & NALU_TYPE_BITMASK == FUA_NALU_TYPE - || payload[0] & NALU_TYPE_BITMASK == FUB_NALU_TYPE - { - (payload[1] & FU_START_BITMASK) != 0 - } else { - true - } - } - - fn is_partition_tail(&self, marker: bool, _payload: &Bytes) -> bool { - marker - } -} diff --git a/rtp/src/codecs/h265/h265_test.rs b/rtp/src/codecs/h265/h265_test.rs deleted file mode 100644 index 6ace5ff89..000000000 --- a/rtp/src/codecs/h265/h265_test.rs +++ /dev/null @@ -1,889 +0,0 @@ -use super::*; - -#[test] -fn test_h265_nalu_header() -> Result<()> { - #[derive(Default)] - struct TestType { - raw_header: Bytes, - - fbit: bool, - typ: u8, - layer_id: u8, - tid: u8, - - is_ap: bool, - is_fu: bool, - is_paci: bool, - } - - let tests = vec![ - // fbit - TestType { - raw_header: Bytes::from_static(&[0x80, 0x00]), - typ: 0, - layer_id: 0, - tid: 0, - fbit: true, - ..Default::default() - }, - // VPS_NUT - TestType { - raw_header: Bytes::from_static(&[0x40, 0x01]), - typ: 32, - layer_id: 0, - tid: 1, - ..Default::default() - }, - // SPS_NUT - TestType { - raw_header: Bytes::from_static(&[0x42, 0x01]), - typ: 33, - layer_id: 0, - tid: 1, - ..Default::default() - }, - // PPS_NUT - TestType { - raw_header: Bytes::from_static(&[0x44, 0x01]), - typ: 34, - layer_id: 0, - tid: 1, - ..Default::default() - }, - // PREFIX_SEI_NUT - TestType { - raw_header: Bytes::from_static(&[0x4e, 0x01]), - typ: 39, - layer_id: 0, - tid: 1, - ..Default::default() - }, - // Fragmentation Unit - TestType { - raw_header: Bytes::from_static(&[0x62, 0x01]), - typ: H265NALU_FRAGMENTATION_UNIT_TYPE, - layer_id: 0, - tid: 1, - is_fu: true, - ..Default::default() - }, - ]; - - for cur in tests { - let header = H265NALUHeader::new(cur.raw_header[0], cur.raw_header[1]); - - assert_eq!(header.f(), cur.fbit, "invalid F bit"); - assert_eq!(header.nalu_type(), cur.typ, "invalid type"); - - // For any type < 32, NAL is a VLC NAL unit. - assert_eq!( - header.is_type_vcl_unit(), - (header.nalu_type() < 32), - "invalid IsTypeVCLUnit" - ); - assert_eq!( - header.is_aggregation_packet(), - cur.is_ap, - "invalid type (aggregation packet)" - ); - assert_eq!( - header.is_fragmentation_unit(), - cur.is_fu, - "invalid type (fragmentation unit)" - ); - assert_eq!(header.is_paci_packet(), cur.is_paci, "invalid type (PACI)"); - assert_eq!(header.layer_id(), cur.layer_id, "invalid layer_id"); - assert_eq!(header.tid(), cur.tid, "invalid tid"); - } - - Ok(()) -} - -#[test] -fn test_h265_fu_header() -> Result<()> { - #[derive(Default)] - struct TestType { - header: H265FragmentationUnitHeader, - - s: bool, - e: bool, - typ: u8, - } - - let tests = vec![ - // Start | IDR_W_RADL - TestType { - header: H265FragmentationUnitHeader(0x93), - s: true, - e: false, - typ: 19, - }, - // Continuation | IDR_W_RADL - TestType { - header: H265FragmentationUnitHeader(0x13), - s: false, - e: false, - typ: 19, - }, - // End | IDR_W_RADL - TestType { - header: H265FragmentationUnitHeader(0x53), - s: false, - e: true, - typ: 19, - }, - // Start | TRAIL_R - TestType { - header: H265FragmentationUnitHeader(0x81), - s: true, - e: false, - typ: 1, - }, - // Continuation | TRAIL_R - TestType { - header: H265FragmentationUnitHeader(0x01), - s: false, - e: false, - typ: 1, - }, - // End | TRAIL_R - TestType { - header: H265FragmentationUnitHeader(0x41), - s: false, - e: true, - typ: 1, - }, - ]; - - for cur in tests { - assert_eq!(cur.header.s(), cur.s, "invalid s field"); - assert_eq!(cur.header.e(), cur.e, "invalid e field"); - assert_eq!(cur.header.fu_type(), cur.typ, "invalid FuType field"); - } - - Ok(()) -} - -#[test] -fn test_h265_single_nalunit_packet() -> Result<()> { - #[derive(Default)] - struct TestType { - raw: Bytes, - with_donl: bool, - expected_packet: Option, - expected_err: Option, - } - - let tests = vec![ - TestType { - raw: Bytes::from_static(&[]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - // FBit enabled in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x80, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrH265CorruptedPacket), - ..Default::default() - }, - // Type '49' in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrInvalidH265PacketType), - ..Default::default() - }, - // Type '50' in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x64, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrInvalidH265PacketType), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x01, 0x01, 0xab, 0xcd, 0xef]), - expected_packet: Some(H265SingleNALUnitPacket { - payload_header: H265NALUHeader::new(0x01, 0x01), - payload: Bytes::from_static(&[0xab, 0xcd, 0xef]), - ..Default::default() - }), - ..Default::default() - }, - // DONL, payload too small - TestType { - raw: Bytes::from_static(&[0x01, 0x01, 0x93, 0xaf]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x01, 0x01, 0xaa, 0xbb, 0xcc]), - expected_packet: Some(H265SingleNALUnitPacket { - payload_header: H265NALUHeader::new(0x01, 0x01), - donl: Some((0xaa << 8) | 0xbb), - payload: Bytes::from_static(&[0xcc]), - ..Default::default() - }), - with_donl: true, - ..Default::default() - }, - ]; - - for cur in tests { - let mut parsed = H265SingleNALUnitPacket::default(); - if cur.with_donl { - parsed.with_donl(cur.with_donl); - } - - let result = parsed.depacketize(&cur.raw); - - if cur.expected_err.is_some() && result.is_ok() { - panic!("should error"); - } else if cur.expected_err.is_none() && result.is_err() { - panic!("should not error"); - } - - if let Some(expected_packet) = cur.expected_packet { - assert_eq!( - parsed.payload_header(), - expected_packet.payload_header(), - "invalid payload header" - ); - assert_eq!(parsed.donl(), expected_packet.donl(), "invalid DONL"); - - assert_eq!( - parsed.payload(), - expected_packet.payload(), - "invalid payload" - ); - } - } - - Ok(()) -} - -#[test] -fn test_h265_aggregation_packet() -> Result<()> { - #[derive(Default)] - struct TestType { - raw: Bytes, - with_donl: bool, - expected_packet: Option, - expected_err: Option, - } - - let tests = vec![ - TestType { - raw: Bytes::from_static(&[]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - // FBit enabled in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x80, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrH265CorruptedPacket), - ..Default::default() - }, - // Type '48' in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0xE0, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrInvalidH265PacketType), - ..Default::default() - }, - // Small payload - TestType { - raw: Bytes::from_static(&[0x60, 0x01, 0x00, 0x1]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - // Small payload - TestType { - raw: Bytes::from_static(&[0x60, 0x01, 0x00]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Small payload - TestType { - raw: Bytes::from_static(&[0x60, 0x01, 0x00, 0x1]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Small payload - TestType { - raw: Bytes::from_static(&[0x60, 0x01, 0x00, 0x01, 0x02]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Single Aggregation Unit - TestType { - raw: Bytes::from_static(&[0x60, 0x01, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Incomplete second Aggregation Unit - TestType { - raw: Bytes::from_static(&[ - 0x60, 0x01, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, // DONL - 0x00, - ]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Incomplete second Aggregation Unit - TestType { - raw: Bytes::from_static(&[ - 0x60, 0x01, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, - // DONL, NAL Unit size (2 bytes) - 0x00, 0x55, 0x55, - ]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Valid Second Aggregation Unit - TestType { - raw: Bytes::from_static(&[ - 0x60, 0x01, 0xcc, 0xdd, 0x00, 0x02, 0xff, 0xee, - // DONL, NAL Unit size (2 bytes), Payload - 0x77, 0x00, 0x01, 0xaa, - ]), - with_donl: true, - expected_packet: Some(H265AggregationPacket { - first_unit: Some(H265AggregationUnitFirst { - donl: Some(0xccdd), - nal_unit_size: 2, - nal_unit: Bytes::from_static(&[0xff, 0xee]), - }), - other_units: vec![H265AggregationUnit { - dond: Some(0x77), - nal_unit_size: 1, - nal_unit: Bytes::from_static(&[0xaa]), - }], - might_need_donl: false, - }), - ..Default::default() - }, - ]; - - for cur in tests { - let mut parsed = H265AggregationPacket::default(); - if cur.with_donl { - parsed.with_donl(cur.with_donl); - } - - let result = parsed.depacketize(&cur.raw); - - if cur.expected_err.is_some() && result.is_ok() { - panic!("should error"); - } else if cur.expected_err.is_none() && result.is_err() { - panic!("should not error"); - } - - if let Some(expected_packet) = cur.expected_packet { - if let (Some(first_unit), Some(parsed_first_unit)) = - (expected_packet.first_unit(), parsed.first_unit()) - { - assert_eq!( - parsed_first_unit.nal_unit_size, first_unit.nal_unit_size, - "invalid first unit NALUSize" - ); - assert_eq!( - parsed_first_unit.donl(), - first_unit.donl(), - "invalid first unit DONL" - ); - assert_eq!( - parsed_first_unit.nal_unit(), - first_unit.nal_unit(), - "invalid first unit NalUnit" - ); - } - - assert_eq!( - parsed.other_units().len(), - expected_packet.other_units().len(), - "number of other units mismatch" - ); - - for ndx in 0..expected_packet.other_units().len() { - assert_eq!( - parsed.other_units()[ndx].nalu_size(), - expected_packet.other_units()[ndx].nalu_size(), - "invalid unit NALUSize" - ); - - assert_eq!( - parsed.other_units()[ndx].dond(), - expected_packet.other_units()[ndx].dond(), - "invalid unit DOND" - ); - - assert_eq!( - parsed.other_units()[ndx].nal_unit(), - expected_packet.other_units()[ndx].nal_unit(), - "invalid first unit NalUnit" - ); - } - - assert_eq!( - parsed.other_units(), - expected_packet.other_units(), - "invalid payload" - ); - } - } - - Ok(()) -} - -#[test] -fn test_h265_fragmentation_unit_packet() -> Result<()> { - #[derive(Default)] - struct TestType { - raw: Bytes, - with_donl: bool, - expected_fu: Option, - expected_err: Option, - } - let tests = vec![ - TestType { - raw: Bytes::from_static(&[]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - // FBit enabled in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x80, 0x01, 0x93, 0xaf]), - expected_err: Some(Error::ErrH265CorruptedPacket), - ..Default::default() - }, - // Type not '49' in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x40, 0x01, 0x93, 0xaf]), - expected_err: Some(Error::ErrInvalidH265PacketType), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93, 0xaf]), - expected_fu: Some(H265FragmentationUnitPacket { - payload_header: H265NALUHeader::new(0x62, 0x01), - fu_header: H265FragmentationUnitHeader(0x93), - donl: None, - payload: Bytes::from_static(&[0xaf]), - might_need_donl: false, - }), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93, 0xcc]), - with_donl: true, - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93, 0xcc, 0xdd, 0xaf, 0x0d, 0x5a]), - with_donl: true, - expected_fu: Some(H265FragmentationUnitPacket { - payload_header: H265NALUHeader::new(0x62, 0x01), - fu_header: H265FragmentationUnitHeader(0x93), - donl: Some((0xcc << 8) | 0xdd), - payload: Bytes::from_static(&[0xaf, 0x0d, 0x5a]), - might_need_donl: false, - }), - ..Default::default() - }, - ]; - - for cur in tests { - let mut parsed = H265FragmentationUnitPacket::default(); - if cur.with_donl { - parsed.with_donl(cur.with_donl); - } - - let result = parsed.depacketize(&cur.raw); - - if cur.expected_err.is_some() && result.is_ok() { - panic!("should error"); - } else if cur.expected_err.is_none() && result.is_err() { - panic!("should not error"); - } - - if let Some(expected_fu) = &cur.expected_fu { - assert_eq!( - parsed.payload_header(), - expected_fu.payload_header(), - "invalid payload header" - ); - assert_eq!( - parsed.fu_header(), - expected_fu.fu_header(), - "invalid FU header" - ); - assert_eq!(parsed.donl(), expected_fu.donl(), "invalid DONL"); - assert_eq!(parsed.payload(), expected_fu.payload(), "invalid Payload"); - } - } - - Ok(()) -} - -#[test] -fn test_h265_temporal_scalability_control_information() -> Result<()> { - #[derive(Default)] - struct TestType { - value: H265TSCI, - expected_tl0picidx: u8, - expected_irap_pic_id: u8, - expected_s: bool, - expected_e: bool, - expected_res: u8, - } - - let tests = vec![ - TestType { - value: H265TSCI(((0xCA) << 24) | ((0xFE) << 16)), - expected_tl0picidx: 0xCA, - expected_irap_pic_id: 0xFE, - ..Default::default() - }, - TestType { - value: H265TSCI((1) << 15), - expected_s: true, - ..Default::default() - }, - TestType { - value: H265TSCI((1) << 14), - expected_e: true, - ..Default::default() - }, - TestType { - value: H265TSCI((0x0A) << 8), - expected_res: 0x0A, - ..Default::default() - }, - // Sets RES, and force sets S and E to 0. - TestType { - value: H265TSCI(((0xAA) << 8) & (u32::MAX ^ ((1) << 15)) & (u32::MAX ^ ((1) << 14))), - expected_res: 0xAA & 0b00111111, - ..Default::default() - }, - ]; - - for cur in tests { - assert_eq!( - cur.value.tl0picidx(), - cur.expected_tl0picidx, - "invalid TL0PICIDX" - ); - assert_eq!( - cur.value.irap_pic_id(), - cur.expected_irap_pic_id, - "invalid IrapPicID" - ); - assert_eq!(cur.value.s(), cur.expected_s, "invalid S"); - assert_eq!(cur.value.e(), cur.expected_e, "invalid E"); - assert_eq!(cur.value.res(), cur.expected_res, "invalid RES"); - } - - Ok(()) -} - -#[test] -fn test_h265_paci_packet() -> Result<()> { - #[derive(Default)] - struct TestType { - raw: Bytes, - expected_fu: Option, - expected_err: Option, - } - - let tests = vec![ - TestType { - raw: Bytes::from_static(&[]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - // FBit enabled in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x80, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrH265CorruptedPacket), - ..Default::default() - }, - // Type not '50' in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x40, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrInvalidH265PacketType), - ..Default::default() - }, - // Invalid header extension size - TestType { - raw: Bytes::from_static(&[0x64, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrInvalidH265PacketType), - ..Default::default() - }, - // No Header Extension - TestType { - raw: Bytes::from_static(&[0x64, 0x01, 0x64, 0x00, 0xab, 0xcd, 0xef]), - expected_fu: Some(H265PACIPacket { - payload_header: H265NALUHeader::new(0x64, 0x01), - paci_header_fields: ((0x64) << 8), - phes: Bytes::from_static(&[]), - payload: Bytes::from_static(&[0xab, 0xcd, 0xef]), - }), - ..Default::default() - }, - // Header Extension 1 byte - TestType { - raw: Bytes::from_static(&[0x64, 0x01, 0x64, 0x10, 0xff, 0xab, 0xcd, 0xef]), - expected_fu: Some(H265PACIPacket { - payload_header: H265NALUHeader::new(0x64, 0x01), - paci_header_fields: ((0x64) << 8) | (0x10), - phes: Bytes::from_static(&[0xff]), - payload: Bytes::from_static(&[0xab, 0xcd, 0xef]), - }), - ..Default::default() - }, - // Header Extension TSCI - TestType { - raw: Bytes::from_static(&[ - 0x64, 0x01, 0x64, 0b00111000, 0xaa, 0xbb, 0x80, 0xab, 0xcd, 0xef, - ]), - expected_fu: Some(H265PACIPacket { - payload_header: H265NALUHeader::new(0x64, 0x01), - paci_header_fields: ((0x64) << 8) | (0b00111000), - phes: Bytes::from_static(&[0xaa, 0xbb, 0x80]), - payload: Bytes::from_static(&[0xab, 0xcd, 0xef]), - }), - ..Default::default() - }, - ]; - - for cur in tests { - let mut parsed = H265PACIPacket::default(); - - let result = parsed.depacketize(&cur.raw); - - if cur.expected_err.is_some() && result.is_ok() { - panic!("should error"); - } else if cur.expected_err.is_none() && result.is_err() { - panic!("should not error"); - } - - if let Some(expected_fu) = &cur.expected_fu { - assert_eq!( - parsed.payload_header(), - expected_fu.payload_header(), - "invalid PayloadHeader" - ); - assert_eq!(parsed.a(), expected_fu.a(), "invalid A"); - assert_eq!(parsed.ctype(), expected_fu.ctype(), "invalid CType"); - assert_eq!(parsed.phs_size(), expected_fu.phs_size(), "invalid PHSsize"); - assert_eq!(parsed.f0(), expected_fu.f0(), "invalid F0"); - assert_eq!(parsed.f1(), expected_fu.f1(), "invalid F1"); - assert_eq!(parsed.f2(), expected_fu.f2(), "invalid F2"); - assert_eq!(parsed.y(), expected_fu.y(), "invalid Y"); - assert_eq!(parsed.phes(), expected_fu.phes(), "invalid PHES"); - assert_eq!(parsed.payload(), expected_fu.payload(), "invalid Payload"); - assert_eq!(parsed.tsci(), expected_fu.tsci(), "invalid TSCI"); - } - } - - Ok(()) -} - -#[test] -fn test_h265_packet() -> Result<()> { - #[derive(Default)] - struct TestType { - raw: Bytes, - with_donl: bool, - expected_packet_type: Option, - expected_err: Option, - } - let tests = vec![ - TestType { - raw: Bytes::from_static(&[]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x64, 0x01, 0x93, 0xaf]), - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - TestType { - raw: Bytes::from_static(&[0x01, 0x01]), - with_donl: true, - expected_err: Some(Error::ErrShortPacket), - ..Default::default() - }, - // FBit enabled in H265NALUHeader - TestType { - raw: Bytes::from_static(&[0x80, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf]), - expected_err: Some(Error::ErrH265CorruptedPacket), - ..Default::default() - }, - // Valid H265SingleNALUnitPacket - TestType { - raw: Bytes::from_static(&[0x01, 0x01, 0xab, 0xcd, 0xef]), - expected_packet_type: Some(H265Payload::H265SingleNALUnitPacket( - H265SingleNALUnitPacket::default(), - )), - ..Default::default() - }, - // Invalid H265SingleNALUnitPacket - TestType { - raw: Bytes::from_static(&[0x01, 0x01, 0x93, 0xaf]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - // Valid H265PACIPacket - TestType { - raw: Bytes::from_static(&[ - 0x64, 0x01, 0x64, 0b00111000, 0xaa, 0xbb, 0x80, 0xab, 0xcd, 0xef, - ]), - expected_packet_type: Some(H265Payload::H265PACIPacket(H265PACIPacket::default())), - ..Default::default() - }, - // Valid H265FragmentationUnitPacket - TestType { - raw: Bytes::from_static(&[0x62, 0x01, 0x93, 0xcc, 0xdd, 0xaf, 0x0d, 0x5a]), - expected_packet_type: Some(H265Payload::H265FragmentationUnitPacket( - H265FragmentationUnitPacket::default(), - )), - with_donl: true, - ..Default::default() - }, - // Valid H265AggregationPacket - TestType { - raw: Bytes::from_static(&[ - 0x60, 0x01, 0xcc, 0xdd, 0x00, 0x02, 0xff, 0xee, 0x77, 0x00, 0x01, 0xaa, - ]), - expected_packet_type: Some(H265Payload::H265AggregationPacket( - H265AggregationPacket::default(), - )), - with_donl: true, - ..Default::default() - }, - // Invalid H265AggregationPacket - TestType { - raw: Bytes::from_static(&[0x60, 0x01, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00]), - expected_err: Some(Error::ErrShortPacket), - with_donl: true, - ..Default::default() - }, - ]; - - for cur in tests { - let mut pck = H265Packet::default(); - if cur.with_donl { - pck.with_donl(true); - } - - let result = pck.depacketize(&cur.raw); - - if cur.expected_err.is_some() && result.is_ok() { - panic!("should error"); - } else if cur.expected_err.is_none() && result.is_err() { - panic!("should not error"); - } - - if cur.expected_err.is_some() { - continue; - } - - if let Some(expected_packet_type) = &cur.expected_packet_type { - //TODO: assert_eq!(pck.packet(), expected_packet_type, "invalid packet type"); - let pck_packet = pck.payload(); - match (pck_packet, expected_packet_type) { - ( - &H265Payload::H265SingleNALUnitPacket(_), - &H265Payload::H265SingleNALUnitPacket(_), - ) => {} - ( - &H265Payload::H265FragmentationUnitPacket(_), - &H265Payload::H265FragmentationUnitPacket(_), - ) => {} - ( - &H265Payload::H265AggregationPacket(_), - &H265Payload::H265AggregationPacket(_), - ) => {} - (&H265Payload::H265PACIPacket(_), &H265Payload::H265PACIPacket(_)) => {} - _ => panic!(), - }; - } - } - - Ok(()) -} - -#[test] -fn test_h265_packet_real() -> Result<()> { - // Tests decoding of real H265 payloads extracted from a Wireshark dump. - let tests = vec![ - b"\x40\x01\x0c\x01\xff\xff\x01\x60\x00\x00\x03\x00\xb0\x00\x00\x03\x00\x00\x03\x00\x7b\xac\x09".to_vec(), - b"\x42\x01\x01\x01\x60\x00\x00\x03\x00\xb0\x00\x00\x03\x00\x00\x03\x00\x7b\xa0\x03\xc0\x80\x10\xe5\x8d\xae\x49\x32\xf4\xdc\x04\x04\x04\x02".to_vec(), - b"\x44\x01\xc0\xf2\xf0\x3c\x90".to_vec(), - b"\x4e\x01\xe5\x04\x61\x0c\x00\x00\x80".to_vec(), - b"\x62\x01\x93\xaf\x0d\x5a\xfe\x67\x77\x29\xc0\x74\xf3\x57\x4c\x16\x94\xaa\x7c\x2a\x64\x5f\xe9\xa5\xb7\x2a\xa3\x95\x9d\x94\xa7\xb4\xd3\xc4\x4a\xb1\xb7\x69\xca\xbe\x75\xc5\x64\xa8\x97\x4b\x8a\xbf\x7e\xf0\x0f\xc3\x22\x60\x67\xab\xae\x96\xd6\x99\xca\x7a\x8d\x35\x93\x1a\x67\x60\xe7\xbe\x7e\x13\x95\x3c\xe0\x11\xc1\xc1\xa7\x48\xef\xf7\x7b\xb0\xeb\x35\x49\x81\x4e\x4e\x54\xf7\x31\x6a\x38\xa1\xa7\x0c\xd6\xbe\x3b\x25\xba\x08\x19\x0b\x49\xfd\x90\xbb\x73\x7a\x45\x8c\xb9\x73\x43\x04\xc5\x5f\xda\x0f\xd5\x70\x4c\x11\xee\x72\xb8\x6a\xb4\x95\x62\x64\xb6\x23\x14\x7e\xdb\x0e\xa5\x0f\x86\x31\xe4\xd1\x64\x56\x43\xf6\xb7\xe7\x1b\x93\x4a\xeb\xd0\xa6\xe3\x1f\xce\xda\x15\x67\x05\xb6\x77\x36\x8b\x27\x5b\xc6\xf2\x95\xb8\x2b\xcc\x9b\x0a\x03\x05\xbe\xc3\xd3\x85\xf5\x69\xb6\x19\x1f\x63\x2d\x8b\x65\x9e\xc3\x9d\xd2\x44\xb3\x7c\x86\x3b\xea\xa8\x5d\x02\xe5\x40\x03\x20\x76\x48\xff\xf6\x2b\x0d\x18\xd6\x4d\x49\x70\x1a\x5e\xb2\x89\xca\xec\x71\x41\x79\x4e\x94\x17\x0c\x57\x51\x55\x14\x61\x40\x46\x4b\x3e\x17\xb2\xc8\xbd\x1c\x06\x13\x91\x72\xf8\xc8\xfc\x6f\xb0\x30\x9a\xec\x3b\xa6\xc9\x33\x0b\xa5\xe5\xf4\x65\x7a\x29\x8b\x76\x62\x81\x12\xaf\x20\x4c\xd9\x21\x23\x9e\xeb\xc9\x0e\x5b\x29\x35\x7f\x41\xcd\xce\xa1\xc4\xbe\x01\x30\xb9\x11\xc3\xb1\xe4\xce\x45\xd2\x5c\xb3\x1e\x69\x78\xba\xb1\x72\xe4\x88\x54\xd8\x5d\xd0\xa8\x3a\x74\xad\xe5\xc7\xc1\x59\x7c\x78\x15\x26\x37\x3d\x50\xae\xb3\xa4\x5b\x6c\x7d\x65\x66\x85\x4d\x16\x9a\x67\x74\xad\x55\x32\x3a\x84\x85\x0b\x6a\xeb\x24\x97\xb4\x20\x4d\xca\x41\x61\x7a\xd1\x7b\x60\xdb\x7f\xd5\x61\x22\xcf\xd1\x7e\x4c\xf3\x85\xfd\x13\x63\xe4\x9d\xed\xac\x13\x0a\xa0\x92\xb7\x34\xde\x65\x0f\xd9\x0f\x9b\xac\xe2\x47\xe8\x5c\xb3\x11\x8e\xc6\x08\x19\xd0\xb0\x85\x52\xc8\x5c\x1b\x08\x0a\xce\xc9\x6b\xa7\xef\x95\x2f\xd0\xb8\x63\xe5\x4c\xd4\xed\x6e\x87\xe9\xd4\x0a\xe6\x11\x44\x63\x00\x94\x18\xe9\x28\xba\xcf\x92\x43\x06\x59\xdd\x37\x4f\xd3\xef\x9d\x31\x5e\x9b\x48\xf9\x1f\x3e\x7b\x95\x3a\xbd\x1f\x71\x55\x0c\x06\xf9\x86\xf8\x3d\x39\x16\x50\xb3\x21\x11\x19\x6f\x70\xa9\x48\xe8\xbb\x0a\x11\x23\xf8\xab\xfe\x44\xe0\xbb\xe8\x64\xfa\x85\xe4\x02\x55\x88\x41\xc6\x30\x7f\x10\xad\x75\x02\x4b\xef\xe1\x0b\x06\x3c\x10\x49\x83\xf9\xd1\x3e\x3e\x67\x86\x4c\xf8\x9d\xde\x5a\xc4\xc8\xcf\xb6\xf4\xb0\xd3\x34\x58\xd4\x7b\x4d\xd3\x37\x63\xb2\x48\x8a\x7e\x20\x00\xde\xb4\x42\x8f\xda\xe9\x43\x9e\x0c\x16\xce\x79\xac\x2c\x70\xc1\x89\x05\x36\x62\x6e\xd9\xbc\xfb\x63\xc6\x79\x89\x3c\x90\x89\x2b\xd1\x8c\xe0\xc2\x54\xc7\xd6\xb4\xe8\x9e\x96\x55\x6e\x7b\xd5\x7f\xac\xd4\xa7\x1c\xa0\xdf\x01\x30\xad\xc0\x9f\x69\x06\x10\x43\x7f\xf4\x5d\x62\xa3\xea\x73\xf2\x14\x79\x19\x13\xea\x59\x14\x79\xa8\xe7\xce\xce\x44\x25\x13\x41\x18\x57\xdd\xce\xe4\xbe\xcc\x20\x80\x29\x71\x73\xa7\x7c\x86\x39\x76\xf4\xa7\x1c\x63\x24\x21\x93\x1e\xb5\x9a\x5c\x8a\x9e\xda\x8b\x9d\x88\x97\xfc\x98\x7d\x26\x74\x04\x1f\xa8\x10\x4f\x45\xcd\x46\xe8\x28\xe4\x8e\x59\x67\x63\x4a\xcf\x1e\xed\xdd\xbb\x79\x2f\x8d\x94\xab\xfc\xdb\xc5\x79\x1a\x4d\xcd\x53\x41\xdf\xd1\x7a\x8f\x46\x3e\x1f\x79\x88\xe3\xee\x9f\xc4\xc1\xe6\x2e\x89\x4d\x28\xc9\xca\x28\xc2\x0a\xc5\xc7\xf1\x22\xcd\xb3\x36\xfa\xe3\x7e\xa6\xcd\x95\x55\x5e\x0e\x1a\x75\x7f\x65\x27\xd3\x37\x4f\x23\xc5\xab\x49\x68\x4e\x02\xb5\xbf\xd7\x95\xc0\x78\x67\xbc\x1a\xe9\xae\x6f\x44\x58\x8a\xc2\xce\x42\x98\x4e\x77\xc7\x2a\xa0\xa7\x7d\xe4\x3b\xd1\x20\x82\x1a\xd3\xe2\xc7\x76\x5d\x06\x46\xb5\x24\xd7\xfb\x57\x63\x2b\x19\x51\x48\x65\x6d\xfb\xe0\x98\xd1\x14\x0e\x17\x64\x29\x34\x6f\x6e\x66\x9e\x8d\xc9\x89\x49\x69\xee\x74\xf3\x35\xe6\x8b\x67\x56\x95\x7f\x1b\xe9\xed\x8c\x0f\xe2\x19\x59\xbf\x03\x35\x55\x3c\x04\xbc\x40\x52\x90\x10\x08\xad\xa7\x65\xe0\x31\xcb\xcf\x3d\xd4\x62\x68\x01\x0d\xed\xf5\x28\x64\x2d\xaa\x7c\x99\x15\x8d\x70\x32\x53\xb8\x9d\x0a\x3c\xbf\x91\x02\x04\xd0\xee\x87\xce\x04\xcc\x3e\xa8\x20\xfd\x97\xdf\xbf\x4a\xbc\xfc\xc9\x7c\x77\x21\xcc\x23\x6f\x59\x38\xd8\xd9\xa0\x0e\xb1\x23\x4e\x04\x3f\x14\x9e\xcc\x05\x54\xab\x20\x69\xed\xa4\xd5\x1d\xb4\x1b\x52\xed\x6a\xea\xeb\x7f\xd1\xbc\xfd\x75\x20\xa0\x1c\x59\x8c\x5a\xa1\x2a\x70\x64\x11\xb1\x7b\xc1\x24\x80\x28\x51\x4c\x94\xa1\x95\x64\x72\xe8\x90\x67\x38\x74\x2b\xab\x38\x46\x12\x71\xce\x19\x98\x98\xf7\x89\xd4\xfe\x2f\x2a\xc5\x61\x20\xd0\xa4\x1a\x51\x3c\x82\xc8\x18\x31\x7a\x10\xe8\x1c\xc6\x95\x5a\xa0\x82\x88\xce\x8f\x4b\x47\x85\x7e\x89\x95\x95\x52\x1e\xac\xce\x45\x57\x61\x38\x97\x2b\x62\xa5\x14\x6f\xc3\xaa\x6c\x35\x83\xc9\xa3\x1e\x30\x89\xf4\xb1\xea\x4f\x39\xde\xde\xc7\x46\x5c\x0e\x85\x41\xec\x6a\xa4\xcb\xee\x70\x9c\x57\xd9\xf4\xa1\xc3\x9c\x2a\x0a\xf0\x5d\x58\xb0\xae\xd4\xdc\xc5\x6a\xa8\x34\xfa\x23\xef\xef\x08\x39\xc3\x3d\xea\x11\x6e\x6a\xe0\x1e\xd0\x52\xa8\xc3\x6e\xc9\x1c\xfc\xd0\x0c\x4c\xea\x0d\x82\xcb\xdd\x29\x1a\xc4\x4f\x6e\xa3\x4d\xcb\x7a\x38\x77\xe5\x15\x6e\xad\xfa\x9d\x2f\x02\xb6\x39\x84\x3a\x60\x8f\x71\x9f\x92\xe5\x24\x4f\xbd\x18\x49\xd5\xef\xbf\x70\xfb\xd1\x4c\x2e\xfc\x2f\x36\xf3\x00\x31\x2e\x90\x18\xcc\xf4\x71\xb9\xe4\xf9\xbe\xcb\x5e\xff\xf3\xe7\xf8\xca\x03\x60\x66\xb3\xc9\x5a\xf9\x74\x09\x02\x57\xb6\x90\x94\xfc\x41\x35\xdc\x35\x3f\x32\x7a\xa6\xa5\xcd\x8a\x8f\xc8\x3d\xc8\x81\xc3\xec\x37\x74\x86\x61\x41\x0d\xc5\xe2\xc8\x0c\x84\x2b\x3b\x71\x58\xde\x1b\xe3\x20\x65\x2e\x76\xf4\x98\xd8\xaa\x78\xe6\xeb\xb8\x85\x0d\xa0\xd0\xf5\x57\x64\x01\x58\x55\x82\xd5\x0f\x2d\x9c\x3e\x2a\xa0\x7e\xaf\x42\xf3\x37\xd1\xb3\xaf\xda\x5b\xa9\xda\xe3\x89\x5d\xf1\xca\xa5\x12\x3d\xe7\x91\x95\x53\x21\x72\xca\x7f\xf6\x79\x59\x21\xcf\x30\x18\xfb\x78\x55\x40\x59\xc3\xf9\xf1\xdd\x58\x44\x5e\x83\x11\x5c\x2d\x1d\x91\xf6\x01\x3d\x3f\xd4\x33\x81\x66\x6c\x40\x7a\x9d\x70\x10\x58\xe6\x53\xad\x85\x11\x99\x3e\x4b\xbc\x31\xc6\x78\x9d\x79\xc5\xde\x9f\x2e\x43\xfa\x76\x84\x2f\xfd\x28\x75\x12\x48\x25\xfd\x15\x8c\x29\x6a\x91\xa4\x63\xc0\xa2\x8c\x41\x3c\xf1\xb0\xf8\xdf\x66\xeb\xbd\x14\x88\xa9\x81\xa7\x35\xc4\x41\x40\x6c\x10\x3f\x09\xbd\xb5\xd3\x7a\xee\x4b\xd5\x86\xff\x36\x03\x6b\x78\xde".to_vec(), - b"\x62\x01\x53\x8a\xe9\x25\xe1\x06\x09\x8e\xba\x12\x74\x87\x09\x9a\x95\xe4\x86\x62\x2b\x4b\xf9\xa6\x2e\x7b\x35\x43\xf7\x39\x99\x0f\x3b\x6f\xfd\x1a\x6e\x23\x54\x70\xb5\x1d\x10\x1c\x63\x40\x96\x99\x41\xb6\x96\x0b\x70\x98\xec\x17\xb0\xaa\xdc\x4a\xab\xe8\x3b\xb7\x6b\x00\x1c\x5b\xc3\xe0\xa2\x8b\x7c\x17\xc8\x92\xc9\xb0\x92\xb6\x70\x84\x95\x30".to_vec(), - b"\x4e\x01\xe5\x04\x35\xac\x00\x00\x80".to_vec(), - b"\x62\x01\x41\xb0\x75\x5c\x27\x46\xef\x8a\xe7\x1d\x50\x38\xb2\x13\x33\xe0\x79\x35\x1b\xc2\xb5\x79\x73\xe7\xc2\x6f\xb9\x1a\x8c\x21\x0e\xa9\x54\x17\x6c\x41\xab\xc8\x16\x57\xec\x5e\xeb\x89\x3b\xa9\x90\x8c\xff\x4d\x46\x8b\xf0\xd9\xc0\xd0\x51\xcf\x8b\x88\xf1\x5f\x1e\x9e\xc1\xb9\x1f\xe3\x06\x45\x35\x8a\x47\xe8\x9a\xf2\x4f\x19\x4c\xf8\xce\x68\x1b\x63\x34\x11\x75\xea\xe5\xb1\x0f\x38\xcc\x05\x09\x8b\x3e\x2b\x88\x84\x9d\xc5\x03\xc3\xc0\x90\x32\xe2\x45\x69\xb1\xe5\xf7\x68\x6b\x16\x90\xa0\x40\xe6\x18\x74\xd8\x68\xf3\x34\x38\x99\xf2\x6c\xb7\x1a\x35\x21\xca\x52\x56\x4c\x7f\xb2\xa3\xd5\xb8\x40\x50\x48\x3e\xdc\xdf\x0b\xf5\x54\x5a\x15\x1a\xe2\xc3\xb4\x94\xda\x3f\xb5\x34\xa2\xca\xbc\x2f\xe0\xa4\xe5\x69\xf4\xbf\x62\x4d\x15\x21\x1b\x11\xfc\x39\xaa\x86\x74\x96\x63\xfd\x07\x53\x26\xf6\x34\x72\xeb\x14\x37\x98\x0d\xf4\x68\x91\x2c\x6b\x46\x83\x88\x82\x04\x8b\x9f\xb8\x32\x73\x75\x8b\xf9\xac\x71\x42\xd1\x2d\xb4\x28\x28\xf5\x78\xe0\x32\xf3\xe1\xfc\x43\x6b\xf9\x92\xf7\x48\xfe\x7f\xc0\x17\xbd\xfd\xba\x2f\x58\x6f\xee\x84\x03\x18\xce\xb0\x9d\x8d\xeb\x22\xf1\xfc\xb1\xcf\xff\x2f\xb2\x9f\x6c\xe5\xb4\x69\xdc\xdd\x20\x93\x00\x30\xad\x56\x04\x66\x7e\xa3\x3c\x18\x4b\x43\x66\x00\x27\x1e\x1c\x09\x11\xd8\xf4\x8a\x9e\xc5\x6a\x94\xe5\xae\x0b\x8a\xbe\x84\xda\xe5\x44\x7f\x38\x1c\xe7\xbb\x03\x19\x66\xe1\x5d\x1d\xc1\xbd\x3d\xc6\xb7\xe3\xff\x7f\x8e\xff\x1e\xf6\x9e\x6f\x58\x27\x74\x65\xef\x02\x5d\xa4\xde\x27\x7f\x51\xe3\x4b\x9e\x3f\x79\x83\xbd\x1b\x8f\x0d\x77\xfb\xbc\xc5\x9f\x15\xa7\x4e\x05\x8a\x24\x97\x66\xb2\x7c\xf6\xe1\x84\x54\xdb\x39\x5e\xf6\x1b\x8f\x05\x73\x1d\xb6\x8e\xd7\x09\x9a\xc5\x92\x80".to_vec(), - ]; - - for cur in tests { - let mut pck = H265Packet::default(); - let _ = pck.depacketize(&Bytes::from(cur))?; - } - - Ok(()) -} diff --git a/rtp/src/codecs/h265/mod.rs b/rtp/src/codecs/h265/mod.rs deleted file mode 100644 index 8faeb6256..000000000 --- a/rtp/src/codecs/h265/mod.rs +++ /dev/null @@ -1,1020 +0,0 @@ -use bytes::{BufMut, Bytes, BytesMut}; - -use super::h264::ANNEXB_NALUSTART_CODE; -use crate::error::{Error, Result}; -use crate::packetizer::{Depacketizer, Payloader}; - -#[cfg(test)] -mod h265_test; - -pub static ANNEXB_3_NALUSTART_CODE: Bytes = Bytes::from_static(&[0x00, 0x00, 0x01]); -pub static SING_PAYLOAD_HDR: Bytes = Bytes::from_static(&[0x1C, 0x01]); -pub static AGGR_PAYLOAD_HDR: Bytes = Bytes::from_static(&[0x60, 0x01]); -pub static FRAG_PAYLOAD_HDR: Bytes = Bytes::from_static(&[0x62, 0x01]); -pub static FU_HDR_IDR_S: u8 = 0x93; -pub static FU_HDR_IDR_M: u8 = 0x13; -pub static FU_HDR_IDR_E: u8 = 0x53; -pub static FU_HDR_P_S: u8 = 0x81; -pub static FU_HDR_P_M: u8 = 0x01; -pub static FU_HDR_P_E: u8 = 0x41; -pub static FU_HDR_B_S: u8 = 0x80; -pub static FU_HDR_B_M: u8 = 0x00; -pub static FU_HDR_B_E: u8 = 0x40; -pub const RTP_OUTBOUND_MTU: usize = 1200; -pub const H265FRAGMENTATION_UNIT_HEADER_SIZE: usize = 1; -pub const NAL_HEADER_SIZE: usize = 2; - -#[derive(PartialEq, Hash, Debug, Copy, Clone)] -pub enum UnitType { - VPS = 32, - SPS = 33, - PPS = 34, - CRA = 21, - SEI = 39, - IDR = 19, - PFR = 1, - BFR = 0, - IGNORE = -1, -} -impl UnitType { - pub fn for_id(id: u8) -> Result { - if id > 64 { - Err(Error::ErrUnhandledNaluType) - } else { - let t = match id { - 32 => UnitType::VPS, - 33 => UnitType::SPS, - 34 => UnitType::PPS, - 21 => UnitType::CRA, - 39 => UnitType::SEI, - 19 => UnitType::IDR, - 1 => UnitType::PFR, - 0 => UnitType::BFR, - _ => UnitType::IGNORE, // shouldn't happen - }; - Ok(t) - } - } -} - -#[derive(Default, Debug, Clone)] -pub struct HevcPayloader { - vps_nalu: Option, - sps_nalu: Option, - pps_nalu: Option, -} - -impl HevcPayloader { - pub fn parse(nalu: &Bytes) -> (Vec, usize) { - let finder = memchr::memmem::Finder::new(&ANNEXB_NALUSTART_CODE); - let nals = finder.find_iter(nalu).collect::>(); - if nals.is_empty() { - let finder = memchr::memmem::Finder::new(&ANNEXB_3_NALUSTART_CODE); - return (finder.find_iter(nalu).collect::>(), 3); - } - (nals, 4) - } - - fn emit(&mut self, nalu: &Bytes, mtu: usize, payloads: &mut Vec) { - if nalu.is_empty() { - return; - } - let payload_header = H265NALUHeader::new(nalu[0], nalu[1]); - let payload_nalu_type = payload_header.nalu_type(); - let nalu_type = UnitType::for_id(payload_nalu_type).unwrap_or(UnitType::IGNORE); - if nalu_type == UnitType::IGNORE { - return; - } else if nalu_type == UnitType::VPS { - self.vps_nalu.replace(nalu.clone()); - } else if nalu_type == UnitType::SPS { - self.sps_nalu.replace(nalu.clone()); - } else if nalu_type == UnitType::PPS { - self.pps_nalu.replace(nalu.clone()); - } - if let (Some(vps_nalu), Some(sps_nalu), Some(pps_nalu)) = - (&self.vps_nalu, &self.sps_nalu, &self.pps_nalu) - { - // Pack current NALU with SPS and PPS as STAP-A - let vps_len = (vps_nalu.len() as u16).to_be_bytes(); - let sps_len = (sps_nalu.len() as u16).to_be_bytes(); - let pps_len = (pps_nalu.len() as u16).to_be_bytes(); - - // TODO DONL not impl yet - let mut aggr_nalu = BytesMut::new(); - aggr_nalu.extend_from_slice(&AGGR_PAYLOAD_HDR); - aggr_nalu.extend_from_slice(&vps_len); - aggr_nalu.extend_from_slice(vps_nalu); - aggr_nalu.extend_from_slice(&sps_len); - aggr_nalu.extend_from_slice(sps_nalu); - aggr_nalu.extend_from_slice(&pps_len); - aggr_nalu.extend_from_slice(pps_nalu); - if aggr_nalu.len() <= mtu { - payloads.push(Bytes::from(aggr_nalu)); - self.vps_nalu.take(); - self.sps_nalu.take(); - self.pps_nalu.take(); - return; - } - } else if nalu_type == UnitType::VPS - || nalu_type == UnitType::SPS - || nalu_type == UnitType::PPS - { - return; - } - // if self.sps_nalu.is_some() && self.pps_nalu.is_some() { - // self.sps_nalu = None; - // self.pps_nalu = None; - // } - - // Single NALU - if nalu.len() <= mtu { - payloads.push(nalu.clone()); - return; - } - let max_fragment_size = - mtu as isize - NAL_HEADER_SIZE as isize - H265FRAGMENTATION_UNIT_HEADER_SIZE as isize; - let nalu_data = nalu; - let mut nalu_data_index = 2; - let nalu_data_length = nalu.len() as isize - nalu_data_index; - let mut nalu_data_remaining = nalu_data_length; - if std::cmp::min(max_fragment_size, nalu_data_remaining) <= 0 { - return; - } - while nalu_data_remaining > 0 { - let current_fragment_size = std::cmp::min(max_fragment_size, nalu_data_remaining); - //out: = make([]byte, fuaHeaderSize + currentFragmentSize) - let mut out = BytesMut::with_capacity( - H265FRAGMENTATION_UNIT_HEADER_SIZE + current_fragment_size as usize, - ); - out.extend_from_slice(&FRAG_PAYLOAD_HDR); - let is_first = nalu_data_index == 2; - let is_last = !is_first && current_fragment_size < max_fragment_size; - /* - +---------------+ - |0|1|2|3|4|5|6|7| - +-+-+-+-+-+-+-+-+ - |S|E| fu_type | - +---------------+ - */ - if nalu_type == UnitType::IDR { - if is_first { - out.put_u8(FU_HDR_IDR_S); - } else if is_last { - out.put_u8(FU_HDR_IDR_E); - } else { - out.put_u8(FU_HDR_IDR_M); - } - } else if nalu_type == UnitType::PFR { - if is_first { - out.put_u8(FU_HDR_P_S); - } else if is_last { - out.put_u8(FU_HDR_P_E); - } else { - out.put_u8(FU_HDR_P_M); - } - } else if nalu_type == UnitType::BFR { - if is_first { - out.put_u8(FU_HDR_B_S); - } else if is_last { - out.put_u8(FU_HDR_B_E); - } else { - out.put_u8(FU_HDR_B_M); - } - } - - out.extend_from_slice( - &nalu_data - [nalu_data_index as usize..(nalu_data_index + current_fragment_size) as usize], - ); - // println!("pkt payload {:?}", &out[0..5]); - payloads.push(out.freeze()); - - nalu_data_remaining -= current_fragment_size; - nalu_data_index += current_fragment_size; - } - } -} - -impl Payloader for HevcPayloader { - /// Payload fragments a H264 packet across one or more byte arrays - fn payload(&mut self, mtu: usize, payload: &Bytes) -> Result> { - if payload.is_empty() || mtu == 0 { - return Ok(vec![]); - } - - let mut payloads = vec![]; - - let (nal_idxs, offset) = HevcPayloader::parse(payload); - let nal_len = nal_idxs.len(); - for (i, start) in nal_idxs.iter().enumerate() { - let end = if (i + 1) < nal_len { - nal_idxs[i + 1] - } else { - payload.len() - }; - // println!( - // "start {}, end {} payload {:?}", - // start, - // end, - // &payload - // .slice((start + offset)..(start + offset + 5)) - // .to_vec() - // ); - self.emit(&payload.slice((start + offset)..end), mtu, &mut payloads); - } - - Ok(payloads) - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} - -/// -/// Network Abstraction Unit Header implementation -/// - -const H265NALU_HEADER_SIZE: usize = 2; -/// -const H265NALU_AGGREGATION_PACKET_TYPE: u8 = 48; -/// -const H265NALU_FRAGMENTATION_UNIT_TYPE: u8 = 49; -/// -const H265NALU_PACI_PACKET_TYPE: u8 = 50; - -/// H265NALUHeader is a H265 NAL Unit Header -/// -/// +---------------+---------------+ -/// |0|1|2|3|4|5|6|7|0|1|2|3|4|5|6|7| -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// |F| Type | layer_id | tid | -/// +-------------+-----------------+ -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub struct H265NALUHeader(pub u16); - -impl H265NALUHeader { - pub fn new(high_byte: u8, low_byte: u8) -> Self { - H265NALUHeader(((high_byte as u16) << 8) | low_byte as u16) - } - - /// f is the forbidden bit, should always be 0. - pub fn f(&self) -> bool { - (self.0 >> 15) != 0 - } - - /// nalu_type of NAL Unit. - pub fn nalu_type(&self) -> u8 { - // 01111110 00000000 - const MASK: u16 = 0b01111110 << 8; - ((self.0 & MASK) >> (8 + 1)) as u8 - } - - /// is_type_vcl_unit returns whether or not the NAL Unit type is a VCL NAL unit. - pub fn is_type_vcl_unit(&self) -> bool { - // Type is coded on 6 bits - const MSB_MASK: u8 = 0b00100000; - (self.nalu_type() & MSB_MASK) == 0 - } - - /// layer_id should always be 0 in non-3D HEVC context. - pub fn layer_id(&self) -> u8 { - // 00000001 11111000 - const MASK: u16 = (0b00000001 << 8) | 0b11111000; - ((self.0 & MASK) >> 3) as u8 - } - - /// tid is the temporal identifier of the NAL unit +1. - pub fn tid(&self) -> u8 { - const MASK: u16 = 0b00000111; - (self.0 & MASK) as u8 - } - - /// is_aggregation_packet returns whether or not the packet is an Aggregation packet. - pub fn is_aggregation_packet(&self) -> bool { - self.nalu_type() == H265NALU_AGGREGATION_PACKET_TYPE - } - - /// is_fragmentation_unit returns whether or not the packet is a Fragmentation Unit packet. - pub fn is_fragmentation_unit(&self) -> bool { - self.nalu_type() == H265NALU_FRAGMENTATION_UNIT_TYPE - } - - /// is_paci_packet returns whether or not the packet is a PACI packet. - pub fn is_paci_packet(&self) -> bool { - self.nalu_type() == H265NALU_PACI_PACKET_TYPE - } -} - -/// -/// Single NAL Unit Packet implementation -/// -/// H265SingleNALUnitPacket represents a NALU packet, containing exactly one NAL unit. -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | PayloadHdr | DONL (conditional) | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | | -/// | NAL unit payload data | -/// | | -/// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | :...OPTIONAL RTP padding | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Reference: -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265SingleNALUnitPacket { - /// payload_header is the header of the H265 packet. - payload_header: H265NALUHeader, - /// donl is a 16-bit field, that may or may not be present. - donl: Option, - /// payload of the fragmentation unit. - payload: Bytes, - - might_need_donl: bool, -} - -impl H265SingleNALUnitPacket { - /// with_donl can be called to specify whether or not DONL might be parsed. - /// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. - pub fn with_donl(&mut self, value: bool) { - self.might_need_donl = value; - } - - /// depacketize parses the passed byte slice and stores the result in the H265SingleNALUnitPacket this method is called upon. - fn depacketize(&mut self, payload: &Bytes) -> Result<()> { - if payload.len() <= H265NALU_HEADER_SIZE { - return Err(Error::ErrShortPacket); - } - - let payload_header = H265NALUHeader::new(payload[0], payload[1]); - if payload_header.f() { - return Err(Error::ErrH265CorruptedPacket); - } - if payload_header.is_fragmentation_unit() - || payload_header.is_paci_packet() - || payload_header.is_aggregation_packet() - { - return Err(Error::ErrInvalidH265PacketType); - } - - let mut payload = payload.slice(2..); - - if self.might_need_donl { - // sizeof(uint16) - if payload.len() <= 2 { - return Err(Error::ErrShortPacket); - } - - let donl = ((payload[0] as u16) << 8) | (payload[1] as u16); - self.donl = Some(donl); - payload = payload.slice(2..); - } - - self.payload_header = payload_header; - self.payload = payload; - - Ok(()) - } - - /// payload_header returns the NALU header of the packet. - pub fn payload_header(&self) -> H265NALUHeader { - self.payload_header - } - - /// donl returns the DONL of the packet. - pub fn donl(&self) -> Option { - self.donl - } - - /// payload returns the Fragmentation Unit packet payload. - pub fn payload(&self) -> Bytes { - self.payload.clone() - } -} - -/// -/// Aggregation Packets implementation -/// -/// H265AggregationUnitFirst represent the First Aggregation Unit in an AP. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// : DONL (conditional) | NALU size | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | NALU size | | -/// +-+-+-+-+-+-+-+-+ NAL unit | -/// | | -/// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | : -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Reference: -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265AggregationUnitFirst { - donl: Option, - nal_unit_size: u16, - nal_unit: Bytes, -} - -impl H265AggregationUnitFirst { - /// donl field, when present, specifies the value of the 16 least - /// significant bits of the decoding order number of the aggregated NAL - /// unit. - pub fn donl(&self) -> Option { - self.donl - } - - /// nalu_size represents the size, in bytes, of the nal_unit. - pub fn nalu_size(&self) -> u16 { - self.nal_unit_size - } - - /// nal_unit payload. - pub fn nal_unit(&self) -> Bytes { - self.nal_unit.clone() - } -} - -/// H265AggregationUnit represent the an Aggregation Unit in an AP, which is not the first one. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// : DOND (cond) | NALU size | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | | -/// | NAL unit | -/// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | : -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Reference: -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265AggregationUnit { - dond: Option, - nal_unit_size: u16, - nal_unit: Bytes, -} - -impl H265AggregationUnit { - /// dond field plus 1 specifies the difference between - /// the decoding order number values of the current aggregated NAL unit - /// and the preceding aggregated NAL unit in the same AP. - pub fn dond(&self) -> Option { - self.dond - } - - /// nalu_size represents the size, in bytes, of the nal_unit. - pub fn nalu_size(&self) -> u16 { - self.nal_unit_size - } - - /// nal_unit payload. - pub fn nal_unit(&self) -> Bytes { - self.nal_unit.clone() - } -} - -/// H265AggregationPacket represents an Aggregation packet. -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | PayloadHdr (Type=48) | | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | -/// | | -/// | two or more aggregation units | -/// | | -/// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | :...OPTIONAL RTP padding | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Reference: -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265AggregationPacket { - first_unit: Option, - other_units: Vec, - - might_need_donl: bool, -} - -impl H265AggregationPacket { - /// with_donl can be called to specify whether or not DONL might be parsed. - /// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. - pub fn with_donl(&mut self, value: bool) { - self.might_need_donl = value; - } - - /// depacketize parses the passed byte slice and stores the result in the H265AggregationPacket this method is called upon. - fn depacketize(&mut self, payload: &Bytes) -> Result<()> { - if payload.len() <= H265NALU_HEADER_SIZE { - return Err(Error::ErrShortPacket); - } - - let payload_header = H265NALUHeader::new(payload[0], payload[1]); - if payload_header.f() { - return Err(Error::ErrH265CorruptedPacket); - } - if !payload_header.is_aggregation_packet() { - return Err(Error::ErrInvalidH265PacketType); - } - - // First parse the first aggregation unit - let mut payload = payload.slice(2..); - let mut first_unit = H265AggregationUnitFirst::default(); - - if self.might_need_donl { - if payload.len() < 2 { - return Err(Error::ErrShortPacket); - } - - let donl = ((payload[0] as u16) << 8) | (payload[1] as u16); - first_unit.donl = Some(donl); - - payload = payload.slice(2..); - } - if payload.len() < 2 { - return Err(Error::ErrShortPacket); - } - first_unit.nal_unit_size = ((payload[0] as u16) << 8) | (payload[1] as u16); - payload = payload.slice(2..); - - if payload.len() < first_unit.nal_unit_size as usize { - return Err(Error::ErrShortPacket); - } - - first_unit.nal_unit = payload.slice(..first_unit.nal_unit_size as usize); - payload = payload.slice(first_unit.nal_unit_size as usize..); - - // Parse remaining Aggregation Units - let mut units = vec![]; //H265AggregationUnit - loop { - let mut unit = H265AggregationUnit::default(); - - if self.might_need_donl { - if payload.is_empty() { - break; - } - - let dond = payload[0]; - unit.dond = Some(dond); - - payload = payload.slice(1..); - } - - if payload.len() < 2 { - break; - } - unit.nal_unit_size = ((payload[0] as u16) << 8) | (payload[1] as u16); - payload = payload.slice(2..); - - if payload.len() < unit.nal_unit_size as usize { - break; - } - - unit.nal_unit = payload.slice(..unit.nal_unit_size as usize); - payload = payload.slice(unit.nal_unit_size as usize..); - - units.push(unit); - } - - // There need to be **at least** two Aggregation Units (first + another one) - if units.is_empty() { - return Err(Error::ErrShortPacket); - } - - self.first_unit = Some(first_unit); - self.other_units = units; - - Ok(()) - } - - /// first_unit returns the first Aggregated Unit of the packet. - pub fn first_unit(&self) -> Option<&H265AggregationUnitFirst> { - self.first_unit.as_ref() - } - - /// other_units returns the all the other Aggregated Unit of the packet (excluding the first one). - pub fn other_units(&self) -> &[H265AggregationUnit] { - self.other_units.as_slice() - } -} - -/// -/// Fragmentation Unit implementation -/// - -/// H265FragmentationUnitHeader is a H265 FU Header -/// +---------------+ -/// |0|1|2|3|4|5|6|7| -/// +-+-+-+-+-+-+-+-+ -/// |S|E| fu_type | -/// +---------------+ -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub struct H265FragmentationUnitHeader(pub u8); - -impl H265FragmentationUnitHeader { - /// s represents the start of a fragmented NAL unit. - pub fn s(&self) -> bool { - const MASK: u8 = 0b10000000; - ((self.0 & MASK) >> 7) != 0 - } - - /// e represents the end of a fragmented NAL unit. - pub fn e(&self) -> bool { - const MASK: u8 = 0b01000000; - ((self.0 & MASK) >> 6) != 0 - } - - /// fu_type MUST be equal to the field Type of the fragmented NAL unit. - pub fn fu_type(&self) -> u8 { - const MASK: u8 = 0b00111111; - self.0 & MASK - } -} - -/// H265FragmentationUnitPacket represents a single Fragmentation Unit packet. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | PayloadHdr (Type=49) | FU header | DONL (cond) | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-| -/// | DONL (cond) | | -/// |-+-+-+-+-+-+-+-+ | -/// | FU payload | -/// | | -/// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | :...OPTIONAL RTP padding | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Reference: -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265FragmentationUnitPacket { - /// payload_header is the header of the H265 packet. - payload_header: H265NALUHeader, - /// fu_header is the header of the fragmentation unit - fu_header: H265FragmentationUnitHeader, - /// donl is a 16-bit field, that may or may not be present. - donl: Option, - /// payload of the fragmentation unit. - payload: Bytes, - - might_need_donl: bool, -} - -impl H265FragmentationUnitPacket { - /// with_donl can be called to specify whether or not DONL might be parsed. - /// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. - pub fn with_donl(&mut self, value: bool) { - self.might_need_donl = value; - } - - /// depacketize parses the passed byte slice and stores the result in the H265FragmentationUnitPacket this method is called upon. - fn depacketize(&mut self, payload: &Bytes) -> Result<()> { - const TOTAL_HEADER_SIZE: usize = H265NALU_HEADER_SIZE + H265FRAGMENTATION_UNIT_HEADER_SIZE; - if payload.len() <= TOTAL_HEADER_SIZE { - return Err(Error::ErrShortPacket); - } - - let payload_header = H265NALUHeader::new(payload[0], payload[1]); - if payload_header.f() { - return Err(Error::ErrH265CorruptedPacket); - } - if !payload_header.is_fragmentation_unit() { - return Err(Error::ErrInvalidH265PacketType); - } - - let fu_header = H265FragmentationUnitHeader(payload[2]); - let mut payload = payload.slice(3..); - - if fu_header.s() && self.might_need_donl { - if payload.len() <= 2 { - return Err(Error::ErrShortPacket); - } - - let donl = ((payload[0] as u16) << 8) | (payload[1] as u16); - self.donl = Some(donl); - payload = payload.slice(2..); - } - - self.payload_header = payload_header; - self.fu_header = fu_header; - self.payload = payload; - - Ok(()) - } - - /// payload_header returns the NALU header of the packet. - pub fn payload_header(&self) -> H265NALUHeader { - self.payload_header - } - - /// fu_header returns the Fragmentation Unit Header of the packet. - pub fn fu_header(&self) -> H265FragmentationUnitHeader { - self.fu_header - } - - /// donl returns the DONL of the packet. - pub fn donl(&self) -> Option { - self.donl - } - - /// payload returns the Fragmentation Unit packet payload. - pub fn payload(&self) -> Bytes { - self.payload.clone() - } -} - -/// -/// PACI implementation -/// - -/// H265PACIPacket represents a single H265 PACI packet. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | PayloadHdr (Type=50) |A| cType | phssize |F0..2|Y| -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | payload Header Extension Structure (phes) | -/// |=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=| -/// | | -/// | PACI payload: NAL unit | -/// | . . . | -/// | | -/// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | :...OPTIONAL RTP padding | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Reference: -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265PACIPacket { - /// payload_header is the header of the H265 packet. - payload_header: H265NALUHeader, - - /// Field which holds value for `A`, `cType`, `phssize`, `F0`, `F1`, `F2` and `Y` fields. - paci_header_fields: u16, - - /// phes is a header extension, of byte length `phssize` - phes: Bytes, - - /// payload contains NAL units & optional padding - payload: Bytes, -} - -impl H265PACIPacket { - /// payload_header returns the NAL Unit Header. - pub fn payload_header(&self) -> H265NALUHeader { - self.payload_header - } - - /// a copies the F bit of the PACI payload NALU. - pub fn a(&self) -> bool { - const MASK: u16 = 0b10000000 << 8; - (self.paci_header_fields & MASK) != 0 - } - - /// ctype copies the Type field of the PACI payload NALU. - pub fn ctype(&self) -> u8 { - const MASK: u16 = 0b01111110 << 8; - ((self.paci_header_fields & MASK) >> (8 + 1)) as u8 - } - - /// phs_size indicates the size of the phes field. - pub fn phs_size(&self) -> u8 { - const MASK: u16 = (0b00000001 << 8) | 0b11110000; - ((self.paci_header_fields & MASK) >> 4) as u8 - } - - /// f0 indicates the presence of a Temporal Scalability support extension in the phes. - pub fn f0(&self) -> bool { - const MASK: u16 = 0b00001000; - (self.paci_header_fields & MASK) != 0 - } - - /// f1 must be zero, reserved for future extensions. - pub fn f1(&self) -> bool { - const MASK: u16 = 0b00000100; - (self.paci_header_fields & MASK) != 0 - } - - /// f2 must be zero, reserved for future extensions. - pub fn f2(&self) -> bool { - const MASK: u16 = 0b00000010; - (self.paci_header_fields & MASK) != 0 - } - - /// y must be zero, reserved for future extensions. - pub fn y(&self) -> bool { - const MASK: u16 = 0b00000001; - (self.paci_header_fields & MASK) != 0 - } - - /// phes contains header extensions. Its size is indicated by phssize. - pub fn phes(&self) -> Bytes { - self.phes.clone() - } - - /// payload is a single NALU or NALU-like struct, not including the first two octets (header). - pub fn payload(&self) -> Bytes { - self.payload.clone() - } - - /// tsci returns the Temporal Scalability Control Information extension, if present. - pub fn tsci(&self) -> Option { - if !self.f0() || self.phs_size() < 3 { - return None; - } - - Some(H265TSCI( - ((self.phes[0] as u32) << 16) | ((self.phes[1] as u32) << 8) | self.phes[0] as u32, - )) - } - - /// depacketize parses the passed byte slice and stores the result in the H265PACIPacket this method is called upon. - fn depacketize(&mut self, payload: &Bytes) -> Result<()> { - const TOTAL_HEADER_SIZE: usize = H265NALU_HEADER_SIZE + 2; - if payload.len() <= TOTAL_HEADER_SIZE { - return Err(Error::ErrShortPacket); - } - - let payload_header = H265NALUHeader::new(payload[0], payload[1]); - if payload_header.f() { - return Err(Error::ErrH265CorruptedPacket); - } - if !payload_header.is_paci_packet() { - return Err(Error::ErrInvalidH265PacketType); - } - - let paci_header_fields = ((payload[2] as u16) << 8) | (payload[3] as u16); - let mut payload = payload.slice(4..); - - self.paci_header_fields = paci_header_fields; - let header_extension_size = self.phs_size(); - - if payload.len() < header_extension_size as usize + 1 { - self.paci_header_fields = 0; - return Err(Error::ErrShortPacket); - } - - self.payload_header = payload_header; - - if header_extension_size > 0 { - self.phes = payload.slice(..header_extension_size as usize); - } - - payload = payload.slice(header_extension_size as usize..); - self.payload = payload; - - Ok(()) - } -} - -/// -/// Temporal Scalability Control Information -/// - -/// H265TSCI is a Temporal Scalability Control Information header extension. -/// Reference: -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub struct H265TSCI(pub u32); - -impl H265TSCI { - /// tl0picidx see RFC7798 for more details. - pub fn tl0picidx(&self) -> u8 { - const M1: u32 = 0xFFFF0000; - const M2: u32 = 0xFF00; - ((((self.0 & M1) >> 16) & M2) >> 8) as u8 - } - - /// irap_pic_id see RFC7798 for more details. - pub fn irap_pic_id(&self) -> u8 { - const M1: u32 = 0xFFFF0000; - const M2: u32 = 0x00FF; - (((self.0 & M1) >> 16) & M2) as u8 - } - - /// s see RFC7798 for more details. - pub fn s(&self) -> bool { - const M1: u32 = 0xFF00; - const M2: u32 = 0b10000000; - (((self.0 & M1) >> 8) & M2) != 0 - } - - /// e see RFC7798 for more details. - pub fn e(&self) -> bool { - const M1: u32 = 0xFF00; - const M2: u32 = 0b01000000; - (((self.0 & M1) >> 8) & M2) != 0 - } - - /// res see RFC7798 for more details. - pub fn res(&self) -> u8 { - const M1: u32 = 0xFF00; - const M2: u32 = 0b00111111; - (((self.0 & M1) >> 8) & M2) as u8 - } -} - -/// -/// H265 Payload Enum -/// -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum H265Payload { - H265SingleNALUnitPacket(H265SingleNALUnitPacket), - H265FragmentationUnitPacket(H265FragmentationUnitPacket), - H265AggregationPacket(H265AggregationPacket), - H265PACIPacket(H265PACIPacket), -} - -impl Default for H265Payload { - fn default() -> Self { - H265Payload::H265SingleNALUnitPacket(H265SingleNALUnitPacket::default()) - } -} - -/// -/// Packet implementation -/// - -/// H265Packet represents a H265 packet, stored in the payload of an RTP packet. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct H265Packet { - payload: H265Payload, - might_need_donl: bool, -} - -impl H265Packet { - /// with_donl can be called to specify whether or not DONL might be parsed. - /// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. - pub fn with_donl(&mut self, value: bool) { - self.might_need_donl = value; - } - - /// payload returns the populated payload. - /// Must be casted to one of: - /// - H265SingleNALUnitPacket - /// - H265FragmentationUnitPacket - /// - H265AggregationPacket - /// - H265PACIPacket - pub fn payload(&self) -> &H265Payload { - &self.payload - } -} - -impl Depacketizer for H265Packet { - /// depacketize parses the passed byte slice and stores the result in the H265Packet this method is called upon - fn depacketize(&mut self, payload: &Bytes) -> Result { - if payload.len() <= H265NALU_HEADER_SIZE { - return Err(Error::ErrShortPacket); - } - - let payload_header = H265NALUHeader::new(payload[0], payload[1]); - if payload_header.f() { - return Err(Error::ErrH265CorruptedPacket); - } - - if payload_header.is_paci_packet() { - let mut decoded = H265PACIPacket::default(); - decoded.depacketize(payload)?; - - self.payload = H265Payload::H265PACIPacket(decoded); - } else if payload_header.is_fragmentation_unit() { - let mut decoded = H265FragmentationUnitPacket::default(); - decoded.with_donl(self.might_need_donl); - - decoded.depacketize(payload)?; - - self.payload = H265Payload::H265FragmentationUnitPacket(decoded); - } else if payload_header.is_aggregation_packet() { - let mut decoded = H265AggregationPacket::default(); - decoded.with_donl(self.might_need_donl); - - decoded.depacketize(payload)?; - - self.payload = H265Payload::H265AggregationPacket(decoded); - } else { - let mut decoded = H265SingleNALUnitPacket::default(); - decoded.with_donl(self.might_need_donl); - - decoded.depacketize(payload)?; - - self.payload = H265Payload::H265SingleNALUnitPacket(decoded); - } - - Ok(payload.clone()) - } - - /// is_partition_head checks if this is the head of a packetized nalu stream. - fn is_partition_head(&self, _payload: &Bytes) -> bool { - //TODO: - true - } - - fn is_partition_tail(&self, marker: bool, _payload: &Bytes) -> bool { - marker - } -} diff --git a/rtp/src/codecs/mod.rs b/rtp/src/codecs/mod.rs deleted file mode 100644 index 0296e20c0..000000000 --- a/rtp/src/codecs/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod av1; -pub mod g7xx; -pub mod h264; -pub mod h265; -pub mod opus; -pub mod vp8; -pub mod vp9; diff --git a/rtp/src/codecs/opus/mod.rs b/rtp/src/codecs/opus/mod.rs deleted file mode 100644 index 3ae991e6e..000000000 --- a/rtp/src/codecs/opus/mod.rs +++ /dev/null @@ -1,46 +0,0 @@ -#[cfg(test)] -mod opus_test; - -use bytes::Bytes; - -use crate::error::{Error, Result}; -use crate::packetizer::{Depacketizer, Payloader}; - -#[derive(Default, Debug, Copy, Clone)] -pub struct OpusPayloader; - -impl Payloader for OpusPayloader { - fn payload(&mut self, mtu: usize, payload: &Bytes) -> Result> { - if payload.is_empty() || mtu == 0 { - return Ok(vec![]); - } - - Ok(vec![payload.clone()]) - } - - fn clone_to(&self) -> Box { - Box::new(*self) - } -} - -/// OpusPacket represents the Opus header that is stored in the payload of an RTP Packet -#[derive(PartialEq, Eq, Debug, Default, Clone)] -pub struct OpusPacket; - -impl Depacketizer for OpusPacket { - fn depacketize(&mut self, packet: &Bytes) -> Result { - if packet.is_empty() { - Err(Error::ErrShortPacket) - } else { - Ok(packet.clone()) - } - } - - fn is_partition_head(&self, _payload: &Bytes) -> bool { - true - } - - fn is_partition_tail(&self, _marker: bool, _payload: &Bytes) -> bool { - true - } -} diff --git a/rtp/src/codecs/opus/opus_test.rs b/rtp/src/codecs/opus/opus_test.rs deleted file mode 100644 index a3966d0ff..000000000 --- a/rtp/src/codecs/opus/opus_test.rs +++ /dev/null @@ -1,51 +0,0 @@ -use super::*; - -#[test] -fn test_opus_unmarshal() -> Result<()> { - let mut pck = OpusPacket; - - // Empty packet - let empty_bytes = Bytes::from_static(&[]); - let result = pck.depacketize(&empty_bytes); - assert!(result.is_err(), "Result should be err in case of error"); - - // Normal packet - let raw_bytes = Bytes::from_static(&[0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x90]); - let payload = pck.depacketize(&raw_bytes)?; - assert_eq!(&raw_bytes, &payload, "Payload must be same"); - - Ok(()) -} - -#[test] -fn test_opus_payload() -> Result<()> { - let mut pck = OpusPayloader; - let empty = Bytes::from_static(&[]); - let payload = Bytes::from_static(&[0x90, 0x90, 0x90]); - - // Positive MTU, empty payload - let result = pck.payload(1, &empty)?; - assert!(result.is_empty(), "Generated payload should be empty"); - - // Positive MTU, small payload - let result = pck.payload(1, &payload)?; - assert_eq!(result.len(), 1, "Generated payload should be the 1"); - - // Positive MTU, small payload - let result = pck.payload(2, &payload)?; - assert_eq!(result.len(), 1, "Generated payload should be the 1"); - - Ok(()) -} - -#[test] -fn test_opus_is_partition_head() -> Result<()> { - let opus = OpusPacket; - //"NormalPacket" - assert!( - opus.is_partition_head(&Bytes::from_static(&[0x00, 0x00])), - "All OPUS RTP packet should be the head of a new partition" - ); - - Ok(()) -} diff --git a/rtp/src/codecs/vp8/mod.rs b/rtp/src/codecs/vp8/mod.rs deleted file mode 100644 index 20b2d3d04..000000000 --- a/rtp/src/codecs/vp8/mod.rs +++ /dev/null @@ -1,246 +0,0 @@ -#[cfg(test)] -mod vp8_test; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use crate::error::{Error, Result}; -use crate::packetizer::{Depacketizer, Payloader}; - -pub const VP8_HEADER_SIZE: usize = 1; - -/// Vp8Payloader payloads VP8 packets -#[derive(Default, Debug, Copy, Clone)] -pub struct Vp8Payloader { - pub enable_picture_id: bool, - picture_id: u16, -} - -impl Payloader for Vp8Payloader { - /// Payload fragments a VP8 packet across one or more byte arrays - fn payload(&mut self, mtu: usize, payload: &Bytes) -> Result> { - if payload.is_empty() || mtu == 0 { - return Ok(vec![]); - } - - /* - * https://tools.ietf.org/html/rfc7741#section-4.2 - * - * 0 1 2 3 4 5 6 7 - * +-+-+-+-+-+-+-+-+ - * |X|R|N|S|R| PID | (REQUIRED) - * +-+-+-+-+-+-+-+-+ - * X: |I|L|T|K| RSV | (OPTIONAL) - * +-+-+-+-+-+-+-+-+ - * I: |M| PictureID | (OPTIONAL) - * +-+-+-+-+-+-+-+-+ - * L: | tl0picidx | (OPTIONAL) - * +-+-+-+-+-+-+-+-+ - * T/K: |tid|Y| KEYIDX | (OPTIONAL) - * +-+-+-+-+-+-+-+-+ - * S: Start of VP8 partition. SHOULD be set to 1 when the first payload - * octet of the RTP packet is the beginning of a new VP8 partition, - * and MUST NOT be 1 otherwise. The S bit MUST be set to 1 for the - * first packet of each encoded frame. - */ - let using_header_size = if self.enable_picture_id { - if self.picture_id == 0 || self.picture_id < 128 { - VP8_HEADER_SIZE + 2 - } else { - VP8_HEADER_SIZE + 3 - } - } else { - VP8_HEADER_SIZE - }; - - let max_fragment_size = mtu as isize - using_header_size as isize; - let mut payload_data_remaining = payload.len() as isize; - let mut payload_data_index: usize = 0; - let mut payloads = vec![]; - - // Make sure the fragment/payload size is correct - if std::cmp::min(max_fragment_size, payload_data_remaining) <= 0 { - return Ok(payloads); - } - - let mut first = true; - while payload_data_remaining > 0 { - let current_fragment_size = - std::cmp::min(max_fragment_size, payload_data_remaining) as usize; - let mut out = BytesMut::with_capacity(using_header_size + current_fragment_size); - let mut buf = [0u8; 4]; - if first { - buf[0] = 0x10; - first = false; - } - - if self.enable_picture_id { - if using_header_size == VP8_HEADER_SIZE + 2 { - buf[0] |= 0x80; - buf[1] |= 0x80; - buf[2] |= (self.picture_id & 0x7F) as u8; - } else if using_header_size == VP8_HEADER_SIZE + 3 { - buf[0] |= 0x80; - buf[1] |= 0x80; - buf[2] |= 0x80 | ((self.picture_id >> 8) & 0x7F) as u8; - buf[3] |= (self.picture_id & 0xFF) as u8; - } - } - - out.put(&buf[..using_header_size]); - - out.put( - &*payload.slice(payload_data_index..payload_data_index + current_fragment_size), - ); - payloads.push(out.freeze()); - - payload_data_remaining -= current_fragment_size as isize; - payload_data_index += current_fragment_size; - } - - self.picture_id += 1; - self.picture_id &= 0x7FFF; - - Ok(payloads) - } - - fn clone_to(&self) -> Box { - Box::new(*self) - } -} - -/// Vp8Packet represents the VP8 header that is stored in the payload of an RTP Packet -#[derive(PartialEq, Eq, Debug, Default, Clone)] -pub struct Vp8Packet { - /// Required Header - /// extended controlbits present - pub x: u8, - /// when set to 1 this frame can be discarded - pub n: u8, - /// start of VP8 partition - pub s: u8, - /// partition index - pub pid: u8, - - /// Extended control bits - /// 1 if PictureID is present - pub i: u8, - /// 1 if tl0picidx is present - pub l: u8, - /// 1 if tid is present - pub t: u8, - /// 1 if KEYIDX is present - pub k: u8, - - /// Optional extension - /// 8 or 16 bits, picture ID - pub picture_id: u16, - /// 8 bits temporal level zero index - pub tl0_pic_idx: u8, - /// 2 bits temporal layer index - pub tid: u8, - /// 1 bit layer sync bit - pub y: u8, - /// 5 bits temporal key frame index - pub key_idx: u8, -} - -impl Depacketizer for Vp8Packet { - /// depacketize parses the passed byte slice and stores the result in the VP8Packet this method is called upon - fn depacketize(&mut self, packet: &Bytes) -> Result { - let payload_len = packet.len(); - if payload_len < 4 { - return Err(Error::ErrShortPacket); - } - // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - // +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ - // |X|R|N|S|R| PID | (REQUIRED) |X|R|N|S|R| PID | (REQUIRED) - // +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ - // X: |I|L|T|K| RSV | (OPTIONAL) X: |I|L|T|K| RSV | (OPTIONAL) - // +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ - // I: |M| PictureID | (OPTIONAL) I: |M| PictureID | (OPTIONAL) - // +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ - // L: | tl0picidx | (OPTIONAL) | PictureID | - // +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ - //T/K:|tid|Y| KEYIDX | (OPTIONAL) L: | tl0picidx | (OPTIONAL) - // +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ - //T/K:|tid|Y| KEYIDX | (OPTIONAL) - // +-+-+-+-+-+-+-+-+ - - let reader = &mut packet.clone(); - let mut payload_index = 0; - - let mut b = reader.get_u8(); - payload_index += 1; - - self.x = (b & 0x80) >> 7; - self.n = (b & 0x20) >> 5; - self.s = (b & 0x10) >> 4; - self.pid = b & 0x07; - - if self.x == 1 { - b = reader.get_u8(); - payload_index += 1; - self.i = (b & 0x80) >> 7; - self.l = (b & 0x40) >> 6; - self.t = (b & 0x20) >> 5; - self.k = (b & 0x10) >> 4; - } - - if self.i == 1 { - b = reader.get_u8(); - payload_index += 1; - // PID present? - if b & 0x80 > 0 { - // M == 1, PID is 16bit - self.picture_id = (((b & 0x7f) as u16) << 8) | (reader.get_u8() as u16); - payload_index += 1; - } else { - self.picture_id = b as u16; - } - } - - if payload_index >= payload_len { - return Err(Error::ErrShortPacket); - } - - if self.l == 1 { - self.tl0_pic_idx = reader.get_u8(); - payload_index += 1; - } - - if payload_index >= payload_len { - return Err(Error::ErrShortPacket); - } - - if self.t == 1 || self.k == 1 { - let b = reader.get_u8(); - if self.t == 1 { - self.tid = b >> 6; - self.y = (b >> 5) & 0x1; - } - if self.k == 1 { - self.key_idx = b & 0x1F; - } - payload_index += 1; - } - - if payload_index >= packet.len() { - return Err(Error::ErrShortPacket); - } - - Ok(packet.slice(payload_index..)) - } - - /// is_partition_head checks whether if this is a head of the VP8 partition - fn is_partition_head(&self, payload: &Bytes) -> bool { - if payload.is_empty() { - false - } else { - (payload[0] & 0x10) != 0 - } - } - - fn is_partition_tail(&self, marker: bool, _payload: &Bytes) -> bool { - marker - } -} diff --git a/rtp/src/codecs/vp8/vp8_test.rs b/rtp/src/codecs/vp8/vp8_test.rs deleted file mode 100644 index 32fd7343d..000000000 --- a/rtp/src/codecs/vp8/vp8_test.rs +++ /dev/null @@ -1,225 +0,0 @@ -use super::*; - -#[test] -fn test_vp8_unmarshal() -> Result<()> { - let mut pck = Vp8Packet::default(); - - // Empty packet - let empty_bytes = Bytes::from_static(&[]); - let result = pck.depacketize(&empty_bytes); - assert!(result.is_err(), "Result should be err in case of error"); - - // Payload smaller than header size - let small_bytes = Bytes::from_static(&[0x00, 0x11, 0x22]); - let result = pck.depacketize(&small_bytes); - assert!(result.is_err(), "Result should be err in case of error"); - - // Payload smaller than header size - let small_bytes = Bytes::from_static(&[0x00, 0x11]); - let result = pck.depacketize(&small_bytes); - assert!(result.is_err(), "Result should be err in case of error"); - - // Normal packet - let raw_bytes = Bytes::from_static(&[0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x90]); - let payload = pck.depacketize(&raw_bytes).expect("Normal packet"); - assert!(!payload.is_empty(), "Payload must be not empty"); - - // Header size, only X - let raw_bytes = Bytes::from_static(&[0x80, 0x00, 0x00, 0x00]); - let payload = pck.depacketize(&raw_bytes).expect("Only X"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 0, "I must be 0"); - assert_eq!(pck.l, 0, "L must be 0"); - assert_eq!(pck.t, 0, "T must be 0"); - assert_eq!(pck.k, 0, "K must be 0"); - - // Header size, X and I, PID 16bits - let raw_bytes = Bytes::from_static(&[0x80, 0x80, 0x81, 0x00, 0x00]); - let payload = pck.depacketize(&raw_bytes).expect("X and I, PID 16bits"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 1, "I must be 1"); - assert_eq!(pck.l, 0, "L must be 0"); - assert_eq!(pck.t, 0, "T must be 0"); - assert_eq!(pck.k, 0, "K must be 0"); - - // Header size, X and L - let raw_bytes = Bytes::from_static(&[0x80, 0x40, 0x00, 0x00]); - let payload = pck.depacketize(&raw_bytes).expect("X and L"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 0, "I must be 0"); - assert_eq!(pck.l, 1, "L must be 1"); - assert_eq!(pck.t, 0, "T must be 0"); - assert_eq!(pck.k, 0, "K must be 0"); - - // Header size, X and T - let raw_bytes = Bytes::from_static(&[0x80, 0x20, 0x00, 0x00]); - let payload = pck.depacketize(&raw_bytes).expect("X and T"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 0, "I must be 0"); - assert_eq!(pck.l, 0, "L must be 0"); - assert_eq!(pck.t, 1, "T must be 1"); - assert_eq!(pck.k, 0, "K must be 0"); - - // Header size, X and K - let raw_bytes = Bytes::from_static(&[0x80, 0x10, 0x00, 0x00]); - let payload = pck.depacketize(&raw_bytes).expect("X and K"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 0, "I must be 0"); - assert_eq!(pck.l, 0, "L must be 0"); - assert_eq!(pck.t, 0, "T must be 0"); - assert_eq!(pck.k, 1, "K must be 1"); - - // Header size, all flags and 8bit picture_id - let raw_bytes = Bytes::from_static(&[0xff, 0xff, 0x00, 0x00, 0x00, 0x00]); - let payload = pck - .depacketize(&raw_bytes) - .expect("all flags and 8bit picture_id"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 1, "I must be 1"); - assert_eq!(pck.l, 1, "L must be 1"); - assert_eq!(pck.t, 1, "T must be 1"); - assert_eq!(pck.k, 1, "K must be 1"); - - // Header size, all flags and 16bit picture_id - let raw_bytes = Bytes::from_static(&[0xff, 0xff, 0x80, 0x00, 0x00, 0x00, 0x00]); - let payload = pck - .depacketize(&raw_bytes) - .expect("all flags and 16bit picture_id"); - assert!(!payload.is_empty(), "Payload must be not empty"); - assert_eq!(pck.x, 1, "X must be 1"); - assert_eq!(pck.i, 1, "I must be 1"); - assert_eq!(pck.l, 1, "L must be 1"); - assert_eq!(pck.t, 1, "T must be 1"); - assert_eq!(pck.k, 1, "K must be 1"); - - Ok(()) -} - -#[test] -fn test_vp8_payload() -> Result<()> { - let tests = vec![ - ( - "WithoutPictureID", - Vp8Payloader::default(), - 2, - vec![ - Bytes::from_static(&[0x90, 0x90, 0x90]), - Bytes::from_static(&[0x91, 0x91]), - ], - vec![ - vec![ - Bytes::from_static(&[0x10, 0x90]), - Bytes::from_static(&[0x00, 0x90]), - Bytes::from_static(&[0x00, 0x90]), - ], - vec![ - Bytes::from_static(&[0x10, 0x91]), - Bytes::from_static(&[0x00, 0x91]), - ], - ], - ), - ( - "WithPictureID_1byte", - Vp8Payloader { - enable_picture_id: true, - picture_id: 0x20, - }, - 5, - vec![ - Bytes::from_static(&[0x90, 0x90, 0x90]), - Bytes::from_static(&[0x91, 0x91]), - ], - vec![ - vec![ - Bytes::from_static(&[0x90, 0x80, 0x20, 0x90, 0x90]), - Bytes::from_static(&[0x80, 0x80, 0x20, 0x90]), - ], - vec![Bytes::from_static(&[0x90, 0x80, 0x21, 0x91, 0x91])], - ], - ), - ( - "WithPictureID_2bytes", - Vp8Payloader { - enable_picture_id: true, - picture_id: 0x120, - }, - 6, - vec![ - Bytes::from_static(&[0x90, 0x90, 0x90]), - Bytes::from_static(&[0x91, 0x91]), - ], - vec![ - vec![ - Bytes::from_static(&[0x90, 0x80, 0x81, 0x20, 0x90, 0x90]), - Bytes::from_static(&[0x80, 0x80, 0x81, 0x20, 0x90]), - ], - vec![Bytes::from_static(&[0x90, 0x80, 0x81, 0x21, 0x91, 0x91])], - ], - ), - ]; - - for (name, mut pck, mtu, payloads, expected) in tests { - for (i, payload) in payloads.iter().enumerate() { - let actual = pck.payload(mtu, payload)?; - assert_eq!(expected[i], actual, "{name}: Generated packet[{i}] differs"); - } - } - - Ok(()) -} - -#[test] -fn test_vp8_payload_error() -> Result<()> { - let mut pck = Vp8Payloader::default(); - let empty = Bytes::from_static(&[]); - let payload = Bytes::from_static(&[0x90, 0x90, 0x90]); - - // Positive MTU, empty payload - let result = pck.payload(1, &empty)?; - assert!(result.is_empty(), "Generated payload should be empty"); - - // Positive MTU, small payload - let result = pck.payload(1, &payload)?; - assert_eq!(result.len(), 0, "Generated payload should be empty"); - - // Positive MTU, small payload - let result = pck.payload(2, &payload)?; - assert_eq!( - result.len(), - payload.len(), - "Generated payload should be the same size as original payload size" - ); - - Ok(()) -} - -#[test] -fn test_vp8_partition_head_checker_is_partition_head() -> Result<()> { - let vp8 = Vp8Packet::default(); - - //"SmallPacket" - assert!( - !vp8.is_partition_head(&Bytes::from_static(&[0x00])), - "Small packet should not be the head of a new partition" - ); - - //"SFlagON", - assert!( - vp8.is_partition_head(&Bytes::from_static(&[0x10, 0x00, 0x00, 0x00])), - "Packet with S flag should be the head of a new partition" - ); - - //"SFlagOFF" - assert!( - !vp8.is_partition_head(&Bytes::from_static(&[0x00, 0x00, 0x00, 0x00])), - "Packet without S flag should not be the head of a new partition" - ); - - Ok(()) -} diff --git a/rtp/src/codecs/vp9/mod.rs b/rtp/src/codecs/vp9/mod.rs deleted file mode 100644 index 3a3f130a3..000000000 --- a/rtp/src/codecs/vp9/mod.rs +++ /dev/null @@ -1,467 +0,0 @@ -#[cfg(test)] -mod vp9_test; - -use std::fmt; -use std::sync::Arc; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use crate::error::{Error, Result}; -use crate::packetizer::{Depacketizer, Payloader}; - -/// Flexible mode 15 bit picture ID -const VP9HEADER_SIZE: usize = 3; -const MAX_SPATIAL_LAYERS: u8 = 5; -const MAX_VP9REF_PICS: usize = 3; - -/// InitialPictureIDFn is a function that returns random initial picture ID. -pub type InitialPictureIDFn = Arc u16) + Send + Sync>; - -/// Vp9Payloader payloads VP9 packets -#[derive(Default, Clone)] -pub struct Vp9Payloader { - picture_id: u16, - initialized: bool, - - pub initial_picture_id_fn: Option, -} - -impl fmt::Debug for Vp9Payloader { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Vp9Payloader") - .field("picture_id", &self.picture_id) - .field("initialized", &self.initialized) - .finish() - } -} - -impl Payloader for Vp9Payloader { - /// Payload fragments an Vp9Payloader packet across one or more byte arrays - fn payload(&mut self, mtu: usize, payload: &Bytes) -> Result> { - /* - * https://www.ietf.org/id/draft-ietf-payload-vp9-13.txt - * - * Flexible mode (F=1) - * 0 1 2 3 4 5 6 7 - * +-+-+-+-+-+-+-+-+ - * |I|P|L|F|B|E|V|Z| (REQUIRED) - * +-+-+-+-+-+-+-+-+ - * I: |M| PICTURE ID | (REQUIRED) - * +-+-+-+-+-+-+-+-+ - * M: | EXTENDED PID | (RECOMMENDED) - * +-+-+-+-+-+-+-+-+ - * L: | tid |U| SID |D| (CONDITIONALLY RECOMMENDED) - * +-+-+-+-+-+-+-+-+ -\ - * P,F: | P_DIFF |N| (CONDITIONALLY REQUIRED) - up to 3 times - * +-+-+-+-+-+-+-+-+ -/ - * V: | SS | - * | .. | - * +-+-+-+-+-+-+-+-+ - * - * Non-flexible mode (F=0) - * 0 1 2 3 4 5 6 7 - * +-+-+-+-+-+-+-+-+ - * |I|P|L|F|B|E|V|Z| (REQUIRED) - * +-+-+-+-+-+-+-+-+ - * I: |M| PICTURE ID | (RECOMMENDED) - * +-+-+-+-+-+-+-+-+ - * M: | EXTENDED PID | (RECOMMENDED) - * +-+-+-+-+-+-+-+-+ - * L: | tid |U| SID |D| (CONDITIONALLY RECOMMENDED) - * +-+-+-+-+-+-+-+-+ - * | tl0picidx | (CONDITIONALLY REQUIRED) - * +-+-+-+-+-+-+-+-+ - * V: | SS | - * | .. | - * +-+-+-+-+-+-+-+-+ - */ - - if payload.is_empty() || mtu == 0 { - return Ok(vec![]); - } - - if !self.initialized { - if self.initial_picture_id_fn.is_none() { - self.initial_picture_id_fn = - Some(Arc::new(|| -> u16 { rand::random::() & 0x7FFF })); - } - self.picture_id = if let Some(f) = &self.initial_picture_id_fn { - f() - } else { - 0 - }; - self.initialized = true; - } - - let max_fragment_size = mtu as isize - VP9HEADER_SIZE as isize; - let mut payloads = vec![]; - let mut payload_data_remaining = payload.len(); - let mut payload_data_index = 0; - - if std::cmp::min(max_fragment_size, payload_data_remaining as isize) <= 0 { - return Ok(vec![]); - } - - while payload_data_remaining > 0 { - let current_fragment_size = - std::cmp::min(max_fragment_size as usize, payload_data_remaining); - let mut out = BytesMut::with_capacity(VP9HEADER_SIZE + current_fragment_size); - let mut buf = [0u8; VP9HEADER_SIZE]; - buf[0] = 0x90; // F=1 I=1 - if payload_data_index == 0 { - buf[0] |= 0x08; // B=1 - } - if payload_data_remaining == current_fragment_size { - buf[0] |= 0x04; // E=1 - } - buf[1] = (self.picture_id >> 8) as u8 | 0x80; - buf[2] = (self.picture_id & 0xFF) as u8; - - out.put(&buf[..]); - - out.put( - &*payload.slice(payload_data_index..payload_data_index + current_fragment_size), - ); - - payloads.push(out.freeze()); - - payload_data_remaining -= current_fragment_size; - payload_data_index += current_fragment_size; - } - - self.picture_id += 1; - self.picture_id &= 0x7FFF; - - Ok(payloads) - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} - -/// Vp9Packet represents the VP9 header that is stored in the payload of an RTP Packet -#[derive(PartialEq, Eq, Debug, Default, Clone)] -pub struct Vp9Packet { - /// picture ID is present - pub i: bool, - /// inter-picture predicted frame. - pub p: bool, - /// layer indices present - pub l: bool, - /// flexible mode - pub f: bool, - /// start of frame. beginning of new vp9 frame - pub b: bool, - /// end of frame - pub e: bool, - /// scalability structure (SS) present - pub v: bool, - /// Not a reference frame for upper spatial layers - pub z: bool, - - /// Recommended headers - /// 7 or 16 bits, picture ID. - pub picture_id: u16, - - /// Conditionally recommended headers - /// Temporal layer ID - pub tid: u8, - /// Switching up point - pub u: bool, - /// Spatial layer ID - pub sid: u8, - /// Inter-layer dependency used - pub d: bool, - - /// Conditionally required headers - /// Reference index (F=1) - pub pdiff: Vec, - /// Temporal layer zero index (F=0) - pub tl0picidx: u8, - - /// Scalability structure headers - /// N_S + 1 indicates the number of spatial layers present in the VP9 stream - pub ns: u8, - /// Each spatial layer's frame resolution present - pub y: bool, - /// PG description present flag. - pub g: bool, - /// N_G indicates the number of pictures in a Picture Group (PG) - pub ng: u8, - pub width: Vec, - pub height: Vec, - /// Temporal layer ID of pictures in a Picture Group - pub pgtid: Vec, - /// Switching up point of pictures in a Picture Group - pub pgu: Vec, - /// Reference indices of pictures in a Picture Group - pub pgpdiff: Vec>, -} - -impl Depacketizer for Vp9Packet { - /// depacketize parses the passed byte slice and stores the result in the Vp9Packet this method is called upon - fn depacketize(&mut self, packet: &Bytes) -> Result { - if packet.is_empty() { - return Err(Error::ErrShortPacket); - } - - let reader = &mut packet.clone(); - let b = reader.get_u8(); - - self.i = (b & 0x80) != 0; - self.p = (b & 0x40) != 0; - self.l = (b & 0x20) != 0; - self.f = (b & 0x10) != 0; - self.b = (b & 0x08) != 0; - self.e = (b & 0x04) != 0; - self.v = (b & 0x02) != 0; - self.z = (b & 0x01) != 0; - - let mut payload_index = 1; - - if self.i { - payload_index = self.parse_picture_id(reader, payload_index)?; - } - - if self.l { - payload_index = self.parse_layer_info(reader, payload_index)?; - } - - if self.f && self.p { - payload_index = self.parse_ref_indices(reader, payload_index)?; - } - - if self.v { - payload_index = self.parse_ssdata(reader, payload_index)?; - } - - Ok(packet.slice(payload_index..)) - } - - /// is_partition_head checks whether if this is a head of the VP9 partition - fn is_partition_head(&self, payload: &Bytes) -> bool { - if payload.is_empty() { - false - } else { - (payload[0] & 0x08) != 0 - } - } - - fn is_partition_tail(&self, marker: bool, _payload: &Bytes) -> bool { - marker - } -} - -impl Vp9Packet { - // Picture ID: - // - // +-+-+-+-+-+-+-+-+ - // I: |M| PICTURE ID | M:0 => picture id is 7 bits. - // +-+-+-+-+-+-+-+-+ M:1 => picture id is 15 bits. - // M: | EXTENDED PID | - // +-+-+-+-+-+-+-+-+ - // - fn parse_picture_id( - &mut self, - reader: &mut dyn Buf, - mut payload_index: usize, - ) -> Result { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - let b = reader.get_u8(); - payload_index += 1; - // PID present? - if (b & 0x80) != 0 { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - // M == 1, PID is 15bit - self.picture_id = (((b & 0x7f) as u16) << 8) | (reader.get_u8() as u16); - payload_index += 1; - } else { - self.picture_id = (b & 0x7F) as u16; - } - - Ok(payload_index) - } - - fn parse_layer_info( - &mut self, - reader: &mut dyn Buf, - mut payload_index: usize, - ) -> Result { - payload_index = self.parse_layer_info_common(reader, payload_index)?; - - if self.f { - Ok(payload_index) - } else { - self.parse_layer_info_non_flexible_mode(reader, payload_index) - } - } - - // Layer indices (flexible mode): - // - // +-+-+-+-+-+-+-+-+ - // L: | T |U| S |D| - // +-+-+-+-+-+-+-+-+ - // - fn parse_layer_info_common( - &mut self, - reader: &mut dyn Buf, - mut payload_index: usize, - ) -> Result { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - let b = reader.get_u8(); - payload_index += 1; - - self.tid = b >> 5; - self.u = b & 0x10 != 0; - self.sid = (b >> 1) & 0x7; - self.d = b & 0x01 != 0; - - if self.sid >= MAX_SPATIAL_LAYERS { - Err(Error::ErrTooManySpatialLayers) - } else { - Ok(payload_index) - } - } - - // Layer indices (non-flexible mode): - // - // +-+-+-+-+-+-+-+-+ - // L: | T |U| S |D| - // +-+-+-+-+-+-+-+-+ - // | tl0picidx | - // +-+-+-+-+-+-+-+-+ - // - fn parse_layer_info_non_flexible_mode( - &mut self, - reader: &mut dyn Buf, - mut payload_index: usize, - ) -> Result { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - self.tl0picidx = reader.get_u8(); - payload_index += 1; - Ok(payload_index) - } - - // Reference indices: - // - // +-+-+-+-+-+-+-+-+ P=1,F=1: At least one reference index - // P,F: | P_DIFF |N| up to 3 times has to be specified. - // +-+-+-+-+-+-+-+-+ N=1: An additional P_DIFF follows - // current P_DIFF. - // - fn parse_ref_indices( - &mut self, - reader: &mut dyn Buf, - mut payload_index: usize, - ) -> Result { - let mut b = 1u8; - while (b & 0x1) != 0 { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - b = reader.get_u8(); - payload_index += 1; - - self.pdiff.push(b >> 1); - if self.pdiff.len() >= MAX_VP9REF_PICS { - return Err(Error::ErrTooManyPDiff); - } - } - - Ok(payload_index) - } - - // Scalability structure (SS): - // - // +-+-+-+-+-+-+-+-+ - // V: | N_S |Y|G|-|-|-| - // +-+-+-+-+-+-+-+-+ -| - // Y: | WIDTH | (OPTIONAL) . - // + + . - // | | (OPTIONAL) . - // +-+-+-+-+-+-+-+-+ . N_S + 1 times - // | HEIGHT | (OPTIONAL) . - // + + . - // | | (OPTIONAL) . - // +-+-+-+-+-+-+-+-+ -| - // G: | N_G | (OPTIONAL) - // +-+-+-+-+-+-+-+-+ -| - // N_G: | T |U| R |-|-| (OPTIONAL) . - // +-+-+-+-+-+-+-+-+ -| . N_G times - // | P_DIFF | (OPTIONAL) . R times . - // +-+-+-+-+-+-+-+-+ -| -| - // - fn parse_ssdata(&mut self, reader: &mut dyn Buf, mut payload_index: usize) -> Result { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - - let b = reader.get_u8(); - payload_index += 1; - - self.ns = b >> 5; - self.y = b & 0x10 != 0; - self.g = (b >> 1) & 0x7 != 0; - - let ns = (self.ns + 1) as usize; - self.ng = 0; - - if self.y { - if reader.remaining() < 4 * ns { - return Err(Error::ErrShortPacket); - } - - self.width = vec![0u16; ns]; - self.height = vec![0u16; ns]; - for i in 0..ns { - self.width[i] = reader.get_u16(); - self.height[i] = reader.get_u16(); - } - payload_index += 4 * ns; - } - - if self.g { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - - self.ng = reader.get_u8(); - payload_index += 1; - } - - for i in 0..self.ng as usize { - if reader.remaining() == 0 { - return Err(Error::ErrShortPacket); - } - let b = reader.get_u8(); - payload_index += 1; - - self.pgtid.push(b >> 5); - self.pgu.push(b & 0x10 != 0); - - let r = ((b >> 2) & 0x3) as usize; - if reader.remaining() < r { - return Err(Error::ErrShortPacket); - } - - self.pgpdiff.push(vec![]); - for _ in 0..r { - let b = reader.get_u8(); - payload_index += 1; - - self.pgpdiff[i].push(b); - } - } - - Ok(payload_index) - } -} diff --git a/rtp/src/codecs/vp9/vp9_test.rs b/rtp/src/codecs/vp9/vp9_test.rs deleted file mode 100644 index 596d423ba..000000000 --- a/rtp/src/codecs/vp9/vp9_test.rs +++ /dev/null @@ -1,364 +0,0 @@ -use super::*; - -#[test] -fn test_vp9_packet_unmarshal() -> Result<()> { - let tests = vec![ - ( - "Empty", - Bytes::from_static(&[]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "NonFlexible", - Bytes::from_static(&[0x00, 0xAA]), - Vp9Packet::default(), - Bytes::from_static(&[0xAA]), - None, - ), - ( - "NonFlexiblePictureID", - Bytes::from_static(&[0x80, 0x02, 0xAA]), - Vp9Packet { - i: true, - picture_id: 0x02, - ..Default::default() - }, - Bytes::from_static(&[0xAA]), - None, - ), - ( - "NonFlexiblePictureIDExt", - Bytes::from_static(&[0x80, 0x81, 0xFF, 0xAA]), - Vp9Packet { - i: true, - picture_id: 0x01FF, - ..Default::default() - }, - Bytes::from_static(&[0xAA]), - None, - ), - ( - "NonFlexiblePictureIDExt_ShortPacket0", - Bytes::from_static(&[0x80, 0x81]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "NonFlexiblePictureIDExt_ShortPacket1", - Bytes::from_static(&[0x80]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "NonFlexibleLayerIndicePictureID", - Bytes::from_static(&[0xA0, 0x02, 0x23, 0x01, 0xAA]), - Vp9Packet { - i: true, - l: true, - picture_id: 0x02, - tid: 0x01, - sid: 0x01, - d: true, - tl0picidx: 0x01, - ..Default::default() - }, - Bytes::from_static(&[0xAA]), - None, - ), - ( - "FlexibleLayerIndicePictureID", - Bytes::from_static(&[0xB0, 0x02, 0x23, 0x01, 0xAA]), - Vp9Packet { - f: true, - i: true, - l: true, - picture_id: 0x02, - tid: 0x01, - sid: 0x01, - d: true, - ..Default::default() - }, - Bytes::from_static(&[0x01, 0xAA]), - None, - ), - ( - "NonFlexibleLayerIndicePictureID_ShortPacket0", - Bytes::from_static(&[0xA0, 0x02, 0x23]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "NonFlexibleLayerIndicePictureID_ShortPacket1", - Bytes::from_static(&[0xA0, 0x02]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "FlexiblePictureIDRefIndex", - Bytes::from_static(&[0xD0, 0x02, 0x03, 0x04, 0xAA]), - Vp9Packet { - i: true, - p: true, - f: true, - picture_id: 0x02, - pdiff: vec![0x01, 0x02], - ..Default::default() - }, - Bytes::from_static(&[0xAA]), - None, - ), - ( - "FlexiblePictureIDRefIndex_TooManyPDiff", - Bytes::from_static(&[0xD0, 0x02, 0x03, 0x05, 0x07, 0x09, 0x10, 0xAA]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrTooManyPDiff), - ), - ( - "FlexiblePictureIDRefIndexNoPayload", - Bytes::from_static(&[0xD0, 0x02, 0x03, 0x04]), - Vp9Packet { - i: true, - p: true, - f: true, - picture_id: 0x02, - pdiff: vec![0x01, 0x02], - ..Default::default() - }, - Bytes::from_static(&[]), - None, - ), - ( - "FlexiblePictureIDRefIndex_ShortPacket0", - Bytes::from_static(&[0xD0, 0x02, 0x03]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "FlexiblePictureIDRefIndex_ShortPacket1", - Bytes::from_static(&[0xD0, 0x02]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "FlexiblePictureIDRefIndex_ShortPacket2", - Bytes::from_static(&[0xD0]), - Vp9Packet::default(), - Bytes::new(), - Some(Error::ErrShortPacket), - ), - ( - "ScalabilityStructureResolutionsNoPayload", - Bytes::from_static(&[ - 0x0A, - (1 << 5) | (1 << 4), // NS:1 Y:1 G:0 - (640 >> 8) as u8, - (640 & 0xff) as u8, - (360 >> 8) as u8, - (360 & 0xff) as u8, - (1280 >> 8) as u8, - (1280 & 0xff) as u8, - (720 >> 8) as u8, - (720 & 0xff) as u8, - ]), - Vp9Packet { - b: true, - v: true, - ns: 1, - y: true, - g: false, - ng: 0, - width: vec![640, 1280], - height: vec![360, 720], - ..Default::default() - }, - Bytes::new(), - None, - ), - ( - "ScalabilityStructureNoPayload", - Bytes::from_static(&[ - 0x0A, - (1 << 5) | (1 << 3), // NS:1 Y:0 G:1 - 2, - (1 << 4), // T:0 U:1 R:0 - - (2 << 5) | (1 << 2), // T:2 U:0 R:1 - - 33, - ]), - Vp9Packet { - b: true, - v: true, - ns: 1, - y: false, - g: true, - ng: 2, - pgtid: vec![0, 2], - pgu: vec![true, false], - pgpdiff: vec![vec![], vec![33]], - ..Default::default() - }, - Bytes::new(), - None, - ), - ]; - - for (name, b, pkt, expected, err) in tests { - let mut p = Vp9Packet::default(); - - if let Some(expected) = err { - if let Err(actual) = p.depacketize(&b) { - assert_eq!( - expected, actual, - "{name}: expected {expected}, but got {actual}" - ); - } else { - panic!("{name}: expected error, but got passed"); - } - } else { - let payload = p.depacketize(&b)?; - assert_eq!(pkt, p, "{name}: expected {pkt:?}, but got {p:?}"); - assert_eq!(payload, expected); - } - } - - Ok(()) -} - -#[test] -fn test_vp9_payloader_payload() -> Result<()> { - let mut r0 = 8692; - let mut rands = vec![]; - for _ in 0..10 { - rands.push(vec![(r0 >> 8) as u8 | 0x80, (r0 & 0xFF) as u8]); - r0 += 1; - } - - let tests = vec![ - ("NilPayload", vec![Bytes::new()], 100, vec![]), - ("SmallMTU", vec![Bytes::from(vec![0x00, 0x00])], 1, vec![]), - ( - "NegativeMTU", - vec![Bytes::from(vec![0x00, 0x00])], - 0, - vec![], - ), - ( - "OnePacket", - vec![Bytes::from(vec![0x01, 0x02])], - 10, - vec![Bytes::from(vec![ - 0x9C, - rands[0][0], - rands[0][1], - 0x01, - 0x02, - ])], - ), - ( - "TwoPackets", - vec![Bytes::from(vec![0x01, 0x02])], - 4, - vec![ - Bytes::from(vec![0x98, rands[0][0], rands[0][1], 0x01]), - Bytes::from(vec![0x94, rands[0][0], rands[0][1], 0x02]), - ], - ), - ( - "ThreePackets", - vec![Bytes::from(vec![0x01, 0x02, 0x03])], - 4, - vec![ - Bytes::from(vec![0x98, rands[0][0], rands[0][1], 0x01]), - Bytes::from(vec![0x90, rands[0][0], rands[0][1], 0x02]), - Bytes::from(vec![0x94, rands[0][0], rands[0][1], 0x03]), - ], - ), - ( - "TwoFramesFourPackets", - vec![Bytes::from(vec![0x01, 0x02, 0x03]), Bytes::from(vec![0x04])], - 5, - vec![ - Bytes::from(vec![0x98, rands[0][0], rands[0][1], 0x01, 0x02]), - Bytes::from(vec![0x94, rands[0][0], rands[0][1], 0x03]), - Bytes::from(vec![0x9C, rands[1][0], rands[1][1], 0x04]), - ], - ), - ]; - - for (name, bs, mtu, expected) in tests { - let mut pck = Vp9Payloader { - initial_picture_id_fn: Some(Arc::new(|| -> u16 { 8692 })), - ..Default::default() - }; - - let mut actual = vec![]; - for b in &bs { - actual.extend(pck.payload(mtu, b)?); - } - assert_eq!(actual, expected, "{name}: Payloaded packet"); - } - - //"PictureIDOverflow" - { - let mut pck = Vp9Payloader { - initial_picture_id_fn: Some(Arc::new(|| -> u16 { 8692 })), - ..Default::default() - }; - let mut p_prev = Vp9Packet::default(); - for i in 0..0x8000 { - let res = pck.payload(4, &Bytes::from_static(&[0x01]))?; - let mut p = Vp9Packet::default(); - p.depacketize(&res[0])?; - - if i > 0 { - if p_prev.picture_id == 0x7FFF { - assert_eq!( - p.picture_id, 0, - "Picture ID next to 0x7FFF must be 0, got {}", - p.picture_id - ); - } else if p_prev.picture_id + 1 != p.picture_id { - panic!( - "Picture ID next must be incremented by 1: {} -> {}", - p_prev.picture_id, p.picture_id, - ); - } - } - - p_prev = p; - } - } - - Ok(()) -} - -#[test] -fn test_vp9_partition_head_checker_is_partition_head() -> Result<()> { - let vp9 = Vp9Packet::default(); - - //"SmallPacket" - assert!( - !vp9.is_partition_head(&Bytes::new()), - "Small packet should not be the head of a new partition" - ); - - //"NormalPacket" - assert!( - vp9.is_partition_head(&Bytes::from_static(&[0x18, 0x00, 0x00])), - "VP9 RTP packet with B flag should be head of a new partition" - ); - assert!( - !vp9.is_partition_head(&Bytes::from_static(&[0x10, 0x00, 0x00])), - "VP9 RTP packet without B flag should not be head of a new partition" - ); - - Ok(()) -} diff --git a/rtp/src/error.rs b/rtp/src/error.rs deleted file mode 100644 index 72612f744..000000000 --- a/rtp/src/error.rs +++ /dev/null @@ -1,86 +0,0 @@ -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("RTP header size insufficient")] - ErrHeaderSizeInsufficient, - #[error("RTP header size insufficient for extension")] - ErrHeaderSizeInsufficientForExtension, - #[error("buffer too small")] - ErrBufferTooSmall, - #[error("extension not enabled")] - ErrHeaderExtensionsNotEnabled, - #[error("extension not found")] - ErrHeaderExtensionNotFound, - - #[error("header extension id must be between 1 and 14 for RFC 5285 extensions")] - ErrRfc8285oneByteHeaderIdrange, - #[error("header extension payload must be 16bytes or less for RFC 5285 one byte extensions")] - ErrRfc8285oneByteHeaderSize, - - #[error("header extension id must be between 1 and 255 for RFC 5285 extensions")] - ErrRfc8285twoByteHeaderIdrange, - #[error("header extension payload must be 255bytes or less for RFC 5285 two byte extensions")] - ErrRfc8285twoByteHeaderSize, - - #[error("header extension id must be 0 for none RFC 5285 extensions")] - ErrRfc3550headerIdrange, - - #[error("packet is not large enough")] - ErrShortPacket, - #[error("invalid nil packet")] - ErrNilPacket, - #[error("too many PDiff")] - ErrTooManyPDiff, - #[error("too many spatial layers")] - ErrTooManySpatialLayers, - #[error("NALU Type is unhandled")] - ErrUnhandledNaluType, - - #[error("corrupted h265 packet")] - ErrH265CorruptedPacket, - #[error("invalid h265 packet type")] - ErrInvalidH265PacketType, - - #[error("payload is too small for OBU extension header")] - ErrPayloadTooSmallForObuExtensionHeader, - #[error("payload is too small for OBU payload size")] - ErrPayloadTooSmallForObuPayloadSize, - - #[error("extension_payload must be in 32-bit words")] - HeaderExtensionPayloadNot32BitWords, - #[error("audio level overflow")] - AudioLevelOverflow, - #[error("playout delay overflow")] - PlayoutDelayOverflow, - #[error("payload is not large enough")] - PayloadIsNotLargeEnough, - #[error("STAP-A declared size({0}) is larger than buffer({1})")] - StapASizeLargerThanBuffer(usize, usize), - #[error("nalu type {0} is currently not handled")] - NaluTypeIsNotHandled(u8), - #[error("{0}")] - Util(#[from] util::Error), - - #[error("{0}")] - Other(String), -} - -impl From for util::Error { - fn from(e: Error) -> Self { - util::Error::from_std(e) - } -} - -impl PartialEq for Error { - fn eq(&self, other: &util::Error) -> bool { - if let Some(down) = other.downcast_ref::() { - self == down - } else { - false - } - } -} diff --git a/rtp/src/extension/abs_send_time_extension/abs_send_time_extension_test.rs b/rtp/src/extension/abs_send_time_extension/abs_send_time_extension_test.rs deleted file mode 100644 index 152ee5f95..000000000 --- a/rtp/src/extension/abs_send_time_extension/abs_send_time_extension_test.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::time::Duration; - -use bytes::BytesMut; -use chrono::prelude::*; - -use super::*; -use crate::error::Result; - -const ABS_SEND_TIME_RESOLUTION: i128 = 1000; - -#[test] -fn test_ntp_conversion() -> Result<()> { - let loc = FixedOffset::west_opt(5 * 60 * 60).unwrap(); // UTC-5 - let tests = vec![ - ( - loc.with_ymd_and_hms(1985, 6, 23, 4, 0, 0).unwrap(), - 0xa0c65b1000000000_u64, - ), - ( - // TODO: fix this. MA: There's only so long I will stare at - // APIs that sacrifice convenience for correctness. - #[allow(deprecated)] - loc.ymd(1999, 12, 31) - .and_hms_nano_opt(23, 59, 59, 500000) - .unwrap(), - 0xbc18084f0020c49b_u64, - ), - ( - #[allow(deprecated)] - loc.ymd(2019, 3, 27) - .and_hms_nano_opt(13, 39, 30, 8675309) - .unwrap(), - 0xe04641e202388b88_u64, - ), - ]; - - for (t, n) in &tests { - let st = UNIX_EPOCH - .checked_add(Duration::from_nanos(t.timestamp_nanos_opt().unwrap() as u64)) - .unwrap_or(UNIX_EPOCH); - let ntp = unix2ntp(st); - - if cfg!(target_os = "windows") { - let actual = ntp as i128; - let expected = *n as i128; - let diff = actual - expected; - if !(-ABS_SEND_TIME_RESOLUTION..=ABS_SEND_TIME_RESOLUTION).contains(&diff) { - panic!("unix2ntp error, expected: {:?}, got: {:?}", ntp, *n,); - } - } else { - assert_eq!(ntp, *n, "unix2ntp error"); - } - } - - for (t, n) in &tests { - let output = ntp2unix(*n); - let input = UNIX_EPOCH - .checked_add(Duration::from_nanos(t.timestamp_nanos_opt().unwrap() as u64)) - .unwrap_or(UNIX_EPOCH); - let diff = input.duration_since(output).unwrap().as_nanos() as i128; - if !(-ABS_SEND_TIME_RESOLUTION..=ABS_SEND_TIME_RESOLUTION).contains(&diff) { - panic!( - "Converted time.Time from NTP time differs, expected: {input:?}, got: {output:?}", - ); - } - } - - Ok(()) -} - -#[test] -fn test_abs_send_time_extension_roundtrip() -> Result<()> { - let tests = vec![ - AbsSendTimeExtension { timestamp: 123456 }, - AbsSendTimeExtension { timestamp: 654321 }, - ]; - - for test in &tests { - let mut raw = BytesMut::with_capacity(test.marshal_size()); - raw.resize(test.marshal_size(), 0); - test.marshal_to(&mut raw)?; - let raw = raw.freeze(); - let buf = &mut raw.clone(); - let out = AbsSendTimeExtension::unmarshal(buf)?; - assert_eq!(test.timestamp, out.timestamp); - } - - Ok(()) -} - -#[test] -fn test_abs_send_time_extension_estimate() -> Result<()> { - let tests = vec![ - //FFFFFFC000000000 mask of second - (0xa0c65b1000100000, 0xa0c65b1001000000), // not carried - (0xa0c65b3f00000000, 0xa0c65b4001000000), // carried during transmission - ]; - - for (send_ntp, receive_ntp) in tests { - let in_time = ntp2unix(send_ntp); - let send = AbsSendTimeExtension { - timestamp: send_ntp >> 14, - }; - let mut raw = BytesMut::with_capacity(send.marshal_size()); - raw.resize(send.marshal_size(), 0); - send.marshal_to(&mut raw)?; - let raw = raw.freeze(); - let buf = &mut raw.clone(); - let receive = AbsSendTimeExtension::unmarshal(buf)?; - - let estimated = receive.estimate(ntp2unix(receive_ntp)); - let diff = estimated.duration_since(in_time).unwrap().as_nanos() as i128; - if !(-ABS_SEND_TIME_RESOLUTION..=ABS_SEND_TIME_RESOLUTION).contains(&diff) { - panic!( - "Converted time.Time from NTP time differs, expected: {in_time:?}, got: {estimated:?}", - ); - } - } - - Ok(()) -} diff --git a/rtp/src/extension/abs_send_time_extension/mod.rs b/rtp/src/extension/abs_send_time_extension/mod.rs deleted file mode 100644 index 33000b47a..000000000 --- a/rtp/src/extension/abs_send_time_extension/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -#[cfg(test)] -mod abs_send_time_extension_test; - -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use bytes::{Buf, BufMut}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; - -pub const ABS_SEND_TIME_EXTENSION_SIZE: usize = 3; - -/// AbsSendTimeExtension is a extension payload format in -/// http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time -#[derive(PartialEq, Eq, Debug, Default, Copy, Clone)] -pub struct AbsSendTimeExtension { - pub timestamp: u64, -} - -impl Unmarshal for AbsSendTimeExtension { - /// Unmarshal parses the passed byte slice and stores the result in the members. - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < ABS_SEND_TIME_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - - let b0 = raw_packet.get_u8(); - let b1 = raw_packet.get_u8(); - let b2 = raw_packet.get_u8(); - let timestamp = (b0 as u64) << 16 | (b1 as u64) << 8 | b2 as u64; - - Ok(AbsSendTimeExtension { timestamp }) - } -} - -impl MarshalSize for AbsSendTimeExtension { - /// MarshalSize returns the size of the AbsSendTimeExtension once marshaled. - fn marshal_size(&self) -> usize { - ABS_SEND_TIME_EXTENSION_SIZE - } -} - -impl Marshal for AbsSendTimeExtension { - /// MarshalTo serializes the members to buffer. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < ABS_SEND_TIME_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - - buf.put_u8(((self.timestamp & 0xFF0000) >> 16) as u8); - buf.put_u8(((self.timestamp & 0xFF00) >> 8) as u8); - buf.put_u8((self.timestamp & 0xFF) as u8); - - Ok(ABS_SEND_TIME_EXTENSION_SIZE) - } -} - -impl AbsSendTimeExtension { - /// Estimate absolute send time according to the receive time. - /// Note that if the transmission delay is larger than 64 seconds, estimated time will be wrong. - pub fn estimate(&self, receive: SystemTime) -> SystemTime { - let receive_ntp = unix2ntp(receive); - let mut ntp = receive_ntp & 0xFFFFFFC000000000 | (self.timestamp & 0xFFFFFF) << 14; - if receive_ntp < ntp { - // Receive time must be always later than send time - ntp -= 0x1000000 << 14; - } - - ntp2unix(ntp) - } - - /// NewAbsSendTimeExtension makes new AbsSendTimeExtension from time.Time. - pub fn new(send_time: SystemTime) -> Self { - AbsSendTimeExtension { - timestamp: unix2ntp(send_time) >> 14, - } - } -} - -pub fn unix2ntp(st: SystemTime) -> u64 { - let u = st - .duration_since(UNIX_EPOCH) - .unwrap_or_else(|_| Duration::from_secs(0)) - .as_nanos() as u64; - let mut s = u / 1_000_000_000; - s += 0x83AA7E80; //offset in seconds between unix epoch and ntp epoch - let mut f = u % 1_000_000_000; - f <<= 32; - f /= 1_000_000_000; - s <<= 32; - - s | f -} - -pub fn ntp2unix(t: u64) -> SystemTime { - let mut s = t >> 32; - let mut f = t & 0xFFFFFFFF; - f *= 1_000_000_000; - f >>= 32; - s -= 0x83AA7E80; - let u = s * 1_000_000_000 + f; - - UNIX_EPOCH - .checked_add(Duration::new(u / 1_000_000_000, (u % 1_000_000_000) as u32)) - .unwrap_or(UNIX_EPOCH) -} diff --git a/rtp/src/extension/audio_level_extension/audio_level_extension_test.rs b/rtp/src/extension/audio_level_extension/audio_level_extension_test.rs deleted file mode 100644 index 89814d763..000000000 --- a/rtp/src/extension/audio_level_extension/audio_level_extension_test.rs +++ /dev/null @@ -1,66 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::*; -use crate::error::Result; - -#[test] -fn test_audio_level_extension_too_small() -> Result<()> { - let mut buf = &vec![0u8; 0][..]; - let result = AudioLevelExtension::unmarshal(&mut buf); - assert!(result.is_err()); - - Ok(()) -} - -#[test] -fn test_audio_level_extension_voice_true() -> Result<()> { - let raw = Bytes::from_static(&[0x88]); - let buf = &mut raw.clone(); - let a1 = AudioLevelExtension::unmarshal(buf)?; - let a2 = AudioLevelExtension { - level: 8, - voice: true, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_audio_level_extension_voice_false() -> Result<()> { - let raw = Bytes::from_static(&[0x8]); - let buf = &mut raw.clone(); - let a1 = AudioLevelExtension::unmarshal(buf)?; - let a2 = AudioLevelExtension { - level: 8, - voice: false, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_audio_level_extension_level_overflow() -> Result<()> { - let a = AudioLevelExtension { - level: 128, - voice: false, - }; - - let mut dst = BytesMut::with_capacity(a.marshal_size()); - dst.resize(a.marshal_size(), 0); - let result = a.marshal_to(&mut dst); - assert!(result.is_err()); - - Ok(()) -} diff --git a/rtp/src/extension/audio_level_extension/mod.rs b/rtp/src/extension/audio_level_extension/mod.rs deleted file mode 100644 index 2dce2be10..000000000 --- a/rtp/src/extension/audio_level_extension/mod.rs +++ /dev/null @@ -1,80 +0,0 @@ -#[cfg(test)] -mod audio_level_extension_test; - -use bytes::{Buf, BufMut}; -use serde::{Deserialize, Serialize}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; - -// AUDIO_LEVEL_EXTENSION_SIZE One byte header size -pub const AUDIO_LEVEL_EXTENSION_SIZE: usize = 1; - -/// AudioLevelExtension is a extension payload format described in -/// https://tools.ietf.org/html/rfc6464 -/// -/// Implementation based on: -/// https://chromium.googlesource.com/external/webrtc/+/e2a017725570ead5946a4ca8235af27470ca0df9/webrtc/modules/rtp_rtcp/source/rtp_header_extensions.cc#49 -/// -/// One byte format: -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ID | len=0 |V| level | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// Two byte format: -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ID | len=1 |V| level | 0 (pad) | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(PartialEq, Eq, Debug, Default, Copy, Clone, Serialize, Deserialize)] -pub struct AudioLevelExtension { - pub level: u8, - pub voice: bool, -} - -impl Unmarshal for AudioLevelExtension { - /// Unmarshal parses the passed byte slice and stores the result in the members - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < AUDIO_LEVEL_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - - let b = raw_packet.get_u8(); - - Ok(AudioLevelExtension { - level: b & 0x7F, - voice: (b & 0x80) != 0, - }) - } -} - -impl MarshalSize for AudioLevelExtension { - /// MarshalSize returns the size of the AudioLevelExtension once marshaled. - fn marshal_size(&self) -> usize { - AUDIO_LEVEL_EXTENSION_SIZE - } -} - -impl Marshal for AudioLevelExtension { - /// MarshalTo serializes the members to buffer - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < AUDIO_LEVEL_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - if self.level > 127 { - return Err(Error::AudioLevelOverflow.into()); - } - let voice = if self.voice { 0x80u8 } else { 0u8 }; - - buf.put_u8(voice | self.level); - - Ok(AUDIO_LEVEL_EXTENSION_SIZE) - } -} diff --git a/rtp/src/extension/mod.rs b/rtp/src/extension/mod.rs deleted file mode 100644 index 6e2817409..000000000 --- a/rtp/src/extension/mod.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::borrow::Cow; -use std::fmt; - -use util::{Marshal, MarshalSize}; - -pub mod abs_send_time_extension; -pub mod audio_level_extension; -pub mod playout_delay_extension; -pub mod transport_cc_extension; -pub mod video_orientation_extension; - -/// A generic RTP header extension. -pub enum HeaderExtension { - AbsSendTime(abs_send_time_extension::AbsSendTimeExtension), - AudioLevel(audio_level_extension::AudioLevelExtension), - PlayoutDelay(playout_delay_extension::PlayoutDelayExtension), - TransportCc(transport_cc_extension::TransportCcExtension), - VideoOrientation(video_orientation_extension::VideoOrientationExtension), - - /// A custom extension - Custom { - uri: Cow<'static, str>, - extension: Box, - }, -} - -impl HeaderExtension { - pub fn uri(&self) -> Cow<'static, str> { - use HeaderExtension::*; - - match self { - AbsSendTime(_) => "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time".into(), - AudioLevel(_) => "urn:ietf:params:rtp-hdrext:ssrc-audio-level".into(), - PlayoutDelay(_) => "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay".into(), - TransportCc(_) => { - "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01".into() - } - VideoOrientation(_) => "urn:3gpp:video-orientation".into(), - Custom { uri, .. } => uri.clone(), - } - } - - pub fn is_same(&self, other: &Self) -> bool { - use HeaderExtension::*; - match (self, other) { - (AbsSendTime(_), AbsSendTime(_)) => true, - (AudioLevel(_), AudioLevel(_)) => true, - (TransportCc(_), TransportCc(_)) => true, - (VideoOrientation(_), VideoOrientation(_)) => true, - (Custom { uri, .. }, Custom { uri: other_uri, .. }) => uri == other_uri, - _ => false, - } - } -} - -impl MarshalSize for HeaderExtension { - fn marshal_size(&self) -> usize { - use HeaderExtension::*; - match self { - AbsSendTime(ext) => ext.marshal_size(), - AudioLevel(ext) => ext.marshal_size(), - PlayoutDelay(ext) => ext.marshal_size(), - TransportCc(ext) => ext.marshal_size(), - VideoOrientation(ext) => ext.marshal_size(), - Custom { extension: ext, .. } => ext.marshal_size(), - } - } -} - -impl Marshal for HeaderExtension { - fn marshal_to(&self, buf: &mut [u8]) -> util::Result { - use HeaderExtension::*; - match self { - AbsSendTime(ext) => ext.marshal_to(buf), - AudioLevel(ext) => ext.marshal_to(buf), - PlayoutDelay(ext) => ext.marshal_to(buf), - TransportCc(ext) => ext.marshal_to(buf), - VideoOrientation(ext) => ext.marshal_to(buf), - Custom { extension: ext, .. } => ext.marshal_to(buf), - } - } -} - -impl fmt::Debug for HeaderExtension { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use HeaderExtension::*; - - match self { - AbsSendTime(ext) => f.debug_tuple("AbsSendTime").field(ext).finish(), - AudioLevel(ext) => f.debug_tuple("AudioLevel").field(ext).finish(), - PlayoutDelay(ext) => f.debug_tuple("PlayoutDelay").field(ext).finish(), - TransportCc(ext) => f.debug_tuple("TransportCc").field(ext).finish(), - VideoOrientation(ext) => f.debug_tuple("VideoOrientation").field(ext).finish(), - Custom { uri, extension: _ } => f.debug_struct("Custom").field("uri", uri).finish(), - } - } -} diff --git a/rtp/src/extension/playout_delay_extension/mod.rs b/rtp/src/extension/playout_delay_extension/mod.rs deleted file mode 100644 index 55e264cc8..000000000 --- a/rtp/src/extension/playout_delay_extension/mod.rs +++ /dev/null @@ -1,82 +0,0 @@ -#[cfg(test)] -mod playout_delay_extension_test; - -use bytes::BufMut; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; - -pub const PLAYOUT_DELAY_EXTENSION_SIZE: usize = 3; -pub const PLAYOUT_DELAY_MAX_VALUE: u16 = (1 << 12) - 1; - -/// PlayoutDelayExtension is an extension payload format described in -/// http://www.webrtc.org/experiments/rtp-hdrext/playout-delay -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ID | len=2 | MIN delay | MAX delay | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(PartialEq, Eq, Debug, Default, Copy, Clone)] -pub struct PlayoutDelayExtension { - pub min_delay: u16, - pub max_delay: u16, -} - -impl Unmarshal for PlayoutDelayExtension { - /// Unmarshal parses the passed byte slice and stores the result in the members. - fn unmarshal(buf: &mut B) -> util::Result - where - Self: Sized, - B: bytes::Buf, - { - if buf.remaining() < PLAYOUT_DELAY_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - - let b0 = buf.get_u8(); - let b1 = buf.get_u8(); - let b2 = buf.get_u8(); - - let min_delay = u16::from_be_bytes([b0, b1]) >> 4; - let max_delay = u16::from_be_bytes([b1, b2]) & 0x0FFF; - - Ok(PlayoutDelayExtension { - min_delay, - max_delay, - }) - } -} - -impl MarshalSize for PlayoutDelayExtension { - /// MarshalSize returns the size of the PlayoutDelayExtension once marshaled. - fn marshal_size(&self) -> usize { - PLAYOUT_DELAY_EXTENSION_SIZE - } -} - -impl Marshal for PlayoutDelayExtension { - /// MarshalTo serializes the members to buffer - fn marshal_to(&self, mut buf: &mut [u8]) -> util::Result { - if buf.remaining_mut() < PLAYOUT_DELAY_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - if self.min_delay > PLAYOUT_DELAY_MAX_VALUE || self.max_delay > PLAYOUT_DELAY_MAX_VALUE { - return Err(Error::PlayoutDelayOverflow.into()); - } - - buf.put_u8((self.min_delay >> 4) as u8); - buf.put_u8(((self.min_delay << 4) as u8) | (self.max_delay >> 8) as u8); - buf.put_u8(self.max_delay as u8); - - Ok(PLAYOUT_DELAY_EXTENSION_SIZE) - } -} - -impl PlayoutDelayExtension { - pub fn new(min_delay: u16, max_delay: u16) -> Self { - PlayoutDelayExtension { - min_delay, - max_delay, - } - } -} diff --git a/rtp/src/extension/playout_delay_extension/playout_delay_extension_test.rs b/rtp/src/extension/playout_delay_extension/playout_delay_extension_test.rs deleted file mode 100644 index 1cb7a5faf..000000000 --- a/rtp/src/extension/playout_delay_extension/playout_delay_extension_test.rs +++ /dev/null @@ -1,38 +0,0 @@ -use bytes::BytesMut; - -use crate::error::Result; - -use super::*; - -#[test] -fn test_playout_delay_extension_roundtrip() -> Result<()> { - let test = PlayoutDelayExtension { - max_delay: 2345, - min_delay: 1234, - }; - - let mut raw = BytesMut::with_capacity(test.marshal_size()); - raw.resize(test.marshal_size(), 0); - test.marshal_to(&mut raw)?; - let raw = raw.freeze(); - let buf = &mut raw.clone(); - let out = PlayoutDelayExtension::unmarshal(buf)?; - assert_eq!(test, out); - - Ok(()) -} - -#[test] -fn test_playout_delay_value_overflow() -> Result<()> { - let test = PlayoutDelayExtension { - max_delay: u16::MAX, - min_delay: u16::MAX, - }; - - let mut dst = BytesMut::with_capacity(test.marshal_size()); - dst.resize(test.marshal_size(), 0); - let result = test.marshal_to(&mut dst); - assert!(result.is_err()); - - Ok(()) -} diff --git a/rtp/src/extension/transport_cc_extension/mod.rs b/rtp/src/extension/transport_cc_extension/mod.rs deleted file mode 100644 index cb72ed1df..000000000 --- a/rtp/src/extension/transport_cc_extension/mod.rs +++ /dev/null @@ -1,61 +0,0 @@ -#[cfg(test)] -mod transport_cc_extension_test; - -use bytes::{Buf, BufMut}; -use serde::{Deserialize, Serialize}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; - -// transport-wide sequence -pub const TRANSPORT_CC_EXTENSION_SIZE: usize = 2; - -/// TransportCCExtension is a extension payload format in -/// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | 0xBE | 0xDE | length=1 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ID | L=1 |transport-wide sequence number | zero padding | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(PartialEq, Eq, Debug, Default, Copy, Clone, Serialize, Deserialize)] -pub struct TransportCcExtension { - pub transport_sequence: u16, -} - -impl Unmarshal for TransportCcExtension { - /// Unmarshal parses the passed byte slice and stores the result in the members - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - if raw_packet.remaining() < TRANSPORT_CC_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - let b0 = raw_packet.get_u8(); - let b1 = raw_packet.get_u8(); - - let transport_sequence = ((b0 as u16) << 8) | b1 as u16; - Ok(TransportCcExtension { transport_sequence }) - } -} - -impl MarshalSize for TransportCcExtension { - /// MarshalSize returns the size of the TransportCcExtension once marshaled. - fn marshal_size(&self) -> usize { - TRANSPORT_CC_EXTENSION_SIZE - } -} - -impl Marshal for TransportCcExtension { - /// Marshal serializes the members to buffer - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < TRANSPORT_CC_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - buf.put_u16(self.transport_sequence); - Ok(TRANSPORT_CC_EXTENSION_SIZE) - } -} diff --git a/rtp/src/extension/transport_cc_extension/transport_cc_extension_test.rs b/rtp/src/extension/transport_cc_extension/transport_cc_extension_test.rs deleted file mode 100644 index f647618b3..000000000 --- a/rtp/src/extension/transport_cc_extension/transport_cc_extension_test.rs +++ /dev/null @@ -1,44 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::*; -use crate::error::Result; - -#[test] -fn test_transport_cc_extension_too_small() -> Result<()> { - let mut buf = &vec![0u8; 0][..]; - let result = TransportCcExtension::unmarshal(&mut buf); - assert!(result.is_err()); - - Ok(()) -} - -#[test] -fn test_transport_cc_extension() -> Result<()> { - let raw = Bytes::from_static(&[0x00, 0x02]); - let buf = &mut raw.clone(); - let t1 = TransportCcExtension::unmarshal(buf)?; - let t2 = TransportCcExtension { - transport_sequence: 2, - }; - assert_eq!(t1, t2); - - let mut dst = BytesMut::with_capacity(t2.marshal_size()); - dst.resize(t2.marshal_size(), 0); - t2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_transport_cc_extension_extra_bytes() -> Result<()> { - let mut raw = Bytes::from_static(&[0x00, 0x02, 0x00, 0xff, 0xff]); - let buf = &mut raw; - let t1 = TransportCcExtension::unmarshal(buf)?; - let t2 = TransportCcExtension { - transport_sequence: 2, - }; - assert_eq!(t1, t2); - - Ok(()) -} diff --git a/rtp/src/extension/video_orientation_extension/mod.rs b/rtp/src/extension/video_orientation_extension/mod.rs deleted file mode 100644 index d03ced95a..000000000 --- a/rtp/src/extension/video_orientation_extension/mod.rs +++ /dev/null @@ -1,132 +0,0 @@ -#[cfg(test)] -mod video_orientation_extension_test; - -use std::convert::{TryFrom, TryInto}; - -use bytes::BufMut; -use serde::{Deserialize, Serialize}; -use util::marshal::Unmarshal; -use util::{Marshal, MarshalSize}; - -use crate::Error; - -// One byte header size -pub const VIDEO_ORIENTATION_EXTENSION_SIZE: usize = 1; - -/// Coordination of Video Orientation in RTP streams. -/// -/// Coordination of Video Orientation consists in signaling of the current -/// orientation of the image captured on the sender side to the receiver for -/// appropriate rendering and displaying. -/// -/// C = Camera: indicates the direction of the camera used for this video -/// stream. It can be used by the MTSI client in receiver to e.g. display -/// the received video differently depending on the source camera. -/// -/// 0: Front-facing camera, facing the user. If camera direction is -/// unknown by the sending MTSI client in the terminal then this is the -/// default value used. -/// 1: Back-facing camera, facing away from the user. -/// -/// F = Flip: indicates a horizontal (left-right flip) mirror operation on -/// the video as sent on the link. -/// -/// 0 1 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | ID | len=0 |0 0 0 0 C F R R| -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(PartialEq, Eq, Debug, Default, Copy, Clone, Serialize, Deserialize)] -pub struct VideoOrientationExtension { - pub direction: CameraDirection, - pub flip: bool, - pub rotation: VideoRotation, -} - -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone, Serialize, Deserialize)] -pub enum CameraDirection { - #[default] - Front = 0, - Back = 1, -} - -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone, Serialize, Deserialize)] -pub enum VideoRotation { - #[default] - Degree0 = 0, - Degree90 = 1, - Degree180 = 2, - Degree270 = 3, -} - -impl MarshalSize for VideoOrientationExtension { - fn marshal_size(&self) -> usize { - VIDEO_ORIENTATION_EXTENSION_SIZE - } -} - -impl Unmarshal for VideoOrientationExtension { - fn unmarshal(buf: &mut B) -> util::Result - where - Self: Sized, - B: bytes::Buf, - { - if buf.remaining() < VIDEO_ORIENTATION_EXTENSION_SIZE { - return Err(Error::ErrBufferTooSmall.into()); - } - - let b = buf.get_u8(); - - let c = (b & 0b1000) >> 3; - let f = b & 0b0100; - let r = b & 0b0011; - - Ok(VideoOrientationExtension { - direction: c.try_into()?, - flip: f > 0, - rotation: r.try_into()?, - }) - } -} - -impl Marshal for VideoOrientationExtension { - fn marshal_to(&self, mut buf: &mut [u8]) -> util::Result { - let c = (self.direction as u8) << 3; - let f = if self.flip { 0b0100 } else { 0 }; - let r = self.rotation as u8; - - buf.put_u8(c | f | r); - - Ok(VIDEO_ORIENTATION_EXTENSION_SIZE) - } -} - -impl TryFrom for CameraDirection { - type Error = util::Error; - - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(CameraDirection::Front), - 1 => Ok(CameraDirection::Back), - _ => Err(util::Error::Other(format!( - "Unhandled camera direction: {value}" - ))), - } - } -} - -impl TryFrom for VideoRotation { - type Error = util::Error; - - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(VideoRotation::Degree0), - 1 => Ok(VideoRotation::Degree90), - 2 => Ok(VideoRotation::Degree180), - 3 => Ok(VideoRotation::Degree270), - _ => Err(util::Error::Other(format!( - "Unhandled video rotation: {value}" - ))), - } - } -} diff --git a/rtp/src/extension/video_orientation_extension/video_orientation_extension_test.rs b/rtp/src/extension/video_orientation_extension/video_orientation_extension_test.rs deleted file mode 100644 index c8f680958..000000000 --- a/rtp/src/extension/video_orientation_extension/video_orientation_extension_test.rs +++ /dev/null @@ -1,113 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::*; -use crate::error::Result; - -#[test] -fn test_video_orientation_extension_too_small() -> Result<()> { - let mut buf = &vec![0u8; 0][..]; - let result = VideoOrientationExtension::unmarshal(&mut buf); - assert!(result.is_err()); - - Ok(()) -} - -#[test] -fn test_video_orientation_extension_back_facing_camera() -> Result<()> { - let raw = Bytes::from_static(&[0b1000]); - let buf = &mut raw.clone(); - let a1 = VideoOrientationExtension::unmarshal(buf)?; - let a2 = VideoOrientationExtension { - direction: CameraDirection::Back, - flip: false, - rotation: VideoRotation::Degree0, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_video_orientation_extension_flip_true() -> Result<()> { - let raw = Bytes::from_static(&[0b0100]); - let buf = &mut raw.clone(); - let a1 = VideoOrientationExtension::unmarshal(buf)?; - let a2 = VideoOrientationExtension { - direction: CameraDirection::Front, - flip: true, - rotation: VideoRotation::Degree0, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_video_orientation_extension_degree_90() -> Result<()> { - let raw = Bytes::from_static(&[0b0001]); - let buf = &mut raw.clone(); - let a1 = VideoOrientationExtension::unmarshal(buf)?; - let a2 = VideoOrientationExtension { - direction: CameraDirection::Front, - flip: false, - rotation: VideoRotation::Degree90, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_video_orientation_extension_degree_180() -> Result<()> { - let raw = Bytes::from_static(&[0b0010]); - let buf = &mut raw.clone(); - let a1 = VideoOrientationExtension::unmarshal(buf)?; - let a2 = VideoOrientationExtension { - direction: CameraDirection::Front, - flip: false, - rotation: VideoRotation::Degree180, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} - -#[test] -fn test_video_orientation_extension_degree_270() -> Result<()> { - let raw = Bytes::from_static(&[0b0011]); - let buf = &mut raw.clone(); - let a1 = VideoOrientationExtension::unmarshal(buf)?; - let a2 = VideoOrientationExtension { - direction: CameraDirection::Front, - flip: false, - rotation: VideoRotation::Degree270, - }; - assert_eq!(a1, a2); - - let mut dst = BytesMut::with_capacity(a2.marshal_size()); - dst.resize(a2.marshal_size(), 0); - a2.marshal_to(&mut dst)?; - assert_eq!(raw, dst.freeze()); - - Ok(()) -} diff --git a/rtp/src/header.rs b/rtp/src/header.rs deleted file mode 100644 index 14097c794..000000000 --- a/rtp/src/header.rs +++ /dev/null @@ -1,477 +0,0 @@ -use bytes::{Buf, BufMut, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; - -pub const HEADER_LENGTH: usize = 4; -pub const VERSION_SHIFT: u8 = 6; -pub const VERSION_MASK: u8 = 0x3; -pub const PADDING_SHIFT: u8 = 5; -pub const PADDING_MASK: u8 = 0x1; -pub const EXTENSION_SHIFT: u8 = 4; -pub const EXTENSION_MASK: u8 = 0x1; -pub const EXTENSION_PROFILE_ONE_BYTE: u16 = 0xBEDE; -pub const EXTENSION_PROFILE_TWO_BYTE: u16 = 0x1000; -pub const EXTENSION_ID_RESERVED: u8 = 0xF; -pub const CC_MASK: u8 = 0xF; -pub const MARKER_SHIFT: u8 = 7; -pub const MARKER_MASK: u8 = 0x1; -pub const PT_MASK: u8 = 0x7F; -pub const SEQ_NUM_OFFSET: usize = 2; -pub const SEQ_NUM_LENGTH: usize = 2; -pub const TIMESTAMP_OFFSET: usize = 4; -pub const TIMESTAMP_LENGTH: usize = 4; -pub const SSRC_OFFSET: usize = 8; -pub const SSRC_LENGTH: usize = 4; -pub const CSRC_OFFSET: usize = 12; -pub const CSRC_LENGTH: usize = 4; - -#[derive(Debug, Eq, PartialEq, Default, Clone)] -pub struct Extension { - pub id: u8, - pub payload: Bytes, -} - -/// Header represents an RTP packet header -/// NOTE: PayloadOffset is populated by Marshal/Unmarshal and should not be modified -#[derive(Debug, Eq, PartialEq, Default, Clone)] -pub struct Header { - pub version: u8, - pub padding: bool, - pub extension: bool, - pub marker: bool, - pub payload_type: u8, - pub sequence_number: u16, - pub timestamp: u32, - pub ssrc: u32, - pub csrc: Vec, - pub extension_profile: u16, - pub extensions: Vec, - pub extensions_padding: usize, -} - -impl Unmarshal for Header { - /// Unmarshal parses the passed byte slice and stores the result in the Header this method is called upon - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let raw_packet_len = raw_packet.remaining(); - if raw_packet_len < HEADER_LENGTH { - return Err(Error::ErrHeaderSizeInsufficient.into()); - } - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * |V=2|P|X| CC |M| PT | sequence number | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | timestamp | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | synchronization source (SSRC) identifier | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | contributing source (CSRC) identifiers | - * | .... | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let b0 = raw_packet.get_u8(); - let version = b0 >> VERSION_SHIFT & VERSION_MASK; - let padding = (b0 >> PADDING_SHIFT & PADDING_MASK) > 0; - let extension = (b0 >> EXTENSION_SHIFT & EXTENSION_MASK) > 0; - let cc = (b0 & CC_MASK) as usize; - - let mut curr_offset = CSRC_OFFSET + (cc * CSRC_LENGTH); - if raw_packet_len < curr_offset { - return Err(Error::ErrHeaderSizeInsufficient.into()); - } - - let b1 = raw_packet.get_u8(); - let marker = (b1 >> MARKER_SHIFT & MARKER_MASK) > 0; - let payload_type = b1 & PT_MASK; - - let sequence_number = raw_packet.get_u16(); - let timestamp = raw_packet.get_u32(); - let ssrc = raw_packet.get_u32(); - - let mut csrc = Vec::with_capacity(cc); - for _ in 0..cc { - csrc.push(raw_packet.get_u32()); - } - let mut extensions_padding: usize = 0; - let (extension_profile, extensions) = if extension { - let expected = curr_offset + 4; - if raw_packet_len < expected { - return Err(Error::ErrHeaderSizeInsufficientForExtension.into()); - } - let extension_profile = raw_packet.get_u16(); - curr_offset += 2; - let extension_length = raw_packet.get_u16() as usize * 4; - curr_offset += 2; - - let expected = curr_offset + extension_length; - if raw_packet_len < expected { - return Err(Error::ErrHeaderSizeInsufficientForExtension.into()); - } - - let mut extensions = vec![]; - match extension_profile { - // RFC 8285 RTP One Byte Header Extension - EXTENSION_PROFILE_ONE_BYTE => { - let end = curr_offset + extension_length; - while curr_offset < end { - let b = raw_packet.get_u8(); - if b == 0x00 { - // padding - curr_offset += 1; - extensions_padding += 1; - continue; - } - - let extid = b >> 4; - let len = ((b & (0xFF ^ 0xF0)) + 1) as usize; - curr_offset += 1; - - if extid == EXTENSION_ID_RESERVED { - break; - } - - extensions.push(Extension { - id: extid, - payload: raw_packet.copy_to_bytes(len), - }); - curr_offset += len; - } - } - // RFC 8285 RTP Two Byte Header Extension - EXTENSION_PROFILE_TWO_BYTE => { - let end = curr_offset + extension_length; - while curr_offset < end { - let b = raw_packet.get_u8(); - if b == 0x00 { - // padding - curr_offset += 1; - extensions_padding += 1; - continue; - } - - let extid = b; - curr_offset += 1; - - let len = raw_packet.get_u8() as usize; - curr_offset += 1; - - extensions.push(Extension { - id: extid, - payload: raw_packet.copy_to_bytes(len), - }); - curr_offset += len; - } - } - // RFC3550 Extension - _ => { - if raw_packet_len < curr_offset + extension_length { - return Err(Error::ErrHeaderSizeInsufficientForExtension.into()); - } - extensions.push(Extension { - id: 0, - payload: raw_packet.copy_to_bytes(extension_length), - }); - } - }; - - (extension_profile, extensions) - } else { - (0, vec![]) - }; - - Ok(Header { - version, - padding, - extension, - marker, - payload_type, - sequence_number, - timestamp, - ssrc, - csrc, - extension_profile, - extensions, - extensions_padding, - }) - } -} - -impl MarshalSize for Header { - /// MarshalSize returns the size of the packet once marshaled. - fn marshal_size(&self) -> usize { - let mut head_size = 12 + (self.csrc.len() * CSRC_LENGTH); - if self.extension { - let extension_payload_len = self.get_extension_payload_len() + self.extensions_padding; - let extension_payload_size = (extension_payload_len + 3) / 4; - head_size += 4 + extension_payload_size * 4; - } - head_size - } -} - -impl Marshal for Header { - /// Marshal serializes the header and writes to the buffer. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - /* - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * |V=2|P|X| CC |M| PT | sequence number | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | timestamp | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | synchronization source (SSRC) identifier | - * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ - * | contributing source (CSRC) identifiers | - * | .... | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - let remaining_before = buf.remaining_mut(); - if remaining_before < self.marshal_size() { - return Err(Error::ErrBufferTooSmall.into()); - } - - // The first byte contains the version, padding bit, extension bit, and csrc size - let mut b0 = (self.version << VERSION_SHIFT) | self.csrc.len() as u8; - if self.padding { - b0 |= 1 << PADDING_SHIFT; - } - - if self.extension { - b0 |= 1 << EXTENSION_SHIFT; - } - buf.put_u8(b0); - - // The second byte contains the marker bit and payload type. - let mut b1 = self.payload_type; - if self.marker { - b1 |= 1 << MARKER_SHIFT; - } - buf.put_u8(b1); - - buf.put_u16(self.sequence_number); - buf.put_u32(self.timestamp); - buf.put_u32(self.ssrc); - - for csrc in &self.csrc { - buf.put_u32(*csrc); - } - - if self.extension { - buf.put_u16(self.extension_profile); - - // calculate extensions size and round to 4 bytes boundaries - let extension_payload_len = self.get_extension_payload_len(); - if self.extension_profile != EXTENSION_PROFILE_ONE_BYTE - && self.extension_profile != EXTENSION_PROFILE_TWO_BYTE - && extension_payload_len % 4 != 0 - { - //the payload must be in 32-bit words. - return Err(Error::HeaderExtensionPayloadNot32BitWords.into()); - } - let extension_payload_size = (extension_payload_len as u16 + 3) / 4; - buf.put_u16(extension_payload_size); - - match self.extension_profile { - // RFC 8285 RTP One Byte Header Extension - EXTENSION_PROFILE_ONE_BYTE => { - for extension in &self.extensions { - buf.put_u8((extension.id << 4) | (extension.payload.len() as u8 - 1)); - buf.put(&*extension.payload); - } - } - // RFC 8285 RTP Two Byte Header Extension - EXTENSION_PROFILE_TWO_BYTE => { - for extension in &self.extensions { - buf.put_u8(extension.id); - buf.put_u8(extension.payload.len() as u8); - buf.put(&*extension.payload); - } - } - // RFC3550 Extension - _ => { - if self.extensions.len() != 1 { - return Err(Error::ErrRfc3550headerIdrange.into()); - } - - if let Some(extension) = self.extensions.first() { - let ext_len = extension.payload.len(); - if ext_len % 4 != 0 { - return Err(Error::HeaderExtensionPayloadNot32BitWords.into()); - } - buf.put(&*extension.payload); - } - } - }; - - // add padding to reach 4 bytes boundaries - for _ in extension_payload_len..extension_payload_size as usize * 4 { - buf.put_u8(0); - } - } - - let remaining_after = buf.remaining_mut(); - Ok(remaining_before - remaining_after) - } -} - -impl Header { - pub fn get_extension_payload_len(&self) -> usize { - let payload_len: usize = self - .extensions - .iter() - .map(|extension| extension.payload.len()) - .sum(); - - let profile_len = self.extensions.len() - * match self.extension_profile { - EXTENSION_PROFILE_ONE_BYTE => 1, - EXTENSION_PROFILE_TWO_BYTE => 2, - _ => 0, - }; - - payload_len + profile_len - } - - /// SetExtension sets an RTP header extension - pub fn set_extension(&mut self, id: u8, payload: Bytes) -> Result<(), Error> { - let payload_len = payload.len() as isize; - if self.extension { - let extension_profile_len = match self.extension_profile { - EXTENSION_PROFILE_ONE_BYTE => { - if !(1..=14).contains(&id) { - return Err(Error::ErrRfc8285oneByteHeaderIdrange); - } - if payload_len > 16 { - return Err(Error::ErrRfc8285oneByteHeaderSize); - } - 1 - } - EXTENSION_PROFILE_TWO_BYTE => { - if id < 1 { - return Err(Error::ErrRfc8285twoByteHeaderIdrange); - } - if payload_len > 255 { - return Err(Error::ErrRfc8285twoByteHeaderSize); - } - 2 - } - _ => { - if id != 0 { - return Err(Error::ErrRfc3550headerIdrange); - } - 0 - } - }; - - let delta; - // Update existing if it exists else add new extension - if let Some(extension) = self - .extensions - .iter_mut() - .find(|extension| extension.id == id) - { - delta = payload_len - extension.payload.len() as isize; - extension.payload = payload; - } else { - delta = payload_len + extension_profile_len; - self.extensions.push(Extension { id, payload }); - } - - match delta.cmp(&0) { - std::cmp::Ordering::Less => { - self.extensions_padding = - ((self.extensions_padding as isize - delta) % 4) as usize; - } - std::cmp::Ordering::Greater => { - let extension_padding = (delta % 4) as usize; - if self.extensions_padding < extension_padding { - self.extensions_padding = (self.extensions_padding + 4) - extension_padding; - } else { - self.extensions_padding -= extension_padding - } - } - _ => {} - } - } else { - // No existing header extensions - self.extension = true; - let mut extension_profile_len = 0; - self.extension_profile = match payload_len { - 0..=16 => { - extension_profile_len = 1; - EXTENSION_PROFILE_ONE_BYTE - } - 17..=255 => { - extension_profile_len = 2; - EXTENSION_PROFILE_TWO_BYTE - } - _ => self.extension_profile, - }; - - let extension_padding = (payload.len() + extension_profile_len) % 4; - if self.extensions_padding < extension_padding { - self.extensions_padding = self.extensions_padding + 4 - extension_padding; - } else { - self.extensions_padding -= extension_padding - } - self.extensions.push(Extension { id, payload }); - } - Ok(()) - } - - /// returns an extension id array - pub fn get_extension_ids(&self) -> Vec { - if self.extension { - self.extensions.iter().map(|e| e.id).collect() - } else { - vec![] - } - } - - /// returns an RTP header extension - pub fn get_extension(&self, id: u8) -> Option { - if self.extension { - self.extensions - .iter() - .find(|extension| extension.id == id) - .map(|extension| extension.payload.clone()) - } else { - None - } - } - - /// Removes an RTP Header extension - pub fn del_extension(&mut self, id: u8) -> Result<(), Error> { - if self.extension { - if let Some(index) = self - .extensions - .iter() - .position(|extension| extension.id == id) - { - let extension = self.extensions.remove(index); - - let extension_profile_len = match self.extension_profile { - EXTENSION_PROFILE_ONE_BYTE => 1, - EXTENSION_PROFILE_TWO_BYTE => 2, - _ => 0, - }; - - let extension_padding = (extension.payload.len() + extension_profile_len) % 4; - self.extensions_padding = (self.extensions_padding + extension_padding) % 4; - - Ok(()) - } else { - Err(Error::ErrHeaderExtensionNotFound) - } - } else { - Err(Error::ErrHeaderExtensionsNotEnabled) - } - } -} diff --git a/rtp/src/lib.rs b/rtp/src/lib.rs deleted file mode 100644 index 18066e7e5..000000000 --- a/rtp/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod codecs; -mod error; -pub mod extension; -pub mod header; -pub mod packet; -pub mod packetizer; -pub mod sequence; - -pub use error::Error; diff --git a/rtp/src/packet/mod.rs b/rtp/src/packet/mod.rs deleted file mode 100644 index 3a76b7eaf..000000000 --- a/rtp/src/packet/mod.rs +++ /dev/null @@ -1,122 +0,0 @@ -#[cfg(test)] -mod packet_test; - -use std::fmt; - -use bytes::{Buf, BufMut, Bytes}; -use util::marshal::{Marshal, MarshalSize, Unmarshal}; - -use crate::error::Error; -use crate::header::*; - -/// Packet represents an RTP Packet -/// NOTE: Raw is populated by Marshal/Unmarshal and should not be modified -#[derive(Debug, Eq, PartialEq, Default, Clone)] -pub struct Packet { - pub header: Header, - pub payload: Bytes, -} - -impl fmt::Display for Packet { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = "RTP PACKET:\n".to_string(); - - out += format!("\tVersion: {}\n", self.header.version).as_str(); - out += format!("\tMarker: {}\n", self.header.marker).as_str(); - out += format!("\tPayload Type: {}\n", self.header.payload_type).as_str(); - out += format!("\tSequence Number: {}\n", self.header.sequence_number).as_str(); - out += format!("\tTimestamp: {}\n", self.header.timestamp).as_str(); - out += format!("\tSSRC: {} ({:x})\n", self.header.ssrc, self.header.ssrc).as_str(); - out += format!("\tPayload Length: {}\n", self.payload.len()).as_str(); - - write!(f, "{out}") - } -} - -impl Unmarshal for Packet { - /// Unmarshal parses the passed byte slice and stores the result in the Header this method is called upon - fn unmarshal(raw_packet: &mut B) -> Result - where - Self: Sized, - B: Buf, - { - let header = Header::unmarshal(raw_packet)?; - let payload_len = raw_packet.remaining(); - let payload = raw_packet.copy_to_bytes(payload_len); - if header.padding { - if payload_len > 0 { - let padding_len = payload[payload_len - 1] as usize; - if padding_len <= payload_len { - Ok(Packet { - header, - payload: payload.slice(..payload_len - padding_len), - }) - } else { - Err(Error::ErrShortPacket.into()) - } - } else { - Err(Error::ErrShortPacket.into()) - } - } else { - Ok(Packet { header, payload }) - } - } -} - -impl MarshalSize for Packet { - /// MarshalSize returns the size of the packet once marshaled. - fn marshal_size(&self) -> usize { - let payload_len = self.payload.len(); - let padding_len = if self.header.padding { - let padding_len = get_padding(payload_len); - if padding_len == 0 { - 4 - } else { - padding_len - } - } else { - 0 - }; - self.header.marshal_size() + payload_len + padding_len - } -} - -impl Marshal for Packet { - /// MarshalTo serializes the packet and writes to the buffer. - fn marshal_to(&self, mut buf: &mut [u8]) -> Result { - if buf.remaining_mut() < self.marshal_size() { - return Err(Error::ErrBufferTooSmall.into()); - } - - let n = self.header.marshal_to(buf)?; - buf = &mut buf[n..]; - buf.put(&*self.payload); - let padding_len = if self.header.padding { - let mut padding_len = get_padding(self.payload.len()); - if padding_len == 0 { - padding_len = 4; - } - for i in 0..padding_len { - if i != padding_len - 1 { - buf.put_u8(0); - } else { - buf.put_u8(padding_len as u8); - } - } - padding_len - } else { - 0 - }; - - Ok(n + self.payload.len() + padding_len) - } -} - -/// getPadding Returns the padding required to make the length a multiple of 4 -fn get_padding(len: usize) -> usize { - if len % 4 == 0 { - 0 - } else { - 4 - (len % 4) - } -} diff --git a/rtp/src/packet/packet_test.rs b/rtp/src/packet/packet_test.rs deleted file mode 100644 index 926f922e6..000000000 --- a/rtp/src/packet/packet_test.rs +++ /dev/null @@ -1,1243 +0,0 @@ -// Silence warning on `..Default::default()` with no effect: -#![allow(clippy::needless_update)] - -use bytes::{Bytes, BytesMut}; - -use super::*; -use crate::error::Result; - -#[test] -fn test_basic() -> Result<()> { - let mut empty_bytes = &vec![0u8; 0][..]; - let result = Packet::unmarshal(&mut empty_bytes); - assert!( - result.is_err(), - "Unmarshal did not error on zero length packet" - ); - - let raw_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x00, 0x01, 0x00, - 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - let parsed_packet = Packet { - header: Header { - version: 2, - padding: false, - extension: true, - marker: true, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - extension_profile: 1, - extensions: vec![Extension { - id: 0, - payload: Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF]), - }], - ..Default::default() - }, - payload: Bytes::from_static(&[0x98, 0x36, 0xbe, 0x88, 0x9e]), - }; - let buf = &mut raw_pkt.clone(); - let packet = Packet::unmarshal(buf)?; - assert_eq!( - packet, parsed_packet, - "TestBasic unmarshal: got {packet}, want {parsed_packet}" - ); - assert_eq!( - packet.header.marshal_size(), - 20, - "wrong computed header marshal size" - ); - assert_eq!( - packet.marshal_size(), - raw_pkt.len(), - "wrong computed marshal size" - ); - - let raw = packet.marshal()?; - let n = raw.len(); - assert_eq!(n, raw_pkt.len(), "wrong marshal size"); - - assert_eq!( - raw.len(), - raw_pkt.len(), - "wrong raw marshal size {} vs {}", - raw.len(), - raw_pkt.len() - ); - assert_eq!( - raw, raw_pkt, - "TestBasic marshal: got {raw:?}, want {raw_pkt:?}" - ); - - Ok(()) -} - -#[test] -fn test_extension() -> Result<()> { - let mut missing_extension_pkt = Bytes::from_static(&[ - 0x90, 0x60, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, - ]); - let buf = &mut missing_extension_pkt; - let result = Packet::unmarshal(buf); - assert!( - result.is_err(), - "Unmarshal did not error on packet with missing extension data" - ); - - let mut invalid_extension_length_pkt = Bytes::from_static(&[ - 0x90, 0x60, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x99, 0x99, 0x99, - 0x99, - ]); - let buf = &mut invalid_extension_length_pkt; - let result = Packet::unmarshal(buf); - assert!( - result.is_err(), - "Unmarshal did not error on packet with invalid extension length" - ); - - let packet = Packet { - header: Header { - extension: true, - extension_profile: 3, - extensions: vec![Extension { - id: 0, - payload: Bytes::from_static(&[0]), - }], - ..Default::default() - }, - payload: Bytes::from_static(&[]), - }; - - let mut raw = BytesMut::new(); - let result = packet.marshal_to(&mut raw); - assert!( - result.is_err(), - "Marshal did not error on packet with invalid extension length" - ); - if let Err(err) = result { - assert_eq!(Error::ErrBufferTooSmall, err); - } - - Ok(()) -} - -#[test] -fn test_padding() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0xa0, 0x60, 0x19, 0x58, 0x63, 0xff, 0x7d, 0x7c, 0x4b, 0x98, 0xd4, 0x0a, 0x67, 0x4d, 0x00, - 0x29, 0x9a, 0x64, 0x03, 0xc0, 0x11, 0x3f, 0x2c, 0xd4, 0x04, 0x04, 0x05, 0x00, 0x00, 0x03, - 0x03, 0xe8, 0x00, 0x00, 0xea, 0x60, 0x04, 0x00, 0x00, 0x03, - ]); - let buf = &mut raw_pkt.clone(); - let packet = Packet::unmarshal(buf)?; - assert_eq!(&packet.payload[..], &raw_pkt[12..12 + 25]); - - let raw = packet.marshal()?; - assert_eq!(raw, raw_pkt); - - Ok(()) -} - -#[test] -fn test_packet_marshal_unmarshal() -> Result<()> { - let pkt = Packet { - header: Header { - extension: true, - csrc: vec![1, 2], - extension_profile: EXTENSION_PROFILE_TWO_BYTE, - extensions: vec![ - Extension { - id: 1, - payload: Bytes::from_static(&[3, 4]), - }, - Extension { - id: 2, - payload: Bytes::from_static(&[5, 6]), - }, - ], - ..Default::default() - }, - payload: Bytes::from_static(&[0xFFu8; 15]), - ..Default::default() - }; - let mut raw = pkt.marshal()?; - let p = Packet::unmarshal(&mut raw)?; - - assert_eq!(pkt, p); - - Ok(()) -} - -#[test] -fn test_rfc_8285_one_byte_extension() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, - 0x01, 0x50, 0xAA, 0x00, 0x00, 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - let buf = &mut raw_pkt.clone(); - Packet::unmarshal(buf)?; - - let p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 5, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - ..Default::default() - }, - payload: raw_pkt.slice(20..), - }; - - let dst = p.marshal()?; - assert_eq!(dst, raw_pkt); - - Ok(()) -} - -#[test] -fn test_rfc_8285_one_byte_two_extension_of_two_bytes() -> Result<()> { - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | 0xBE | 0xDE | length=1 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | ID | L=0 | data | ID | L=0 | data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - let raw_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, - 0x01, 0x10, 0xAA, 0x20, 0xBB, // Payload - 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - let buf = &mut raw_pkt.clone(); - let p = Packet::unmarshal(buf)?; - - let ext1 = p.header.get_extension(1); - let ext1_expect = Bytes::from_static(&[0xAA]); - if let Some(ext1) = ext1 { - assert_eq!(ext1, ext1_expect); - } else { - panic!("ext1 is none"); - } - - let ext2 = p.header.get_extension(2); - let ext2_expect = Bytes::from_static(&[0xBB]); - if let Some(ext2) = ext2 { - assert_eq!(ext2, ext2_expect); - } else { - panic!("ext2 is none"); - } - - // Test Marshal - let p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![ - Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }, - Extension { - id: 2, - payload: Bytes::from_static(&[0xBB]), - }, - ], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - csrc: vec![], - ..Default::default() - }, - payload: raw_pkt.slice(20..), - }; - - let dst = p.marshal()?; - assert_eq!(dst, raw_pkt); - - Ok(()) -} - -#[test] -fn test_rfc_8285_one_byte_multiple_extensions_with_padding() -> Result<()> { - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | 0xBE | 0xDE | length=3 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | ID | L=0 | data | ID | L=1 | data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data | 0 (pad) | 0 (pad) | ID | L=3 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | data | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - - let mut raw_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, - 0x03, 0x10, 0xAA, 0x21, 0xBB, 0xBB, 0x00, 0x00, 0x33, 0xCC, 0xCC, 0xCC, 0xCC, - // Payload - 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - let buf = &mut raw_pkt; - let packet = Packet::unmarshal(buf)?; - let ext1 = packet - .header - .get_extension(1) - .expect("Error getting header extension."); - - let ext1_expect = Bytes::from_static(&[0xAA]); - assert_eq!(ext1, ext1_expect); - - let ext2 = packet - .header - .get_extension(2) - .expect("Error getting header extension."); - - let ext2_expect = Bytes::from_static(&[0xBB, 0xBB]); - assert_eq!(ext2, ext2_expect); - - let ext3 = packet - .header - .get_extension(3) - .expect("Error getting header extension."); - - let ext3_expect = Bytes::from_static(&[0xCC, 0xCC, 0xCC, 0xCC]); - assert_eq!(ext3, ext3_expect); - - let mut dst_buf: Vec> = vec![vec![0u8; 1000], vec![0xFF; 1000], vec![0xAA; 2]]; - - let raw_pkg_marshal: [u8; 33] = [ - // padding is moved to the end by re-marshaling - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, - 0x03, 0x10, 0xAA, 0x21, 0xBB, 0xBB, 0x33, 0xCC, 0xCC, 0xCC, 0xCC, 0x00, 0x00, - // Payload - 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]; - - let checker = |name: &str, buf: &mut [u8], p: &Packet| -> Result<()> { - let size = p.marshal_to(buf)?; - - assert_eq!( - &buf[..size], - &raw_pkg_marshal[..], - "Marshalled fields are not equal for {name}." - ); - - Ok(()) - }; - - checker("CleanBuffer", &mut dst_buf[0], &packet)?; - checker("DirtyBuffer", &mut dst_buf[1], &packet)?; - - let result = packet.marshal_to(&mut dst_buf[2]); - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrBufferTooSmall, err); - } - - Ok(()) -} - -fn test_rfc_8285_one_byte_multiple_extension() -> Result<()> { - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | 0xBE | 0xDE | length=3 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | ID=1 | L=0 | data | ID=2 | L=1 | data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data | ID=3 | L=3 | data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - let raw_pkt = &[ - 0x90u8, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, - 0x03, 0x10, 0xAA, 0x21, 0xBB, 0xBB, 0x33, 0xCC, 0xCC, 0xCC, 0xCC, 0x00, 0x00, - // Payload - 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]; - - let p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![ - Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }, - Extension { - id: 2, - payload: Bytes::from_static(&[0xBB, 0xBB]), - }, - Extension { - id: 3, - payload: Bytes::from_static(&[0xCC, 0xCC]), - }, - ], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload: raw_pkt[28..].into(), - }; - - let dst_data = p.marshal()?; - assert_eq!( - &dst_data[..], - raw_pkt, - "Marshal failed raw \nMarshaled:\n{dst_data:?}\nrawPkt:\n{raw_pkt:?}", - ); - - Ok(()) -} - -fn test_rfc_8285_two_byte_extension() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x10, 0x00, 0x00, - 0x07, 0x05, 0x18, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, - 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0x00, 0x00, 0x98, - 0x36, 0xbe, 0x88, 0x9e, - ]); - - let _ = Packet::unmarshal(&mut raw_pkt.clone())?; - - let p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0x1000, - extensions: vec![Extension { - id: 5, - payload: Bytes::from_static(&[ - 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, - 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, - ]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload: raw_pkt.slice(44..), - }; - - let dst_data = p.marshal()?; - assert_eq!( - dst_data, raw_pkt, - "Marshal failed raw \nMarshaled:\n{dst_data:?}\nrawPkt:\n{raw_pkt:?}" - ); - Ok(()) -} - -fn test_rfc8285_two_byte_multiple_extension_with_padding() -> Result<()> { - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | 0x10 | 0x00 | length=3 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | ID=1 | L=0 | ID=2 | L=1 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | data | 0 (pad) | ID=3 | L=4 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | data | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - - let mut raw_pkt = Bytes::from_static(&[ - 0x90u8, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x10, 0x00, 0x00, - 0x03, 0x01, 0x00, 0x02, 0x01, 0xBB, 0x00, 0x03, 0x04, 0xCC, 0xCC, 0xCC, 0xCC, 0x98, 0x36, - 0xbe, 0x88, 0x9e, - ]); - - let p = Packet::unmarshal(&mut raw_pkt)?; - - let ext = p.header.get_extension(1); - let ext_expect = Some(Bytes::from_static(&[])); - assert_eq!( - ext, ext_expect, - "Extension has incorrect data. Got: {ext:?}, Expected: {ext_expect:?}" - ); - - let ext = p.header.get_extension(2); - let ext_expect = Some(Bytes::from_static(&[0xBB])); - assert_eq!( - ext, ext_expect, - "Extension has incorrect data. Got: {ext:?}, Expected: {ext_expect:?}" - ); - - let ext = p.header.get_extension(3); - let ext_expect = Some(Bytes::from_static(&[0xCC, 0xCC, 0xCC, 0xCC])); - assert_eq!( - ext, ext_expect, - "Extension has incorrect data. Got: {ext:?}, Expected: {ext_expect:?}" - ); - - Ok(()) -} - -fn test_rfc8285_two_byte_multiple_extension_with_large_extension() -> Result<()> { - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | 0x10 | 0x00 | length=3 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | ID=1 | L=0 | ID=2 | L=1 | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | data | ID=3 | L=17 | data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data... - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // ...data... | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - - let raw_pkt = Bytes::from_static(&[ - 0x90u8, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x10, 0x00, 0x00, - 0x06, 0x01, 0x00, 0x02, 0x01, 0xBB, 0x03, 0x11, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, - 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, // Payload - 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - - let p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0x1000, - extensions: vec![ - Extension { - id: 1, - payload: Bytes::from_static(&[]), - }, - Extension { - id: 2, - payload: Bytes::from_static(&[0xBB]), - }, - Extension { - id: 3, - payload: Bytes::from_static(&[ - 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, - 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, - ]), - }, - ], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload: raw_pkt.slice(40..), - }; - - let dst_data = p.marshal()?; - assert_eq!( - dst_data, - raw_pkt[..], - "Marshal failed raw \nMarshaled: {dst_data:?}, \nraw_pkt:{raw_pkt:?}" - ); - - Ok(()) -} - -fn test_rfc8285_get_extension_returns_nil_when_extension_disabled() -> Result<()> { - let payload = Bytes::from_static(&[ - // Payload - 0x98u8, 0x36, 0xbe, 0x88, 0x9e, - ]); - - let p = Packet { - header: Header { - marker: true, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let res = p.header.get_extension(1); - assert!( - res.is_none(), - "Should return none on get_extension when header extension is false" - ); - - Ok(()) -} - -fn test_rfc8285_del_extension() -> Result<()> { - let payload = Bytes::from_static(&[ - // Payload - 0x98u8, 0x36, 0xbe, 0x88, 0x9e, - ]); - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let ext = p.header.get_extension(1); - assert!(ext.is_some(), "Extension should exist"); - - p.header.del_extension(1)?; - - let ext = p.header.get_extension(1); - assert!(ext.is_none(), "Extension should not exist"); - - let err = p.header.del_extension(1); - assert!( - err.is_err(), - "Should return error when deleting extension that doesnt exist" - ); - - Ok(()) -} - -fn test_rfc8285_get_extension_ids() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![ - Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }, - Extension { - id: 2, - payload: Bytes::from_static(&[0xBB]), - }, - ], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let ids = p.header.get_extension_ids(); - assert!(!ids.is_empty(), "Extensions should exist"); - - assert_eq!( - ids.len(), - p.header.extensions.len(), - "The number of IDs should be equal to the number of extensions, want={}, hanve{}", - ids.len(), - p.header.extensions.len() - ); - - for id in ids { - let ext = p.header.get_extension(id); - assert!(ext.is_some(), "Extension should exist for id: {id}") - } -} - -fn test_rfc8285_get_extension_ids_return_empty_when_extension_disabled() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let p = Packet { - header: Header { - marker: true, - extension: false, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let ids = p.header.get_extension_ids(); - assert!(ids.is_empty(), "Extensions should not exist"); -} - -fn test_rfc8285_del_extension_returns_error_when_extensions_disabled() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: false, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let ids = p.header.del_extension(1); - assert!( - ids.is_err(), - "Should return error on del_extension when header extension field is false" - ); -} - -fn test_rfc8285_one_byte_set_extension_should_enable_extension_when_adding() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: false, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let extension = Bytes::from_static(&[0xAAu8, 0xAA]); - let result = p.header.set_extension(1, extension.clone()); - assert!(result.is_ok(), "Error setting extension"); - - assert!(p.header.extension, "Extension should be set to true"); - assert_eq!( - p.header.extension_profile, 0xBEDE, - "Extension profile should be set to 0xBEDE" - ); - assert_eq!( - p.header.extensions.len(), - 1, - "Extensions len should be set to 1" - ); - assert_eq!( - p.header.get_extension(1), - Some(extension), - "Extension value is not set" - ) -} - -fn test_rfc8285_set_extension_should_set_correct_extension_profile_for_16_byte_extension() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: false, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let extension = Bytes::from_static(&[ - 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, - 0xAA, - ]); - - let res = p.header.set_extension(1, extension); - assert!(res.is_ok(), "Error setting extension"); - - assert_eq!( - p.header.extension_profile, 0xBEDE, - "Extension profile should be 0xBEDE" - ); -} - -fn test_rfc8285_set_extension_should_update_existing_extension() -> Result<()> { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - assert_eq!( - p.header.get_extension(1), - Some([0xAA][..].into()), - "Extension value not initialized properly" - ); - - let extension = Bytes::from_static(&[0xBBu8]); - p.header.set_extension(1, extension.clone())?; - - assert_eq!( - p.header.get_extension(1), - Some(extension), - "Extension value was not set" - ); - - Ok(()) -} - -fn test_rfc8285_one_byte_set_extension_should_error_when_invalid_id_provided() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - assert!( - p.header - .set_extension(0, Bytes::from_static(&[0xBBu8])) - .is_err(), - "set_extension did not error on invalid id" - ); - assert!( - p.header - .set_extension(15, Bytes::from_static(&[0xBBu8])) - .is_err(), - "set_extension did not error on invalid id" - ); -} - -fn test_rfc8285_one_byte_extension_terminate_processing_when_reserved_id_encountered() -> Result<()> -{ - let reserved_id_pkt = Bytes::from_static(&[ - 0x90u8, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, - 0x01, 0xF0, 0xAA, 0x98, 0x36, 0xbe, 0x88, 0x9e, - ]); - - let p = Packet::unmarshal(&mut reserved_id_pkt.clone())?; - - assert_eq!( - p.header.extensions.len(), - 0, - "Extension should be empty for invalid ID" - ); - - let payload = reserved_id_pkt.slice(17..); - assert_eq!(p.payload, payload, "p.payload must be same as payload"); - - Ok(()) -} - -fn test_rfc8285_one_byte_set_extension_should_error_when_payload_too_large() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAAu8]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let res = p.header.set_extension( - 1, - Bytes::from_static(&[ - 0xBBu8, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, - ]), - ); - - assert!( - res.is_err(), - "set_extension did not error on too large payload" - ); -} - -fn test_rfc8285_two_bytes_set_extension_should_enable_extension_when_adding() -> Result<()> { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let extension = Bytes::from_static(&[ - 0xAAu8, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, - 0xAA, 0xAA, - ]); - - p.header.set_extension(1, extension.clone())?; - - assert!(p.header.extension, "Extension should be set to true"); - assert_eq!( - p.header.extension_profile, 0x1000, - "Extension profile should be set to 0xBEDE" - ); - assert_eq!( - p.header.extensions.len(), - 1, - "Extensions should be set to 1" - ); - assert_eq!( - p.header.get_extension(1), - Some(extension), - "Extension value is not set" - ); - - Ok(()) -} - -fn test_rfc8285_two_byte_set_extension_should_update_existing_extension() -> Result<()> { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0x1000, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - assert_eq!( - p.header.get_extension(1), - Some(Bytes::from_static(&[0xAA])), - "Extension value not initialized properly" - ); - - let extension = Bytes::from_static(&[ - 0xBBu8, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, - ]); - - p.header.set_extension(1, extension.clone())?; - - assert_eq!(p.header.get_extension(1), Some(extension)); - - Ok(()) -} - -fn test_rfc8285_two_byte_set_extension_should_error_when_payload_too_large() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let res = p.header.set_extension( - 1, - Bytes::from_static(&[ - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, - ]), - ); - - assert!( - res.is_err(), - "Set extension did not error on too large payload" - ); -} - -fn test_rfc3550_set_extension_should_error_when_non_zero() -> Result<()> { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0x1111, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0xAA]), - }], - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - p.header.set_extension(0, Bytes::from_static(&[0xBB]))?; - let res = p.header.get_extension(0); - assert_eq!( - res, - Some(Bytes::from_static(&[0xBB])), - "p.get_extension returned incorrect value" - ); - - Ok(()) -} - -fn test_rfc3550_set_extension_should_error_when_setting_non_zero_id() { - let payload = Bytes::from_static(&[0x98u8, 0x36, 0xbe, 0x88, 0x9e]); - - let mut p = Packet { - header: Header { - marker: true, - extension: true, - extension_profile: 0x1111, - version: 2, - payload_type: 96, - sequence_number: 27023, - timestamp: 3653407706, - ssrc: 476325762, - ..Default::default() - }, - payload, - ..Default::default() - }; - - let res = p.header.set_extension(1, Bytes::from_static(&[0xBB])); - assert!(res.is_err(), "set_extension did not error on invalid id"); -} - -use std::collections::HashMap; - -struct Cases { - input: Bytes, - err: Error, -} - -fn test_unmarshal_error_handling() { - let mut cases = HashMap::new(); - - cases.insert( - "ShortHeader", - Cases { - input: Bytes::from_static(&[ - 0x80, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, // timestamp - 0x1c, 0x64, 0x27, // SSRC (one byte missing) - ]), - err: Error::ErrHeaderSizeInsufficient, - }, - ); - - cases.insert( - "MissingCSRC", - Cases { - input: Bytes::from_static(&[ - 0x81, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, // timestamp - 0x1c, 0x64, 0x27, 0x82, // SSRC - ]), - err: Error::ErrHeaderSizeInsufficient, - }, - ); - - cases.insert( - "MissingExtension", - Cases { - input: Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, // timestamp - 0x1c, 0x64, 0x27, 0x82, // SSRC - ]), - err: Error::ErrHeaderSizeInsufficientForExtension, - }, - ); - - cases.insert( - "MissingExtensionData", - Cases { - input: Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, // timestamp - 0x1c, 0x64, 0x27, 0x82, // SSRC - 0xBE, 0xDE, 0x00, 0x03, // specified to have 3 extensions, but actually not - ]), - err: Error::ErrHeaderSizeInsufficientForExtension, - }, - ); - - cases.insert( - "MissingExtensionDataPayload", - Cases { - input: Bytes::from_static(&[ - 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, // timestamp - 0x1c, 0x64, 0x27, 0x82, // SSRC - 0xBE, 0xDE, 0x00, 0x01, // have 1 extension - 0x12, - 0x00, // length of the payload is expected to be 3, but actually have only 1 - ]), - err: Error::ErrHeaderSizeInsufficientForExtension, - }, - ); - - for (name, mut test_case) in cases.drain() { - let result = Header::unmarshal(&mut test_case.input); - let err = result.err().unwrap(); - assert_eq!( - test_case.err, err, - "Expected :{:?}, found: {:?} for testcase {}", - test_case.err, err, name - ) - } -} - -fn test_round_trip() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x00u8, 0x10, 0x23, 0x45, 0x12, 0x34, 0x45, 0x67, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, - 0x33, 0x44, 0x55, 0x66, 0x77, - ]); - - let payload = raw_pkt.slice(12..); - - let p = Packet::unmarshal(&mut raw_pkt.clone())?; - - assert_eq!( - payload, p.payload, - "p.payload must be same as payload.\n p.payload: {:?},\nraw_pkt: {:?}", - p.payload, payload - ); - - let buf = p.marshal()?; - - assert_eq!( - raw_pkt, buf, - "buf must be the same as raw_pkt. \n buf: {buf:?},\nraw_pkt: {raw_pkt:?}", - ); - assert_eq!( - payload, p.payload, - "p.payload must be the same as payload. \n payload: {:?},\np.payload: {:?}", - payload, p.payload, - ); - - Ok(()) -} diff --git a/rtp/src/packetizer/mod.rs b/rtp/src/packetizer/mod.rs deleted file mode 100644 index c925bc708..000000000 --- a/rtp/src/packetizer/mod.rs +++ /dev/null @@ -1,165 +0,0 @@ -#[cfg(test)] -mod packetizer_test; - -use std::fmt; -use std::sync::Arc; -use std::time::SystemTime; - -use bytes::{Bytes, BytesMut}; -use util::marshal::{Marshal, MarshalSize}; - -use crate::error::Result; -use crate::extension::abs_send_time_extension::*; -use crate::header::*; -use crate::packet::*; -use crate::sequence::*; - -/// Payloader payloads a byte array for use as rtp.Packet payloads -pub trait Payloader: fmt::Debug { - fn payload(&mut self, mtu: usize, b: &Bytes) -> Result>; - fn clone_to(&self) -> Box; -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_to() - } -} - -/// Packetizer packetizes a payload -pub trait Packetizer: fmt::Debug { - fn enable_abs_send_time(&mut self, value: u8); - fn packetize(&mut self, payload: &Bytes, samples: u32) -> Result>; - fn skip_samples(&mut self, skipped_samples: u32); - fn clone_to(&self) -> Box; -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_to() - } -} - -/// Depacketizer depacketizes a RTP payload, removing any RTP specific data from the payload -pub trait Depacketizer { - fn depacketize(&mut self, b: &Bytes) -> Result; - - /// Checks if the packet is at the beginning of a partition. This - /// should return false if the result could not be determined, in - /// which case the caller will detect timestamp discontinuities. - fn is_partition_head(&self, payload: &Bytes) -> bool; - - /// Checks if the packet is at the end of a partition. This should - /// return false if the result could not be determined. - fn is_partition_tail(&self, marker: bool, payload: &Bytes) -> bool; -} - -//TODO: SystemTime vs Instant? -// non-monotonic clock vs monotonically non-decreasing clock -/// FnTimeGen provides current SystemTime -pub type FnTimeGen = Arc SystemTime) + Send + Sync>; - -#[derive(Clone)] -pub(crate) struct PacketizerImpl { - pub(crate) mtu: usize, - pub(crate) payload_type: u8, - pub(crate) ssrc: u32, - pub(crate) payloader: Box, - pub(crate) sequencer: Box, - pub(crate) timestamp: u32, - pub(crate) clock_rate: u32, - pub(crate) abs_send_time: u8, //http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time - pub(crate) time_gen: Option, -} - -impl fmt::Debug for PacketizerImpl { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PacketizerImpl") - .field("mtu", &self.mtu) - .field("payload_type", &self.payload_type) - .field("ssrc", &self.ssrc) - .field("timestamp", &self.timestamp) - .field("clock_rate", &self.clock_rate) - .field("abs_send_time", &self.abs_send_time) - .finish() - } -} - -pub fn new_packetizer( - mtu: usize, - payload_type: u8, - ssrc: u32, - payloader: Box, - sequencer: Box, - clock_rate: u32, -) -> impl Packetizer { - PacketizerImpl { - mtu, - payload_type, - ssrc, - payloader, - sequencer, - timestamp: rand::random::(), - clock_rate, - abs_send_time: 0, - time_gen: None, - } -} - -impl Packetizer for PacketizerImpl { - fn enable_abs_send_time(&mut self, value: u8) { - self.abs_send_time = value - } - - fn packetize(&mut self, payload: &Bytes, samples: u32) -> Result> { - let payloads = self.payloader.payload(self.mtu - 12, payload)?; - let payloads_len = payloads.len(); - let mut packets = Vec::with_capacity(payloads_len); - for (i, payload) in payloads.into_iter().enumerate() { - packets.push(Packet { - header: Header { - version: 2, - padding: false, - extension: false, - marker: i == payloads_len - 1, - payload_type: self.payload_type, - sequence_number: self.sequencer.next_sequence_number(), - timestamp: self.timestamp, //TODO: Figure out how to do timestamps - ssrc: self.ssrc, - ..Default::default() - }, - payload, - }); - } - - self.timestamp = self.timestamp.wrapping_add(samples); - - if payloads_len != 0 && self.abs_send_time != 0 { - let st = if let Some(fn_time_gen) = &self.time_gen { - fn_time_gen() - } else { - SystemTime::now() - }; - let send_time = AbsSendTimeExtension::new(st); - //apply http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time - let mut raw = BytesMut::with_capacity(send_time.marshal_size()); - raw.resize(send_time.marshal_size(), 0); - let _ = send_time.marshal_to(&mut raw)?; - packets[payloads_len - 1] - .header - .set_extension(self.abs_send_time, raw.freeze())?; - } - - Ok(packets) - } - - /// skip_samples causes a gap in sample count between Packetize requests so the - /// RTP payloads produced have a gap in timestamps - fn skip_samples(&mut self, skipped_samples: u32) { - self.timestamp = self.timestamp.wrapping_add(skipped_samples); - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/rtp/src/packetizer/packetizer_test.rs b/rtp/src/packetizer/packetizer_test.rs deleted file mode 100644 index 077658638..000000000 --- a/rtp/src/packetizer/packetizer_test.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::time::{Duration, UNIX_EPOCH}; - -use chrono::prelude::*; - -use super::*; -use crate::codecs::*; -use crate::error::Result; - -#[test] -fn test_packetizer() -> Result<()> { - let multiple_payload = Bytes::from_static(&[0; 128]); - let g722 = Box::new(g7xx::G722Payloader {}); - let seq = Box::new(new_random_sequencer()); - - //use the G722 payloader here, because it's very simple and all 0s is valid G722 data. - let mut packetizer = new_packetizer(100, 98, 0x1234ABCD, g722, seq, 90000); - let packets = packetizer.packetize(&multiple_payload, 2000)?; - - if packets.len() != 2 { - let mut packet_lengths = String::new(); - #[allow(clippy::needless_range_loop)] - for i in 0..packets.len() { - packet_lengths += - format!("Packet {} length {}\n", i, packets[i].payload.len()).as_str(); - } - panic!( - "Generated {} packets instead of 2\n{}", - packets.len(), - packet_lengths, - ); - } - Ok(()) -} - -#[test] -fn test_packetizer_abs_send_time() -> Result<()> { - let g722 = Box::new(g7xx::G722Payloader {}); - let sequencer = Box::new(new_fixed_sequencer(1234)); - - let time_gen: Option = Some(Arc::new(|| -> SystemTime { - let loc = FixedOffset::west_opt(5 * 60 * 60).unwrap(); // UTC-5 - let t = loc.with_ymd_and_hms(1985, 6, 23, 4, 0, 0).unwrap(); - UNIX_EPOCH - .checked_add(Duration::from_nanos(t.timestamp_nanos_opt().unwrap() as u64)) - .unwrap_or(UNIX_EPOCH) - })); - - //use the G722 payloader here, because it's very simple and all 0s is valid G722 data. - let mut pktizer = PacketizerImpl { - mtu: 100, - payload_type: 98, - ssrc: 0x1234ABCD, - payloader: g722, - sequencer, - timestamp: 45678, - clock_rate: 90000, - abs_send_time: 0, - time_gen, - }; - pktizer.enable_abs_send_time(1); - - let payload = Bytes::from_static(&[0x11, 0x12, 0x13, 0x14]); - let packets = pktizer.packetize(&payload, 2000)?; - - let expected = Packet { - header: Header { - version: 2, - padding: false, - extension: true, - marker: true, - payload_type: 98, - sequence_number: 1234, - timestamp: 45678, - ssrc: 0x1234ABCD, - csrc: vec![], - extension_profile: 0xBEDE, - extensions: vec![Extension { - id: 1, - payload: Bytes::from_static(&[0x40, 0, 0]), - }], - extensions_padding: 0, - }, - payload: Bytes::from_static(&[0x11, 0x12, 0x13, 0x14]), - }; - - if packets.len() != 1 { - panic!("Generated {} packets instead of 1", packets.len()) - } - - assert_eq!(packets[0], expected); - - Ok(()) -} - -#[test] -fn test_packetizer_timestamp_rollover_does_not_panic() -> Result<()> { - let g722 = Box::new(g7xx::G722Payloader {}); - let seq = Box::new(new_random_sequencer()); - - let payload = Bytes::from_static(&[0; 128]); - let mut packetizer = new_packetizer(100, 98, 0x1234ABCD, g722, seq, 90000); - - packetizer.packetize(&payload, 10)?; - - packetizer.packetize(&payload, u32::MAX)?; - - packetizer.skip_samples(u32::MAX); - - Ok(()) -} diff --git a/rtp/src/sequence.rs b/rtp/src/sequence.rs deleted file mode 100644 index bf61c29ed..000000000 --- a/rtp/src/sequence.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::fmt; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::{AtomicU16, AtomicU64}; - -/// Sequencer generates sequential sequence numbers for building RTP packets -pub trait Sequencer: fmt::Debug { - fn next_sequence_number(&self) -> u16; - fn roll_over_count(&self) -> u64; - fn clone_to(&self) -> Box; -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_to() - } -} - -/// NewRandomSequencer returns a new sequencer starting from a random sequence -/// number -pub fn new_random_sequencer() -> impl Sequencer { - let c = Counters { - sequence_number: Arc::new(AtomicU16::new(rand::random::())), - roll_over_count: Arc::new(AtomicU64::new(0)), - }; - SequencerImpl(c) -} - -/// NewFixedSequencer returns a new sequencer starting from a specific -/// sequence number -pub fn new_fixed_sequencer(s: u16) -> impl Sequencer { - let sequence_number = if s == 0 { u16::MAX } else { s - 1 }; - - let c = Counters { - sequence_number: Arc::new(AtomicU16::new(sequence_number)), - roll_over_count: Arc::new(AtomicU64::new(0)), - }; - - SequencerImpl(c) -} - -#[derive(Debug, Clone)] -struct SequencerImpl(Counters); - -#[derive(Debug, Clone)] -struct Counters { - sequence_number: Arc, - roll_over_count: Arc, -} - -impl Sequencer for SequencerImpl { - /// NextSequenceNumber increment and returns a new sequence number for - /// building RTP packets - fn next_sequence_number(&self) -> u16 { - if self.0.sequence_number.load(Ordering::SeqCst) == u16::MAX { - self.0.roll_over_count.fetch_add(1, Ordering::SeqCst); - self.0.sequence_number.store(0, Ordering::SeqCst); - 0 - } else { - self.0.sequence_number.fetch_add(1, Ordering::SeqCst) + 1 - } - } - - /// RollOverCount returns the amount of times the 16bit sequence number - /// has wrapped - fn roll_over_count(&self) -> u64 { - self.0.roll_over_count.load(Ordering::SeqCst) - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/sctp/.gitignore b/sctp/.gitignore deleted file mode 100644 index 87ca02a58..000000000 --- a/sctp/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk \ No newline at end of file diff --git a/sctp/CHANGELOG.md b/sctp/CHANGELOG.md deleted file mode 100644 index ee7acd552..000000000 --- a/sctp/CHANGELOG.md +++ /dev/null @@ -1,39 +0,0 @@ -# webrtc-sctp changelog - -## Unreleased - -* Use the new algorithm in crc crate for better throughput [#569](https://github.com/webrtc-rs/webrtc/pull/569) - -## v0.8.0 - -* Fix 'attempt to add with overflow' panic in dev profile [#393](https://github.com/webrtc-rs/webrtc/pull/393) -* Limit the bytes in the PendingQueue to avoid packets accumulating there uncontrollably [#367](https://github.com/webrtc-rs/webrtc/pull/367). -* Improve algorithm used to push to pending queue from O(n*log(n)) to O(log(n)) [#365](https://github.com/webrtc-rs/webrtc/pull/365). -* Reuse as many allocations as possible when marshaling [#364](https://github.com/webrtc-rs/webrtc/pull/364). -* The lock for the internal association was contended badly because marshaling was done while still in a critical section and also tokio was scheduling tasks badly [#363](https://github.com/webrtc-rs/webrtc/pull/363). - -### Breaking - -* Make `sctp::Stream::write` & `sctp::Stream::write_sctp` async again [#367](https://github.com/webrtc-rs/webrtc/pull/367). - -## v0.7.0 - -* Increased minimum support rust version to `1.60.0`. -* Do not loose data in `PollStream::poll_write` [#341](https://github.com/webrtc-rs/webrtc/pull/341). -* `PollStream::poll_shutdown`: make sure to flush any writes before shutting down [#340](https://github.com/webrtc-rs/webrtc/pull/340). -* Fixed a possible bug when adding chunks to pending queue [#345](https://github.com/webrtc-rs/webrtc/pull/345). -* Increased required `webrtc-util` version to `0.7.0`. - -### Breaking changes - -* Make `Stream::on_buffered_amount_low` function non-async [#338](https://github.com/webrtc-rs/webrtc/pull/338). -* Make `sctp::Stream::write` & `sctp::Stream::write_sctp` sync [#344](https://github.com/webrtc-rs/webrtc/pull/344). - -## v0.6.1 - -* Increased min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). -* [#245 Fix incorrect chunk type Display for CWR](https://github.com/webrtc-rs/webrtc/pull/245) by [@k0nserv](https://github.com/k0nserv). - -## Prior to 0.6.1 - -Before 0.6.1 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/sctp/releases). diff --git a/sctp/Cargo.toml b/sctp/Cargo.toml deleted file mode 100644 index 061419603..000000000 --- a/sctp/Cargo.toml +++ /dev/null @@ -1,51 +0,0 @@ -[package] -name = "webrtc-sctp" -version = "0.10.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of SCTP" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-sctp" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/sctp" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["conn"] } - -arc-swap = "1" -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -bytes = "1" -rand = "0.8" -crc = "3.2.1" -async-trait = "0.1" -log = "0.4" -thiserror = "1" -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" -lazy_static = "1" -env_logger = "0.10" -chrono = "0.4.28" -clap = "3" - -[[example]] -name = "ping" -path = "examples/ping.rs" -bench = false - -[[example]] -name = "pong" -path = "examples/pong.rs" -bench = false diff --git a/sctp/LICENSE-APACHE b/sctp/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/sctp/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/sctp/LICENSE-MIT b/sctp/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/sctp/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/sctp/README.md b/sctp/README.md deleted file mode 100644 index c856c7c14..000000000 --- a/sctp/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of SCTP. Rewrite Pion SCTP in Rust -

diff --git a/sctp/codecov.yml b/sctp/codecov.yml deleted file mode 100644 index ac2e61cff..000000000 --- a/sctp/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 9ec6d495-dfa7-4250-afeb-dbf342009340 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/sctp/doc/webrtc.rs.png b/sctp/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/sctp/doc/webrtc.rs.png and /dev/null differ diff --git a/sctp/examples/ping.rs b/sctp/examples/ping.rs deleted file mode 100644 index d8566ce39..000000000 --- a/sctp/examples/ping.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::net::Shutdown; -use std::sync::Arc; - -use bytes::Bytes; -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use tokio::signal; -use tokio::sync::mpsc; -use webrtc_sctp::association::*; -use webrtc_sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier; -use webrtc_sctp::stream::*; -use webrtc_sctp::Error; - -// RUST_LOG=trace cargo run --color=always --package webrtc-sctp --example ping -- --server 0.0.0.0:5678 - -#[tokio::main] -async fn main() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let mut app = App::new("SCTP Ping") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of SCTP Client") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .long("server") - .help("SCTP Server name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - conn.connect(server).await.unwrap(); - println!("connecting {server}.."); - - let config = Config { - net_conn: conn, - max_receive_buffer_size: 0, - max_message_size: 0, - name: "client".to_owned(), - }; - let a = Association::client(config).await?; - println!("created a client"); - - let stream = a.open_stream(0, PayloadProtocolIdentifier::String).await?; - println!("opened a stream"); - - // set unordered = true and 10ms threshold for dropping packets - stream.set_reliability_params(true, ReliabilityType::Timed, 10); - - let stream_tx = Arc::clone(&stream); - tokio::spawn(async move { - let mut ping_seq_num = 0; - while ping_seq_num < 10 { - let ping_msg = format!("ping {ping_seq_num}"); - println!("sent: {ping_msg}"); - stream_tx.write(&Bytes::from(ping_msg)).await?; - - ping_seq_num += 1; - } - - println!("finished send ping"); - Result::<(), Error>::Ok(()) - }); - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let stream_rx = Arc::clone(&stream); - tokio::spawn(async move { - let mut buff = vec![0u8; 1024]; - while let Ok(n) = stream_rx.read(&mut buff).await { - let pong_msg = String::from_utf8(buff[..n].to_vec()).unwrap(); - println!("received: {pong_msg}"); - } - - println!("finished recv pong"); - drop(done_tx); - }); - - println!("Waiting for Ctrl-C..."); - signal::ctrl_c().await.expect("failed to listen for event"); - println!("Closing stream and association..."); - - stream.shutdown(Shutdown::Both).await?; - a.close().await?; - - let _ = done_rx.recv().await; - - Ok(()) -} diff --git a/sctp/examples/pong.rs b/sctp/examples/pong.rs deleted file mode 100644 index c46b1b6db..000000000 --- a/sctp/examples/pong.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::net::Shutdown; -use std::sync::Arc; -use std::time::Duration; - -use bytes::Bytes; -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use tokio::signal; -use tokio::sync::mpsc; -use util::conn::conn_disconnected_packet::DisconnectedPacketConn; -use util::Conn; -use webrtc_sctp::association::*; -use webrtc_sctp::stream::*; -use webrtc_sctp::Error; - -// RUST_LOG=trace cargo run --color=always --package webrtc-sctp --example pong -- --host 0.0.0.0:5678 - -#[tokio::main] -async fn main() -> Result<(), Error> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let mut app = App::new("SCTP Pong") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of SCTP Server") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("host") - .required_unless("FULLHELP") - .takes_value(true) - .long("host") - .help("SCTP host name."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let host = matches.value_of("host").unwrap(); - let conn = DisconnectedPacketConn::new(Arc::new(UdpSocket::bind(host).await.unwrap())); - println!("listening {}...", conn.local_addr().unwrap()); - - let config = Config { - net_conn: Arc::new(conn), - max_receive_buffer_size: 0, - max_message_size: 0, - name: "server".to_owned(), - }; - let a = Association::server(config).await?; - println!("created a server"); - - let stream = a.accept_stream().await.unwrap(); - println!("accepted a stream"); - - // set unordered = true and 10ms threshold for dropping packets - stream.set_reliability_params(true, ReliabilityType::Timed, 10); - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let stream2 = Arc::clone(&stream); - tokio::spawn(async move { - let mut buff = vec![0u8; 1024]; - while let Ok(n) = stream2.read(&mut buff).await { - let ping_msg = String::from_utf8(buff[..n].to_vec()).unwrap(); - println!("received: {ping_msg}"); - - let pong_msg = format!("pong [{ping_msg}]"); - println!("sent: {pong_msg}"); - stream2.write(&Bytes::from(pong_msg)).await?; - - tokio::time::sleep(Duration::from_secs(1)).await; - } - println!("finished ping-pong"); - drop(done_tx); - - Result::<(), Error>::Ok(()) - }); - - println!("Waiting for Ctrl-C..."); - signal::ctrl_c().await.expect("failed to listen for event"); - println!("Closing stream and association..."); - - stream.shutdown(Shutdown::Both).await?; - a.close().await?; - - let _ = done_rx.recv().await; - - Ok(()) -} diff --git a/sctp/examples/throughput.rs b/sctp/examples/throughput.rs deleted file mode 100644 index 31f2100e9..000000000 --- a/sctp/examples/throughput.rs +++ /dev/null @@ -1,147 +0,0 @@ -use std::io::Write; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use util::conn::conn_disconnected_packet::DisconnectedPacketConn; -use util::Conn; -use webrtc_sctp::association::*; -use webrtc_sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier; -use webrtc_sctp::stream::*; -use webrtc_sctp::Error; - -fn main() -> Result<(), Error> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Warn) - .init(); - - let mut app = App::new("SCTP Throughput") - .version("0.1.0") - .about("An example of SCTP Server") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("port") - .required_unless("FULLHELP") - .takes_value(true) - .long("port") - .help("use port ."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let port1 = matches.value_of("port").unwrap().to_owned(); - let port2 = port1.clone(); - - std::thread::spawn(|| { - tokio::runtime::Runtime::new() - .unwrap() - .block_on(async move { - let conn = DisconnectedPacketConn::new(Arc::new( - UdpSocket::bind(format!("127.0.0.1:{port1}")).await.unwrap(), - )); - println!("listening {}...", conn.local_addr().unwrap()); - - let config = Config { - net_conn: Arc::new(conn), - max_receive_buffer_size: 0, - max_message_size: 0, - name: "recver".to_owned(), - }; - let a = Association::server(config).await?; - println!("created a server"); - - let stream = a.accept_stream().await.unwrap(); - println!("accepted a stream"); - - // set unordered = true and 10ms threshold for dropping packets - stream.set_reliability_params(true, ReliabilityType::Rexmit, 0); - - let mut buff = [0u8; 65535]; - let mut recv = 0; - let mut pkt_num = 0; - let mut loop_num = 0; - let mut now = tokio::time::Instant::now(); - while let Ok(n) = stream.read(&mut buff).await { - recv += n; - if n != 0 { - pkt_num += 1; - } - loop_num += 1; - if now.elapsed().as_secs() == 1 { - println!("Throughput: {recv} Bytes/s, {pkt_num} pkts, {loop_num} loops"); - now = tokio::time::Instant::now(); - recv = 0; - loop_num = 0; - pkt_num = 0; - } - } - Result::<(), Error>::Ok(()) - }) - }); - - std::thread::spawn(|| { - tokio::runtime::Runtime::new() - .unwrap() - .block_on(async move { - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - conn.connect(format!("127.0.0.1:{port2}")).await.unwrap(); - println!("connecting 127.0.0.1:{port2}.."); - - let config = Config { - net_conn: conn, - max_receive_buffer_size: 0, - max_message_size: 0, - name: "sender".to_owned(), - }; - let a = Association::client(config).await.unwrap(); - println!("created a client"); - - let stream = a - .open_stream(0, PayloadProtocolIdentifier::Binary) - .await - .unwrap(); - println!("opened a stream"); - - //const LEN: usize = 1200; - const LEN: usize = 65535; - let buf = vec![0; LEN]; - let bytes = bytes::Bytes::from(buf); - - let mut now = tokio::time::Instant::now(); - let mut pkt_num = 0; - while stream.write(&bytes).await.is_ok() { - pkt_num += 1; - if now.elapsed().as_secs() == 1 { - println!("Send {pkt_num} pkts"); - now = tokio::time::Instant::now(); - pkt_num = 0; - } - } - Result::<(), Error>::Ok(()) - }) - }); - #[allow(clippy::empty_loop)] - loop {} -} diff --git a/sctp/fuzz/.gitignore b/sctp/fuzz/.gitignore deleted file mode 100644 index 80894b1a2..000000000 --- a/sctp/fuzz/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -target -corpus diff --git a/sctp/fuzz/Cargo.toml b/sctp/fuzz/Cargo.toml deleted file mode 100644 index e22f7d11f..000000000 --- a/sctp/fuzz/Cargo.toml +++ /dev/null @@ -1,32 +0,0 @@ -[package] -name = "webrtc-sctp-fuzz" -version = "0.0.0" -authors = ["Automatically generated"] -publish = false -edition = "2021" - -[package.metadata] -cargo-fuzz = true - -[dependencies] -libfuzzer-sys = "0.4" -bytes = "*" - -[dependencies.webrtc-sctp] -path = ".." - -# Prevent this from interfering with workspaces -[workspace] -members = ["."] - -[[bin]] -name = "packet" -path = "fuzz_targets/packet.rs" -test = false -doc = false - -[[bin]] -name = "param" -path = "fuzz_targets/param.rs" -test = false -doc = false diff --git a/sctp/fuzz/artifacts/packet/crash-16cad30042bc4791bd62c630a780add5d1220779 b/sctp/fuzz/artifacts/packet/crash-16cad30042bc4791bd62c630a780add5d1220779 deleted file mode 100644 index 2f6df735b..000000000 Binary files a/sctp/fuzz/artifacts/packet/crash-16cad30042bc4791bd62c630a780add5d1220779 and /dev/null differ diff --git a/sctp/fuzz/artifacts/packet/crash-8b9b318a6b66ea23232a4e2aec91deeeca470af8 b/sctp/fuzz/artifacts/packet/crash-8b9b318a6b66ea23232a4e2aec91deeeca470af8 deleted file mode 100644 index aac36a233..000000000 Binary files a/sctp/fuzz/artifacts/packet/crash-8b9b318a6b66ea23232a4e2aec91deeeca470af8 and /dev/null differ diff --git a/sctp/fuzz/artifacts/packet/crash-8d90dfc8fc34fa06f161f69617ee8f48dec434cd b/sctp/fuzz/artifacts/packet/crash-8d90dfc8fc34fa06f161f69617ee8f48dec434cd deleted file mode 100644 index 92d545a3d..000000000 Binary files a/sctp/fuzz/artifacts/packet/crash-8d90dfc8fc34fa06f161f69617ee8f48dec434cd and /dev/null differ diff --git a/sctp/fuzz/artifacts/packet/crash-b836a20af7f8af85423dbe80565465b16bb7a16f b/sctp/fuzz/artifacts/packet/crash-b836a20af7f8af85423dbe80565465b16bb7a16f deleted file mode 100644 index bd51fac25..000000000 Binary files a/sctp/fuzz/artifacts/packet/crash-b836a20af7f8af85423dbe80565465b16bb7a16f and /dev/null differ diff --git a/sctp/fuzz/artifacts/packet/crash-f940d9879efc88872145955bae11ca6ad6a4c044 b/sctp/fuzz/artifacts/packet/crash-f940d9879efc88872145955bae11ca6ad6a4c044 deleted file mode 100644 index 745cf533b..000000000 Binary files a/sctp/fuzz/artifacts/packet/crash-f940d9879efc88872145955bae11ca6ad6a4c044 and /dev/null differ diff --git a/sctp/fuzz/artifacts/param/crash-216833e417069f431d0617fb4e9f8abe6c9a6c1d b/sctp/fuzz/artifacts/param/crash-216833e417069f431d0617fb4e9f8abe6c9a6c1d deleted file mode 100644 index a0d02fce8..000000000 Binary files a/sctp/fuzz/artifacts/param/crash-216833e417069f431d0617fb4e9f8abe6c9a6c1d and /dev/null differ diff --git a/sctp/fuzz/artifacts/param/crash-fb1b644bc0d365ce2dc3c3ff77cd3c4cd8da528d b/sctp/fuzz/artifacts/param/crash-fb1b644bc0d365ce2dc3c3ff77cd3c4cd8da528d deleted file mode 100644 index a597b341f..000000000 Binary files a/sctp/fuzz/artifacts/param/crash-fb1b644bc0d365ce2dc3c3ff77cd3c4cd8da528d and /dev/null differ diff --git a/sctp/fuzz/fuzz_targets/packet.rs b/sctp/fuzz/fuzz_targets/packet.rs deleted file mode 100644 index 65a92c5b7..000000000 --- a/sctp/fuzz/fuzz_targets/packet.rs +++ /dev/null @@ -1,10 +0,0 @@ -#![no_main] -use libfuzzer_sys::fuzz_target; - -use webrtc_sctp::packet::Packet; -use bytes::Bytes; - -fuzz_target!(|data: &[u8]| { - let bytes = Bytes::from(data.to_vec()); - Packet::unmarshal(&bytes); -}); diff --git a/sctp/fuzz/fuzz_targets/param.rs b/sctp/fuzz/fuzz_targets/param.rs deleted file mode 100644 index 7677edf6e..000000000 --- a/sctp/fuzz/fuzz_targets/param.rs +++ /dev/null @@ -1,10 +0,0 @@ -#![no_main] -use libfuzzer_sys::fuzz_target; - -use webrtc_sctp::param::build_param; -use bytes::Bytes; - -fuzz_target!(|data: &[u8]| { - let bytes = Bytes::from(data.to_vec()); - build_param(&bytes); -}); diff --git a/sctp/src/association/association_internal.rs b/sctp/src/association/association_internal.rs deleted file mode 100644 index 747349c94..000000000 --- a/sctp/src/association/association_internal.rs +++ /dev/null @@ -1,2421 +0,0 @@ -#[cfg(test)] -mod association_internal_test; - -use async_trait::async_trait; -use portable_atomic::AtomicBool; - -use super::*; -use crate::param::param_forward_tsn_supported::ParamForwardTsnSupported; -use crate::param::param_type::ParamType; -use crate::param::param_unrecognized::ParamUnrecognized; - -#[derive(Default)] -pub struct AssociationInternal { - pub(crate) name: String, - pub(crate) state: Arc, - pub(crate) max_message_size: Arc, - pub(crate) inflight_queue_length: Arc, - pub(crate) will_send_shutdown: Arc, - awake_write_loop_ch: Option>>, - - peer_verification_tag: u32, - pub(crate) my_verification_tag: u32, - - pub(crate) my_next_tsn: u32, // nextTSN - peer_last_tsn: u32, // lastRcvdTSN - min_tsn2measure_rtt: u32, // for RTT measurement - will_send_forward_tsn: bool, - will_retransmit_fast: bool, - will_retransmit_reconfig: bool, - - will_send_shutdown_ack: bool, - will_send_shutdown_complete: bool, - - // Reconfig - my_next_rsn: u32, - reconfigs: HashMap, - reconfig_requests: HashMap, - - // Non-RFC internal data - source_port: u16, - destination_port: u16, - pub(crate) my_max_num_inbound_streams: u16, - pub(crate) my_max_num_outbound_streams: u16, - my_cookie: Option, - payload_queue: PayloadQueue, - inflight_queue: PayloadQueue, - pending_queue: Arc, - control_queue: ControlQueue, - pub(crate) mtu: u32, - max_payload_size: u32, // max DATA chunk payload size - cumulative_tsn_ack_point: u32, - advanced_peer_tsn_ack_point: u32, - use_forward_tsn: bool, - - // Congestion control parameters - pub(crate) max_receive_buffer_size: u32, - pub(crate) cwnd: u32, // my congestion window size - rwnd: u32, // calculated peer's receiver windows size - pub(crate) ssthresh: u32, // slow start threshold - partial_bytes_acked: u32, - pub(crate) in_fast_recovery: bool, - fast_recover_exit_point: u32, - - // RTX & Ack timer - pub(crate) rto_mgr: RtoManager, - pub(crate) t1init: Option>, - pub(crate) t1cookie: Option>, - pub(crate) t2shutdown: Option>, - pub(crate) t3rtx: Option>, - pub(crate) treconfig: Option>, - pub(crate) ack_timer: Option>, - - // Chunks stored for retransmission - pub(crate) stored_init: Option, - stored_cookie_echo: Option, - - streams: HashMap>, - - close_loop_ch_tx: Option>, - accept_ch_tx: Option>>, - handshake_completed_ch_tx: Option>>, - - // local error - silent_error: Option, - - // per inbound packet context - delayed_ack_triggered: bool, - immediate_ack_triggered: bool, - - pub(crate) stats: Arc, - ack_state: AckState, - pub(crate) ack_mode: AckMode, // for testing -} - -impl AssociationInternal { - pub(crate) fn new( - config: Config, - close_loop_ch_tx: broadcast::Sender<()>, - accept_ch_tx: mpsc::Sender>, - handshake_completed_ch_tx: mpsc::Sender>, - awake_write_loop_ch: Arc>, - ) -> Self { - let max_receive_buffer_size = if config.max_receive_buffer_size == 0 { - INITIAL_RECV_BUF_SIZE - } else { - config.max_receive_buffer_size - }; - - let max_message_size = if config.max_message_size == 0 { - DEFAULT_MAX_MESSAGE_SIZE - } else { - config.max_message_size - }; - - let inflight_queue_length = Arc::new(AtomicUsize::new(0)); - - let mut tsn = random::(); - if tsn == 0 { - tsn += 1; - } - let mut a = AssociationInternal { - name: config.name, - max_receive_buffer_size, - max_message_size: Arc::new(AtomicU32::new(max_message_size)), - - my_max_num_outbound_streams: u16::MAX, - my_max_num_inbound_streams: u16::MAX, - payload_queue: PayloadQueue::new(Arc::new(AtomicUsize::new(0))), - inflight_queue: PayloadQueue::new(Arc::clone(&inflight_queue_length)), - inflight_queue_length, - pending_queue: Arc::new(PendingQueue::new()), - control_queue: ControlQueue::new(), - mtu: INITIAL_MTU, - max_payload_size: INITIAL_MTU - (COMMON_HEADER_SIZE + DATA_CHUNK_HEADER_SIZE), - my_verification_tag: random::(), - my_next_tsn: tsn, - my_next_rsn: tsn, - min_tsn2measure_rtt: tsn, - state: Arc::new(AtomicU8::new(AssociationState::Closed as u8)), - rto_mgr: RtoManager::new(), - streams: HashMap::new(), - reconfigs: HashMap::new(), - reconfig_requests: HashMap::new(), - accept_ch_tx: Some(accept_ch_tx), - close_loop_ch_tx: Some(close_loop_ch_tx), - handshake_completed_ch_tx: Some(handshake_completed_ch_tx), - cumulative_tsn_ack_point: tsn - 1, - advanced_peer_tsn_ack_point: tsn - 1, - silent_error: Some(Error::ErrSilentlyDiscard), - stats: Arc::new(AssociationStats::default()), - awake_write_loop_ch: Some(awake_write_loop_ch), - ..Default::default() - }; - - // RFC 4690 Sec 7.2.1 - // o The initial cwnd before DATA transmission or after a sufficiently - // long idle period MUST be set to min(4*MTU, max (2*MTU, 4380 - // bytes)). - // TODO: Consider whether this should use `clamp` - #[allow(clippy::manual_clamp)] - { - a.cwnd = std::cmp::min(4 * a.mtu, std::cmp::max(2 * a.mtu, 4380)); - } - log::trace!( - "[{}] updated cwnd={} ssthresh={} inflight={} (INI)", - a.name, - a.cwnd, - a.ssthresh, - a.inflight_queue.get_num_bytes() - ); - - a - } - - /// caller must hold self.lock - pub(crate) fn send_init(&mut self) -> Result<()> { - if let Some(stored_init) = self.stored_init.clone() { - log::debug!("[{}] sending INIT", self.name); - - self.source_port = 5000; // Spec?? - self.destination_port = 5000; // Spec?? - - let outbound = Packet { - source_port: self.source_port, - destination_port: self.destination_port, - verification_tag: 0, - chunks: vec![Box::new(stored_init)], - }; - - self.control_queue.push_back(outbound); - self.awake_write_loop(); - - Ok(()) - } else { - Err(Error::ErrInitNotStoredToSend) - } - } - - /// caller must hold self.lock - fn send_cookie_echo(&mut self) -> Result<()> { - if let Some(stored_cookie_echo) = &self.stored_cookie_echo { - log::debug!("[{}] sending COOKIE-ECHO", self.name); - - let outbound = Packet { - source_port: self.source_port, - destination_port: self.destination_port, - verification_tag: self.peer_verification_tag, - chunks: vec![Box::new(stored_cookie_echo.clone())], - }; - - self.control_queue.push_back(outbound); - self.awake_write_loop(); - Ok(()) - } else { - Err(Error::ErrCookieEchoNotStoredToSend) - } - } - - pub(crate) async fn close(&mut self) -> Result<()> { - if self.get_state() != AssociationState::Closed { - self.set_state(AssociationState::Closed); - - log::debug!("[{}] closing association..", self.name); - - self.close_all_timers().await; - - // awake read/write_loop to exit - self.close_loop_ch_tx.take(); - - for si in self.streams.keys().cloned().collect::>() { - self.unregister_stream(si); - } - - // Wait for read_loop to end - //if let Some(read_loop_close_ch) = &mut self.read_loop_close_ch { - // let _ = read_loop_close_ch.recv().await; - //} - - log::debug!("[{}] association closed", self.name); - log::debug!( - "[{}] stats nDATAs (in) : {}", - self.name, - self.stats.get_num_datas() - ); - log::debug!( - "[{}] stats nSACKs (in) : {}", - self.name, - self.stats.get_num_sacks() - ); - log::debug!( - "[{}] stats nT3Timeouts : {}", - self.name, - self.stats.get_num_t3timeouts() - ); - log::debug!( - "[{}] stats nAckTimeouts: {}", - self.name, - self.stats.get_num_ack_timeouts() - ); - log::debug!( - "[{}] stats nFastRetrans: {}", - self.name, - self.stats.get_num_fast_retrans() - ); - } - - Ok(()) - } - - async fn close_all_timers(&mut self) { - // Close all retransmission & ack timers - if let Some(t1init) = &self.t1init { - t1init.stop().await; - } - if let Some(t1cookie) = &self.t1cookie { - t1cookie.stop().await; - } - if let Some(t2shutdown) = &self.t2shutdown { - t2shutdown.stop().await; - } - if let Some(t3rtx) = &self.t3rtx { - t3rtx.stop().await; - } - if let Some(treconfig) = &self.treconfig { - treconfig.stop().await; - } - if let Some(ack_timer) = &mut self.ack_timer { - ack_timer.stop(); - } - } - - fn awake_write_loop(&self) { - //log::debug!("[{}] awake_write_loop_ch.notify_one", self.name); - if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch { - let _ = awake_write_loop_ch.try_send(()); - } - } - - /// unregister_stream un-registers a stream from the association - /// The caller should hold the association write lock. - fn unregister_stream(&mut self, stream_identifier: u16) { - let s = self.streams.remove(&stream_identifier); - if let Some(s) = s { - // NOTE: shutdown is not used here because it resets the stream. - if !s.read_shutdown.swap(true, Ordering::SeqCst) { - s.read_notifier.notify_waiters(); - } - s.write_shutdown.store(true, Ordering::SeqCst); - } - } - - /// handle_inbound parses incoming raw packets - pub(crate) async fn handle_inbound(&mut self, raw: &Bytes) -> Result<()> { - let p = match Packet::unmarshal(raw) { - Ok(p) => p, - Err(err) => { - log::warn!("[{}] unable to parse SCTP packet {}", self.name, err); - return Ok(()); - } - }; - - if let Err(err) = p.check_packet() { - log::warn!("[{}] failed validating packet {}", self.name, err); - return Ok(()); - } - - self.handle_chunk_start(); - - for c in &p.chunks { - self.handle_chunk(&p, c).await?; - } - - self.handle_chunk_end(); - Ok(()) - } - - fn gather_data_packets_to_retransmit(&mut self, mut raw_packets: Vec) -> Vec { - for p in self.get_data_packets_to_retransmit() { - raw_packets.push(p); - } - - raw_packets - } - - async fn gather_outbound_data_and_reconfig_packets( - &mut self, - mut raw_packets: Vec, - ) -> Vec { - // Pop unsent data chunks from the pending queue to send as much as - // cwnd and rwnd allow. - let (chunks, sis_to_reset) = self.pop_pending_data_chunks_to_send().await; - if !chunks.is_empty() { - // Start timer. (noop if already started) - log::trace!("[{}] T3-rtx timer start (pt1)", self.name); - if let Some(t3rtx) = &self.t3rtx { - t3rtx.start(self.rto_mgr.get_rto()).await; - } - for p in self.bundle_data_chunks_into_packets(chunks) { - raw_packets.push(p); - } - } - - if !sis_to_reset.is_empty() || self.will_retransmit_reconfig { - if self.will_retransmit_reconfig { - self.will_retransmit_reconfig = false; - log::debug!( - "[{}] retransmit {} RECONFIG chunk(s)", - self.name, - self.reconfigs.len() - ); - for c in self.reconfigs.values() { - let p = self.create_packet(vec![Box::new(c.clone())]); - raw_packets.push(p); - } - } - - if !sis_to_reset.is_empty() { - let rsn = self.generate_next_rsn(); - let tsn = self.my_next_tsn - 1; - log::debug!( - "[{}] sending RECONFIG: rsn={} tsn={} streams={:?}", - self.name, - rsn, - self.my_next_tsn - 1, - sis_to_reset - ); - - let c = ChunkReconfig { - param_a: Some(Box::new(ParamOutgoingResetRequest { - reconfig_request_sequence_number: rsn, - sender_last_tsn: tsn, - stream_identifiers: sis_to_reset, - ..Default::default() - })), - ..Default::default() - }; - self.reconfigs.insert(rsn, c.clone()); // store in the map for retransmission - - let p = self.create_packet(vec![Box::new(c)]); - raw_packets.push(p); - } - - if !self.reconfigs.is_empty() { - if let Some(treconfig) = &self.treconfig { - treconfig.start(self.rto_mgr.get_rto()).await; - } - } - } - - raw_packets - } - - fn gather_outbound_fast_retransmission_packets( - &mut self, - mut raw_packets: Vec, - ) -> Vec { - if self.will_retransmit_fast { - self.will_retransmit_fast = false; - - let mut to_fast_retrans: Vec> = vec![]; - let mut fast_retrans_size = COMMON_HEADER_SIZE; - - let mut i = 0; - loop { - let tsn = self.cumulative_tsn_ack_point + i + 1; - if let Some(c) = self.inflight_queue.get_mut(tsn) { - if c.acked || c.abandoned() || c.nsent > 1 || c.miss_indicator < 3 { - i += 1; - continue; - } - - // RFC 4960 Sec 7.2.4 Fast Retransmit on Gap Reports - // 3) Determine how many of the earliest (i.e., lowest TSN) DATA chunks - // marked for retransmission will fit into a single packet, subject - // to constraint of the path MTU of the destination transport - // address to which the packet is being sent. Call this value K. - // Retransmit those K DATA chunks in a single packet. When a Fast - // Retransmit is being performed, the sender SHOULD ignore the value - // of cwnd and SHOULD NOT delay retransmission for this single - // packet. - - let data_chunk_size = DATA_CHUNK_HEADER_SIZE + c.user_data.len() as u32; - if self.mtu < fast_retrans_size + data_chunk_size { - break; - } - - fast_retrans_size += data_chunk_size; - self.stats.inc_fast_retrans(); - c.nsent += 1; - } else { - break; // end of pending data - } - - if let Some(c) = self.inflight_queue.get(tsn) { - self.check_partial_reliability_status(c); - to_fast_retrans.push(Box::new(c.clone())); - log::trace!( - "[{}] fast-retransmit: tsn={} sent={} htna={}", - self.name, - c.tsn, - c.nsent, - self.fast_recover_exit_point - ); - } - i += 1; - } - - if !to_fast_retrans.is_empty() { - let p = self.create_packet(to_fast_retrans); - raw_packets.push(p); - } - } - - raw_packets - } - - async fn gather_outbound_sack_packets(&mut self, mut raw_packets: Vec) -> Vec { - if self.ack_state == AckState::Immediate { - self.ack_state = AckState::Idle; - let sack = self.create_selective_ack_chunk().await; - log::debug!("[{}] sending SACK: {}", self.name, sack); - let p = self.create_packet(vec![Box::new(sack)]); - raw_packets.push(p); - } - - raw_packets - } - - fn gather_outbound_forward_tsn_packets(&mut self, mut raw_packets: Vec) -> Vec { - /*log::debug!( - "[{}] gatherOutboundForwardTSNPackets {}", - self.name, - self.will_send_forward_tsn - );*/ - if self.will_send_forward_tsn { - self.will_send_forward_tsn = false; - if sna32gt( - self.advanced_peer_tsn_ack_point, - self.cumulative_tsn_ack_point, - ) { - let fwd_tsn = self.create_forward_tsn(); - let p = self.create_packet(vec![Box::new(fwd_tsn)]); - raw_packets.push(p); - } - } - - raw_packets - } - - async fn gather_outbound_shutdown_packets( - &mut self, - mut raw_packets: Vec, - ) -> (Vec, bool) { - let mut ok = true; - - if self.will_send_shutdown.load(Ordering::SeqCst) { - self.will_send_shutdown.store(false, Ordering::SeqCst); - - let shutdown = ChunkShutdown { - cumulative_tsn_ack: self.cumulative_tsn_ack_point, - }; - - let p = self.create_packet(vec![Box::new(shutdown)]); - if let Some(t2shutdown) = &self.t2shutdown { - t2shutdown.start(self.rto_mgr.get_rto()).await; - } - raw_packets.push(p); - } else if self.will_send_shutdown_ack { - self.will_send_shutdown_ack = false; - - let shutdown_ack = ChunkShutdownAck {}; - - let p = self.create_packet(vec![Box::new(shutdown_ack)]); - if let Some(t2shutdown) = &self.t2shutdown { - t2shutdown.start(self.rto_mgr.get_rto()).await; - } - raw_packets.push(p); - } else if self.will_send_shutdown_complete { - self.will_send_shutdown_complete = false; - - let shutdown_complete = ChunkShutdownComplete {}; - ok = false; - let p = self.create_packet(vec![Box::new(shutdown_complete)]); - - raw_packets.push(p); - } - - (raw_packets, ok) - } - - /// gather_outbound gathers outgoing packets. The returned bool value set to - /// false means the association should be closed down after the final send. - pub(crate) async fn gather_outbound(&mut self) -> (Vec, bool) { - let mut raw_packets = Vec::with_capacity(16); - - if !self.control_queue.is_empty() { - for p in self.control_queue.drain(..) { - raw_packets.push(p); - } - } - - let state = self.get_state(); - match state { - AssociationState::Established => { - raw_packets = self.gather_data_packets_to_retransmit(raw_packets); - raw_packets = self - .gather_outbound_data_and_reconfig_packets(raw_packets) - .await; - raw_packets = self.gather_outbound_fast_retransmission_packets(raw_packets); - raw_packets = self.gather_outbound_sack_packets(raw_packets).await; - raw_packets = self.gather_outbound_forward_tsn_packets(raw_packets); - (raw_packets, true) - } - AssociationState::ShutdownPending - | AssociationState::ShutdownSent - | AssociationState::ShutdownReceived => { - raw_packets = self.gather_data_packets_to_retransmit(raw_packets); - raw_packets = self.gather_outbound_fast_retransmission_packets(raw_packets); - raw_packets = self.gather_outbound_sack_packets(raw_packets).await; - self.gather_outbound_shutdown_packets(raw_packets).await - } - AssociationState::ShutdownAckSent => { - self.gather_outbound_shutdown_packets(raw_packets).await - } - _ => (raw_packets, true), - } - } - - /// set_state atomically sets the state of the Association. - pub(crate) fn set_state(&self, new_state: AssociationState) { - let old_state = AssociationState::from(self.state.swap(new_state as u8, Ordering::SeqCst)); - if new_state != old_state { - log::debug!( - "[{}] state change: '{}' => '{}'", - self.name, - old_state, - new_state, - ); - } - } - - /// get_state atomically returns the state of the Association. - fn get_state(&self) -> AssociationState { - self.state.load(Ordering::SeqCst).into() - } - - async fn handle_init(&mut self, p: &Packet, i: &ChunkInit) -> Result> { - let state = self.get_state(); - log::debug!("[{}] chunkInit received in state '{}'", self.name, state); - - // https://tools.ietf.org/html/rfc4960#section-5.2.1 - // Upon receipt of an INIT in the COOKIE-WAIT state, an endpoint MUST - // respond with an INIT ACK using the same parameters it sent in its - // original INIT chunk (including its Initiate Tag, unchanged). When - // responding, the endpoint MUST send the INIT ACK back to the same - // address that the original INIT (sent by this endpoint) was sent. - - if state != AssociationState::Closed - && state != AssociationState::CookieWait - && state != AssociationState::CookieEchoed - { - // 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, - // COOKIE-WAIT, and SHUTDOWN-ACK-SENT - return Err(Error::ErrHandleInitState); - } - - // Should we be setting any of these permanently until we've ACKed further? - self.my_max_num_inbound_streams = - std::cmp::min(i.num_inbound_streams, self.my_max_num_inbound_streams); - self.my_max_num_outbound_streams = - std::cmp::min(i.num_outbound_streams, self.my_max_num_outbound_streams); - self.peer_verification_tag = i.initiate_tag; - self.source_port = p.destination_port; - self.destination_port = p.source_port; - - // 13.2 This is the last TSN received in sequence. This value - // is set initially by taking the peer's initial TSN, - // received in the INIT or INIT ACK chunk, and - // subtracting one from it. - self.peer_last_tsn = if i.initial_tsn == 0 { - u32::MAX - } else { - i.initial_tsn - 1 - }; - - for param in &i.params { - if let Some(v) = param.as_any().downcast_ref::() { - for t in &v.chunk_types { - if *t == CT_FORWARD_TSN { - log::debug!("[{}] use ForwardTSN (on init)", self.name); - self.use_forward_tsn = true; - } - } - } - } - if !self.use_forward_tsn { - log::warn!("[{}] not using ForwardTSN (on init)", self.name); - } - - let mut outbound = Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - ..Default::default() - }; - - // According to RFC https://datatracker.ietf.org/doc/html/rfc4960#section-3.2.2 - // We report unknown parameters with a paramtype with bit 14 set as unrecognized - let unrecognized_params_from_init = i - .params - .iter() - .filter_map(|param| { - if let ParamType::Unknown { param_type } = param.header().typ { - let needs_to_be_reported = ((param_type >> 14) & 0x01) == 1; - if needs_to_be_reported { - let wrapped: Box = - Box::new(ParamUnrecognized::wrap(param.clone())); - Some(wrapped) - } else { - None - } - } else { - None - } - }) - .collect(); - - let mut init_ack = ChunkInit { - is_ack: true, - initial_tsn: self.my_next_tsn, - num_outbound_streams: self.my_max_num_outbound_streams, - num_inbound_streams: self.my_max_num_inbound_streams, - initiate_tag: self.my_verification_tag, - advertised_receiver_window_credit: self.max_receive_buffer_size, - params: unrecognized_params_from_init, - }; - - if self.my_cookie.is_none() { - self.my_cookie = Some(ParamStateCookie::new()); - } - - if let Some(my_cookie) = &self.my_cookie { - init_ack.params = vec![Box::new(my_cookie.clone())]; - } - - init_ack.set_supported_extensions(); - - outbound.chunks = vec![Box::new(init_ack)]; - - Ok(vec![outbound]) - } - - async fn handle_init_ack(&mut self, p: &Packet, i: &ChunkInit) -> Result> { - let state = self.get_state(); - log::debug!("[{}] chunkInitAck received in state '{}'", self.name, state); - if state != AssociationState::CookieWait { - // RFC 4960 - // 5.2.3. Unexpected INIT ACK - // If an INIT ACK is received by an endpoint in any state other than the - // COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk. - // An unexpected INIT ACK usually indicates the processing of an old or - // duplicated INIT chunk. - return Ok(vec![]); - } - - self.my_max_num_inbound_streams = - std::cmp::min(i.num_inbound_streams, self.my_max_num_inbound_streams); - self.my_max_num_outbound_streams = - std::cmp::min(i.num_outbound_streams, self.my_max_num_outbound_streams); - self.peer_verification_tag = i.initiate_tag; - self.peer_last_tsn = if i.initial_tsn == 0 { - u32::MAX - } else { - i.initial_tsn - 1 - }; - if self.source_port != p.destination_port || self.destination_port != p.source_port { - log::warn!("[{}] handle_init_ack: port mismatch", self.name); - return Ok(vec![]); - } - - self.rwnd = i.advertised_receiver_window_credit; - log::debug!("[{}] initial rwnd={}", self.name, self.rwnd); - - // RFC 4690 Sec 7.2.1 - // o The initial value of ssthresh MAY be arbitrarily high (for - // example, implementations MAY use the size of the receiver - // advertised window). - self.ssthresh = self.rwnd; - log::trace!( - "[{}] updated cwnd={} ssthresh={} inflight={} (INI)", - self.name, - self.cwnd, - self.ssthresh, - self.inflight_queue.get_num_bytes() - ); - - if let Some(t1init) = &self.t1init { - t1init.stop().await; - } - self.stored_init = None; - - let mut cookie_param = None; - for param in &i.params { - if let Some(v) = param.as_any().downcast_ref::() { - cookie_param = Some(v); - } else if let Some(v) = param.as_any().downcast_ref::() { - for t in &v.chunk_types { - if *t == CT_FORWARD_TSN { - log::debug!("[{}] use ForwardTSN (on initAck)", self.name); - self.use_forward_tsn = true; - } - } - } else if param - .as_any() - .downcast_ref::() - .is_some() - { - self.use_forward_tsn = true; - } - } - if !self.use_forward_tsn { - log::warn!("[{}] not using ForwardTSN (on initAck)", self.name); - } - - if let Some(v) = cookie_param { - self.stored_cookie_echo = Some(ChunkCookieEcho { - cookie: v.cookie.clone(), - }); - - self.send_cookie_echo()?; - - if let Some(t1cookie) = &self.t1cookie { - t1cookie.start(self.rto_mgr.get_rto()).await; - } - - self.set_state(AssociationState::CookieEchoed); - - Ok(vec![]) - } else { - Err(Error::ErrInitAckNoCookie) - } - } - - async fn handle_heartbeat(&self, c: &ChunkHeartbeat) -> Result> { - log::trace!("[{}] chunkHeartbeat", self.name); - if let Some(p) = c.params.first() { - if let Some(hbi) = p.as_any().downcast_ref::() { - return Ok(vec![Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - chunks: vec![Box::new(ChunkHeartbeatAck { - params: vec![Box::new(ParamHeartbeatInfo { - heartbeat_information: hbi.heartbeat_information.clone(), - })], - })], - }]); - } else { - log::warn!( - "[{}] failed to handle Heartbeat, no ParamHeartbeatInfo", - self.name, - ); - } - } - - Ok(vec![]) - } - - async fn handle_cookie_echo(&mut self, c: &ChunkCookieEcho) -> Result> { - let state = self.get_state(); - log::debug!("[{}] COOKIE-ECHO received in state '{}'", self.name, state); - - if let Some(my_cookie) = &self.my_cookie { - match state { - AssociationState::Established => { - if my_cookie.cookie != c.cookie { - return Ok(vec![]); - } - } - AssociationState::Closed - | AssociationState::CookieWait - | AssociationState::CookieEchoed => { - if my_cookie.cookie != c.cookie { - return Ok(vec![]); - } - - if let Some(t1init) = &self.t1init { - t1init.stop().await; - } - self.stored_init = None; - - if let Some(t1cookie) = &self.t1cookie { - t1cookie.stop().await; - } - self.stored_cookie_echo = None; - - self.set_state(AssociationState::Established); - if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx { - let _ = handshake_completed_ch.send(None).await; - } - } - _ => return Ok(vec![]), - }; - } else { - log::debug!("[{}] COOKIE-ECHO received before initialization", self.name); - return Ok(vec![]); - } - - Ok(vec![Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - chunks: vec![Box::new(ChunkCookieAck {})], - }]) - } - - async fn handle_cookie_ack(&mut self) -> Result> { - let state = self.get_state(); - log::debug!("[{}] COOKIE-ACK received in state '{}'", self.name, state); - if state != AssociationState::CookieEchoed { - // RFC 4960 - // 5.2.5. Handle Duplicate COOKIE-ACK. - // At any state other than COOKIE-ECHOED, an endpoint should silently - // discard a received COOKIE ACK chunk. - return Ok(vec![]); - } - - if let Some(t1cookie) = &self.t1cookie { - t1cookie.stop().await; - } - self.stored_cookie_echo = None; - - self.set_state(AssociationState::Established); - if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx { - let _ = handshake_completed_ch.send(None).await; - } - - Ok(vec![]) - } - - async fn handle_data(&mut self, d: &ChunkPayloadData) -> Result> { - log::trace!( - "[{}] DATA: tsn={} immediateSack={} len={}", - self.name, - d.tsn, - d.immediate_sack, - d.user_data.len() - ); - self.stats.inc_datas(); - - let can_push = self.payload_queue.can_push(d, self.peer_last_tsn); - let mut stream_handle_data = false; - if can_push { - if let Some(_s) = self.get_or_create_stream(d.stream_identifier) { - if self.get_my_receiver_window_credit().await > 0 { - // Pass the new chunk to stream level as soon as it arrives - self.payload_queue.push(d.clone(), self.peer_last_tsn); - stream_handle_data = true; - } else { - // Receive buffer is full - if let Some(last_tsn) = self.payload_queue.get_last_tsn_received() { - if sna32lt(d.tsn, *last_tsn) { - log::debug!("[{}] receive buffer full, but accepted as this is a missing chunk with tsn={} ssn={}", self.name, d.tsn, d.stream_sequence_number); - self.payload_queue.push(d.clone(), self.peer_last_tsn); - stream_handle_data = true; //s.handle_data(d.clone()); - } - } else { - log::debug!( - "[{}] receive buffer full. dropping DATA with tsn={} ssn={}", - self.name, - d.tsn, - d.stream_sequence_number - ); - } - } - } else { - // silently discard the data. (sender will retry on T3-rtx timeout) - // see pion/sctp#30 - log::debug!("discard {}", d.stream_sequence_number); - return Ok(vec![]); - } - } - - let immediate_sack = d.immediate_sack; - - if stream_handle_data { - if let Some(s) = self.streams.get_mut(&d.stream_identifier) { - s.handle_data(d.clone()).await; - } - } - - self.handle_peer_last_tsn_and_acknowledgement(immediate_sack) - } - - /// A common routine for handle_data and handle_forward_tsn routines - fn handle_peer_last_tsn_and_acknowledgement( - &mut self, - sack_immediately: bool, - ) -> Result> { - let mut reply = vec![]; - - // Try to advance peer_last_tsn - - // From RFC 3758 Sec 3.6: - // .. and then MUST further advance its cumulative TSN point locally - // if possible - // Meaning, if peer_last_tsn+1 points to a chunk that is received, - // advance peer_last_tsn until peer_last_tsn+1 points to unreceived chunk. - log::debug!("[{}] peer_last_tsn = {}", self.name, self.peer_last_tsn); - while self.payload_queue.pop(self.peer_last_tsn + 1).is_some() { - self.peer_last_tsn += 1; - log::debug!("[{}] peer_last_tsn = {}", self.name, self.peer_last_tsn); - - let rst_reqs: Vec = - self.reconfig_requests.values().cloned().collect(); - for rst_req in rst_reqs { - self.reset_streams_if_any(&rst_req, false, &mut reply)?; - } - } - - let has_packet_loss = !self.payload_queue.is_empty(); - if has_packet_loss { - log::trace!( - "[{}] packetloss: {}", - self.name, - self.payload_queue - .get_gap_ack_blocks_string(self.peer_last_tsn) - ); - } - - if (self.ack_state != AckState::Immediate - && !sack_immediately - && !has_packet_loss - && self.ack_mode == AckMode::Normal) - || self.ack_mode == AckMode::AlwaysDelay - { - if self.ack_state == AckState::Idle { - self.delayed_ack_triggered = true; - } else { - self.immediate_ack_triggered = true; - } - } else { - self.immediate_ack_triggered = true; - } - - Ok(reply) - } - - pub(crate) async fn get_my_receiver_window_credit(&self) -> u32 { - let mut bytes_queued = 0; - for s in self.streams.values() { - bytes_queued += s.get_num_bytes_in_reassembly_queue().await as u32; - } - - if bytes_queued >= self.max_receive_buffer_size { - 0 - } else { - self.max_receive_buffer_size - bytes_queued - } - } - - pub(crate) fn open_stream( - &mut self, - stream_identifier: u16, - default_payload_type: PayloadProtocolIdentifier, - ) -> Result> { - if self.streams.contains_key(&stream_identifier) { - return Err(Error::ErrStreamAlreadyExist); - } - - if let Some(s) = self.create_stream(stream_identifier, false) { - s.set_default_payload_type(default_payload_type); - Ok(Arc::clone(&s)) - } else { - Err(Error::ErrStreamCreateFailed) - } - } - - /// create_stream creates a stream. The caller should hold the lock and check no stream exists for this id. - fn create_stream(&mut self, stream_identifier: u16, accept: bool) -> Option> { - let s = Arc::new(Stream::new( - format!("{}:{}", stream_identifier, self.name), - stream_identifier, - self.max_payload_size, - Arc::clone(&self.max_message_size), - Arc::clone(&self.state), - self.awake_write_loop_ch.clone(), - Arc::clone(&self.pending_queue), - )); - - if accept { - if let Some(accept_ch) = &self.accept_ch_tx { - if accept_ch.try_send(Arc::clone(&s)).is_ok() { - log::debug!( - "[{}] accepted a new stream (streamIdentifier: {})", - self.name, - stream_identifier - ); - } else { - log::debug!("[{}] dropped a new stream due to accept_ch full", self.name); - return None; - } - } else { - log::debug!( - "[{}] dropped a new stream due to accept_ch_tx is None", - self.name - ); - return None; - } - } - self.streams.insert(stream_identifier, Arc::clone(&s)); - Some(s) - } - - /// get_or_create_stream gets or creates a stream. The caller should hold the lock. - fn get_or_create_stream(&mut self, stream_identifier: u16) -> Option> { - if self.streams.contains_key(&stream_identifier) { - self.streams.get(&stream_identifier).cloned() - } else { - self.create_stream(stream_identifier, true) - } - } - - async fn process_selective_ack( - &mut self, - d: &ChunkSelectiveAck, - ) -> Result<(HashMap, u32)> { - let mut bytes_acked_per_stream = HashMap::new(); - - // New ack point, so pop all ACKed packets from inflight_queue - // We add 1 because the "currentAckPoint" has already been popped from the inflight queue - // For the first SACK we take care of this by setting the ackpoint to cumAck - 1 - let mut i = self.cumulative_tsn_ack_point + 1; - //log::debug!("[{}] i={} d={}", self.name, i, d.cumulative_tsn_ack); - while sna32lte(i, d.cumulative_tsn_ack) { - if let Some(c) = self.inflight_queue.pop(i) { - if !c.acked { - // RFC 4096 sec 6.3.2. Retransmission Timer Rules - // R3) Whenever a SACK is received that acknowledges the DATA chunk - // with the earliest outstanding TSN for that address, restart the - // T3-rtx timer for that address with its current RTO (if there is - // still outstanding data on that address). - if i == self.cumulative_tsn_ack_point + 1 { - // T3 timer needs to be reset. Stop it for now. - if let Some(t3rtx) = &self.t3rtx { - t3rtx.stop().await; - } - } - - let n_bytes_acked = c.user_data.len() as i64; - - // Sum the number of bytes acknowledged per stream - if let Some(amount) = bytes_acked_per_stream.get_mut(&c.stream_identifier) { - *amount += n_bytes_acked; - } else { - bytes_acked_per_stream.insert(c.stream_identifier, n_bytes_acked); - } - - // RFC 4960 sec 6.3.1. RTO Calculation - // C4) When data is in flight and when allowed by rule C5 below, a new - // RTT measurement MUST be made each round trip. Furthermore, new - // RTT measurements SHOULD be made no more than once per round trip - // for a given destination transport address. - // C5) Karn's algorithm: RTT measurements MUST NOT be made using - // packets that were retransmitted (and thus for which it is - // ambiguous whether the reply was for the first instance of the - // chunk or for a later instance) - if c.nsent == 1 && sna32gte(c.tsn, self.min_tsn2measure_rtt) { - self.min_tsn2measure_rtt = self.my_next_tsn; - let rtt = match SystemTime::now().duration_since(c.since) { - Ok(rtt) => rtt, - Err(_) => return Err(Error::ErrInvalidSystemTime), - }; - let srtt = self.rto_mgr.set_new_rtt(rtt.as_millis() as u64); - log::trace!( - "[{}] SACK: measured-rtt={} srtt={} new-rto={}", - self.name, - rtt.as_millis(), - srtt, - self.rto_mgr.get_rto() - ); - } - } - - if self.in_fast_recovery && c.tsn == self.fast_recover_exit_point { - log::debug!("[{}] exit fast-recovery", self.name); - self.in_fast_recovery = false; - } - } else { - return Err(Error::ErrInflightQueueTsnPop); - } - - i += 1; - } - - let mut htna = d.cumulative_tsn_ack; - - // Mark selectively acknowledged chunks as "acked" - for g in &d.gap_ack_blocks { - for i in g.start..=g.end { - let tsn = d.cumulative_tsn_ack + i as u32; - - let (is_existed, is_acked) = if let Some(c) = self.inflight_queue.get(tsn) { - (true, c.acked) - } else { - (false, false) - }; - let n_bytes_acked = if is_existed && !is_acked { - self.inflight_queue.mark_as_acked(tsn) as i64 - } else { - 0 - }; - - if let Some(c) = self.inflight_queue.get(tsn) { - if !is_acked { - // Sum the number of bytes acknowledged per stream - if let Some(amount) = bytes_acked_per_stream.get_mut(&c.stream_identifier) { - *amount += n_bytes_acked; - } else { - bytes_acked_per_stream.insert(c.stream_identifier, n_bytes_acked); - } - - log::trace!("[{}] tsn={} has been sacked", self.name, c.tsn); - - if c.nsent == 1 { - self.min_tsn2measure_rtt = self.my_next_tsn; - let rtt = match SystemTime::now().duration_since(c.since) { - Ok(rtt) => rtt, - Err(_) => return Err(Error::ErrInvalidSystemTime), - }; - let srtt = self.rto_mgr.set_new_rtt(rtt.as_millis() as u64); - log::trace!( - "[{}] SACK: measured-rtt={} srtt={} new-rto={}", - self.name, - rtt.as_millis(), - srtt, - self.rto_mgr.get_rto() - ); - } - - if sna32lt(htna, tsn) { - htna = tsn; - } - } - } else { - return Err(Error::ErrTsnRequestNotExist); - } - } - } - - Ok((bytes_acked_per_stream, htna)) - } - - async fn on_cumulative_tsn_ack_point_advanced(&mut self, total_bytes_acked: i64) { - // RFC 4096, sec 6.3.2. Retransmission Timer Rules - // R2) Whenever all outstanding data sent to an address have been - // acknowledged, turn off the T3-rtx timer of that address. - if self.inflight_queue.is_empty() { - log::trace!( - "[{}] SACK: no more packet in-flight (pending={})", - self.name, - self.pending_queue.len() - ); - if let Some(t3rtx) = &self.t3rtx { - t3rtx.stop().await; - } - } else { - log::trace!("[{}] T3-rtx timer start (pt2)", self.name); - if let Some(t3rtx) = &self.t3rtx { - t3rtx.start(self.rto_mgr.get_rto()).await; - } - } - - // Update congestion control parameters - if self.cwnd <= self.ssthresh { - // RFC 4096, sec 7.2.1. Slow-Start - // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST - // use the slow-start algorithm to increase cwnd only if the current - // congestion window is being fully utilized, an incoming SACK - // advances the Cumulative TSN Ack Point, and the data sender is not - // in Fast Recovery. Only when these three conditions are met can - // the cwnd be increased; otherwise, the cwnd MUST not be increased. - // If these conditions are met, then cwnd MUST be increased by, at - // most, the lesser of 1) the total size of the previously - // outstanding DATA chunk(s) acknowledged, and 2) the destination's - // path MTU. - if !self.in_fast_recovery && self.pending_queue.len() > 0 { - self.cwnd += std::cmp::min(total_bytes_acked as u32, self.cwnd); // TCP way - // self.cwnd += min32(uint32(total_bytes_acked), self.mtu) // SCTP way (slow) - log::trace!( - "[{}] updated cwnd={} ssthresh={} acked={} (SS)", - self.name, - self.cwnd, - self.ssthresh, - total_bytes_acked - ); - } else { - log::trace!( - "[{}] cwnd did not grow: cwnd={} ssthresh={} acked={} FR={} pending={}", - self.name, - self.cwnd, - self.ssthresh, - total_bytes_acked, - self.in_fast_recovery, - self.pending_queue.len() - ); - } - } else { - // RFC 4096, sec 7.2.2. Congestion Avoidance - // o Whenever cwnd is greater than ssthresh, upon each SACK arrival - // that advances the Cumulative TSN Ack Point, increase - // partial_bytes_acked by the total number of bytes of all new chunks - // acknowledged in that SACK including chunks acknowledged by the new - // Cumulative TSN Ack and by Gap Ack Blocks. - self.partial_bytes_acked += total_bytes_acked as u32; - - // o When partial_bytes_acked is equal to or greater than cwnd and - // before the arrival of the SACK the sender had cwnd or more bytes - // of data outstanding (i.e., before arrival of the SACK, flight size - // was greater than or equal to cwnd), increase cwnd by MTU, and - // reset partial_bytes_acked to (partial_bytes_acked - cwnd). - if self.partial_bytes_acked >= self.cwnd && self.pending_queue.len() > 0 { - self.partial_bytes_acked -= self.cwnd; - self.cwnd += self.mtu; - log::trace!( - "[{}] updated cwnd={} ssthresh={} acked={} (CA)", - self.name, - self.cwnd, - self.ssthresh, - total_bytes_acked - ); - } - } - } - - fn process_fast_retransmission( - &mut self, - cum_tsn_ack_point: u32, - htna: u32, - cum_tsn_ack_point_advanced: bool, - ) -> Result<()> { - // HTNA algorithm - RFC 4960 Sec 7.2.4 - // Increment missIndicator of each chunks that the SACK reported missing - // when either of the following is met: - // a) Not in fast-recovery - // miss indications are incremented only for missing TSNs prior to the - // highest TSN newly acknowledged in the SACK. - // b) In fast-recovery AND the Cumulative TSN Ack Point advanced - // the miss indications are incremented for all TSNs reported missing - // in the SACK. - if !self.in_fast_recovery || cum_tsn_ack_point_advanced { - let max_tsn = if !self.in_fast_recovery { - // a) increment only for missing TSNs prior to the HTNA - htna - } else { - // b) increment for all TSNs reported missing - cum_tsn_ack_point + (self.inflight_queue.len() as u32) + 1 - }; - - let mut tsn = cum_tsn_ack_point + 1; - while sna32lt(tsn, max_tsn) { - if let Some(c) = self.inflight_queue.get_mut(tsn) { - if !c.acked && !c.abandoned() && c.miss_indicator < 3 { - c.miss_indicator += 1; - if c.miss_indicator == 3 && !self.in_fast_recovery { - // 2) If not in Fast Recovery, adjust the ssthresh and cwnd of the - // destination address(es) to which the missing DATA chunks were - // last sent, according to the formula described in Section 7.2.3. - self.in_fast_recovery = true; - self.fast_recover_exit_point = htna; - self.ssthresh = std::cmp::max(self.cwnd / 2, 4 * self.mtu); - self.cwnd = self.ssthresh; - self.partial_bytes_acked = 0; - self.will_retransmit_fast = true; - - log::trace!( - "[{}] updated cwnd={} ssthresh={} inflight={} (FR)", - self.name, - self.cwnd, - self.ssthresh, - self.inflight_queue.get_num_bytes() - ); - } - } - } else { - return Err(Error::ErrTsnRequestNotExist); - } - - tsn += 1; - } - } - - if self.in_fast_recovery && cum_tsn_ack_point_advanced { - self.will_retransmit_fast = true; - } - - Ok(()) - } - - async fn handle_sack(&mut self, d: &ChunkSelectiveAck) -> Result> { - log::trace!( - "[{}] {}, SACK: cumTSN={} a_rwnd={}", - self.name, - self.cumulative_tsn_ack_point, - d.cumulative_tsn_ack, - d.advertised_receiver_window_credit - ); - let state = self.get_state(); - if state != AssociationState::Established - && state != AssociationState::ShutdownPending - && state != AssociationState::ShutdownReceived - { - return Ok(vec![]); - } - - self.stats.inc_sacks(); - - if sna32gt(self.cumulative_tsn_ack_point, d.cumulative_tsn_ack) { - // RFC 4960 sec 6.2.1. Processing a Received SACK - // D) - // i) If Cumulative TSN Ack is less than the Cumulative TSN Ack - // Point, then drop the SACK. Since Cumulative TSN Ack is - // monotonically increasing, a SACK whose Cumulative TSN Ack is - // less than the Cumulative TSN Ack Point indicates an out-of- - // order SACK. - - log::debug!( - "[{}] SACK Cumulative ACK {} is older than ACK point {}", - self.name, - d.cumulative_tsn_ack, - self.cumulative_tsn_ack_point - ); - - return Ok(vec![]); - } - - // Process selective ack - let (bytes_acked_per_stream, htna) = self.process_selective_ack(d).await?; - - let mut total_bytes_acked = 0; - for n_bytes_acked in bytes_acked_per_stream.values() { - total_bytes_acked += *n_bytes_acked; - } - - let mut cum_tsn_ack_point_advanced = false; - if sna32lt(self.cumulative_tsn_ack_point, d.cumulative_tsn_ack) { - log::trace!( - "[{}] SACK: cumTSN advanced: {} -> {}", - self.name, - self.cumulative_tsn_ack_point, - d.cumulative_tsn_ack - ); - - self.cumulative_tsn_ack_point = d.cumulative_tsn_ack; - cum_tsn_ack_point_advanced = true; - self.on_cumulative_tsn_ack_point_advanced(total_bytes_acked) - .await; - } - - for (si, n_bytes_acked) in &bytes_acked_per_stream { - if let Some(s) = self.streams.get_mut(si) { - s.on_buffer_released(*n_bytes_acked).await; - } - } - - // New rwnd value - // RFC 4960 sec 6.2.1. Processing a Received SACK - // D) - // ii) Set rwnd equal to the newly received a_rwnd minus the number - // of bytes still outstanding after processing the Cumulative - // TSN Ack and the Gap Ack Blocks. - - // bytes acked were already subtracted by markAsAcked() method - let bytes_outstanding = self.inflight_queue.get_num_bytes() as u32; - if bytes_outstanding >= d.advertised_receiver_window_credit { - self.rwnd = 0; - } else { - self.rwnd = d.advertised_receiver_window_credit - bytes_outstanding; - } - - self.process_fast_retransmission(d.cumulative_tsn_ack, htna, cum_tsn_ack_point_advanced)?; - - if self.use_forward_tsn { - // RFC 3758 Sec 3.5 C1 - if sna32lt( - self.advanced_peer_tsn_ack_point, - self.cumulative_tsn_ack_point, - ) { - self.advanced_peer_tsn_ack_point = self.cumulative_tsn_ack_point - } - - // RFC 3758 Sec 3.5 C2 - let mut i = self.advanced_peer_tsn_ack_point + 1; - while let Some(c) = self.inflight_queue.get(i) { - if !c.abandoned() { - break; - } - self.advanced_peer_tsn_ack_point = i; - i += 1; - } - - // RFC 3758 Sec 3.5 C3 - if sna32gt( - self.advanced_peer_tsn_ack_point, - self.cumulative_tsn_ack_point, - ) { - self.will_send_forward_tsn = true; - log::debug!( - "[{}] handleSack {}: sna32GT({}, {})", - self.name, - self.will_send_forward_tsn, - self.advanced_peer_tsn_ack_point, - self.cumulative_tsn_ack_point - ); - } - self.awake_write_loop(); - } - - self.postprocess_sack(state, cum_tsn_ack_point_advanced) - .await; - - Ok(vec![]) - } - - /// The caller must hold the lock. This method was only added because the - /// linter was complaining about the "cognitive complexity" of handle_sack. - async fn postprocess_sack( - &mut self, - state: AssociationState, - mut should_awake_write_loop: bool, - ) { - if !self.inflight_queue.is_empty() { - // Start timer. (noop if already started) - log::trace!("[{}] T3-rtx timer start (pt3)", self.name); - if let Some(t3rtx) = &self.t3rtx { - t3rtx.start(self.rto_mgr.get_rto()).await; - } - } else if state == AssociationState::ShutdownPending { - // No more outstanding, send shutdown. - should_awake_write_loop = true; - self.will_send_shutdown.store(true, Ordering::SeqCst); - self.set_state(AssociationState::ShutdownSent); - } else if state == AssociationState::ShutdownReceived { - // No more outstanding, send shutdown ack. - should_awake_write_loop = true; - self.will_send_shutdown_ack = true; - self.set_state(AssociationState::ShutdownAckSent); - } - - if should_awake_write_loop { - self.awake_write_loop(); - } - } - - async fn handle_shutdown(&mut self, _: &ChunkShutdown) -> Result> { - let state = self.get_state(); - - if state == AssociationState::Established { - if !self.inflight_queue.is_empty() { - self.set_state(AssociationState::ShutdownReceived); - } else { - // No more outstanding, send shutdown ack. - self.will_send_shutdown_ack = true; - self.set_state(AssociationState::ShutdownAckSent); - - self.awake_write_loop(); - } - } else if state == AssociationState::ShutdownSent { - // self.cumulative_tsn_ack_point = c.cumulative_tsn_ack - - self.will_send_shutdown_ack = true; - self.set_state(AssociationState::ShutdownAckSent); - - self.awake_write_loop(); - } - - Ok(vec![]) - } - - async fn handle_shutdown_ack(&mut self, _: &ChunkShutdownAck) -> Result> { - let state = self.get_state(); - if state == AssociationState::ShutdownSent || state == AssociationState::ShutdownAckSent { - if let Some(t2shutdown) = &self.t2shutdown { - t2shutdown.stop().await; - } - self.will_send_shutdown_complete = true; - - self.awake_write_loop(); - } - - Ok(vec![]) - } - - async fn handle_shutdown_complete(&mut self, _: &ChunkShutdownComplete) -> Result> { - let state = self.get_state(); - if state == AssociationState::ShutdownAckSent { - if let Some(t2shutdown) = &self.t2shutdown { - t2shutdown.stop().await; - } - self.close().await?; - } - - Ok(vec![]) - } - - /// create_forward_tsn generates ForwardTSN chunk. - /// This method will be be called if use_forward_tsn is set to false. - fn create_forward_tsn(&self) -> ChunkForwardTsn { - // RFC 3758 Sec 3.5 C4 - let mut stream_map: HashMap = HashMap::new(); // to report only once per SI - let mut i = self.cumulative_tsn_ack_point + 1; - while sna32lte(i, self.advanced_peer_tsn_ack_point) { - if let Some(c) = self.inflight_queue.get(i) { - if let Some(ssn) = stream_map.get(&c.stream_identifier) { - if sna16lt(*ssn, c.stream_sequence_number) { - // to report only once with greatest SSN - stream_map.insert(c.stream_identifier, c.stream_sequence_number); - } - } else { - stream_map.insert(c.stream_identifier, c.stream_sequence_number); - } - } else { - break; - } - - i += 1; - } - - let mut fwd_tsn = ChunkForwardTsn { - new_cumulative_tsn: self.advanced_peer_tsn_ack_point, - streams: vec![], - }; - - let mut stream_str = String::new(); - for (si, ssn) in &stream_map { - stream_str += format!("(si={si} ssn={ssn})").as_str(); - fwd_tsn.streams.push(ChunkForwardTsnStream { - identifier: *si, - sequence: *ssn, - }); - } - log::trace!( - "[{}] building fwd_tsn: newCumulativeTSN={} cumTSN={} - {}", - self.name, - fwd_tsn.new_cumulative_tsn, - self.cumulative_tsn_ack_point, - stream_str - ); - - fwd_tsn - } - - /// create_packet wraps chunks in a packet. - /// The caller should hold the read lock. - pub(crate) fn create_packet(&self, chunks: Vec>) -> Packet { - Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - chunks, - } - } - - async fn handle_reconfig(&mut self, c: &ChunkReconfig) -> Result> { - log::trace!("[{}] handle_reconfig", self.name); - - let mut pp = vec![]; - - if let Some(param_a) = &c.param_a { - self.handle_reconfig_param(param_a, &mut pp).await?; - } - - if let Some(param_b) = &c.param_b { - self.handle_reconfig_param(param_b, &mut pp).await?; - } - - Ok(pp) - } - - async fn handle_forward_tsn(&mut self, c: &ChunkForwardTsn) -> Result> { - log::trace!("[{}] FwdTSN: {}", self.name, c.to_string()); - - if !self.use_forward_tsn { - log::warn!("[{}] received FwdTSN but not enabled", self.name); - // Return an error chunk - let cerr = ChunkError { - error_causes: vec![ErrorCauseUnrecognizedChunkType::default()], - }; - - let outbound = Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - chunks: vec![Box::new(cerr)], - }; - return Ok(vec![outbound]); - } - - // From RFC 3758 Sec 3.6: - // Note, if the "New Cumulative TSN" value carried in the arrived - // FORWARD TSN chunk is found to be behind or at the current cumulative - // TSN point, the data receiver MUST treat this FORWARD TSN as out-of- - // date and MUST NOT update its Cumulative TSN. The receiver SHOULD - // send a SACK to its peer (the sender of the FORWARD TSN) since such a - // duplicate may indicate the previous SACK was lost in the network. - - log::trace!( - "[{}] should send ack? newCumTSN={} peer_last_tsn={}", - self.name, - c.new_cumulative_tsn, - self.peer_last_tsn - ); - if sna32lte(c.new_cumulative_tsn, self.peer_last_tsn) { - log::trace!("[{}] sending ack on Forward TSN", self.name); - self.ack_state = AckState::Immediate; - if let Some(ack_timer) = &mut self.ack_timer { - ack_timer.stop(); - } - self.awake_write_loop(); - return Ok(vec![]); - } - - // From RFC 3758 Sec 3.6: - // the receiver MUST perform the same TSN handling, including duplicate - // detection, gap detection, SACK generation, cumulative TSN - // advancement, etc. as defined in RFC 2960 [2]---with the following - // exceptions and additions. - - // When a FORWARD TSN chunk arrives, the data receiver MUST first update - // its cumulative TSN point to the value carried in the FORWARD TSN - // chunk, - - // Advance peer_last_tsn - while sna32lt(self.peer_last_tsn, c.new_cumulative_tsn) { - self.payload_queue.pop(self.peer_last_tsn + 1); // may not exist - self.peer_last_tsn += 1; - } - - // Report new peer_last_tsn value and abandoned largest SSN value to - // corresponding streams so that the abandoned chunks can be removed - // from the reassemblyQueue. - for forwarded in &c.streams { - if let Some(s) = self.streams.get_mut(&forwarded.identifier) { - s.handle_forward_tsn_for_ordered(forwarded.sequence).await; - } - } - - // TSN may be forewared for unordered chunks. ForwardTSN chunk does not - // report which stream identifier it skipped for unordered chunks. - // Therefore, we need to broadcast this event to all existing streams for - // unordered chunks. - // See https://github.com/pion/sctp/issues/106 - for s in self.streams.values_mut() { - s.handle_forward_tsn_for_unordered(c.new_cumulative_tsn) - .await; - } - - self.handle_peer_last_tsn_and_acknowledgement(false) - } - - async fn send_reset_request(&mut self, stream_identifier: u16) -> Result<()> { - let state = self.get_state(); - if state != AssociationState::Established { - return Err(Error::ErrResetPacketInStateNotExist); - } - - // Create DATA chunk which only contains valid stream identifier with - // nil userData and use it as a EOS from the stream. - let c = ChunkPayloadData { - stream_identifier, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::new(), - ..Default::default() - }; - - self.pending_queue.push(c).await; - self.awake_write_loop(); - - Ok(()) - } - - #[allow(clippy::borrowed_box)] - async fn handle_reconfig_param( - &mut self, - raw: &Box, - reply: &mut Vec, - ) -> Result<()> { - if let Some(p) = raw.as_any().downcast_ref::() { - self.reconfig_requests - .insert(p.reconfig_request_sequence_number, p.clone()); - self.reset_streams_if_any(p, true, reply)?; - Ok(()) - } else if let Some(p) = raw.as_any().downcast_ref::() { - self.reconfigs.remove(&p.reconfig_response_sequence_number); - if self.reconfigs.is_empty() { - if let Some(treconfig) = &self.treconfig { - treconfig.stop().await; - } - } - Ok(()) - } else { - Err(Error::ErrParameterType) - } - } - - fn reset_streams_if_any( - &mut self, - p: &ParamOutgoingResetRequest, - respond: bool, - reply: &mut Vec, - ) -> Result<()> { - let mut result = ReconfigResult::SuccessPerformed; - let mut sis_to_reset = vec![]; - - if sna32lte(p.sender_last_tsn, self.peer_last_tsn) { - log::debug!( - "[{}] resetStream(): senderLastTSN={} <= peer_last_tsn={}", - self.name, - p.sender_last_tsn, - self.peer_last_tsn - ); - for id in &p.stream_identifiers { - if let Some(s) = self.streams.get(id) { - let stream_identifier = s.stream_identifier; - if respond { - sis_to_reset.push(*id); - } - self.unregister_stream(stream_identifier); - } - } - self.reconfig_requests - .remove(&p.reconfig_request_sequence_number); - } else { - log::debug!( - "[{}] resetStream(): senderLastTSN={} > peer_last_tsn={}", - self.name, - p.sender_last_tsn, - self.peer_last_tsn - ); - result = ReconfigResult::InProgress; - } - - // Answer incoming reset requests with the same reset request, but with - // reconfig_response_sequence_number. - if !sis_to_reset.is_empty() { - let rsn = self.generate_next_rsn(); - let tsn = self.my_next_tsn - 1; - - let c = ChunkReconfig { - param_a: Some(Box::new(ParamOutgoingResetRequest { - reconfig_request_sequence_number: rsn, - reconfig_response_sequence_number: p.reconfig_request_sequence_number, - sender_last_tsn: tsn, - stream_identifiers: sis_to_reset, - })), - ..Default::default() - }; - - self.reconfigs.insert(rsn, c.clone()); // store in the map for retransmission - - let p = self.create_packet(vec![Box::new(c)]); - reply.push(p); - } - - let packet = self.create_packet(vec![Box::new(ChunkReconfig { - param_a: Some(Box::new(ParamReconfigResponse { - reconfig_response_sequence_number: p.reconfig_request_sequence_number, - result, - })), - param_b: None, - })]); - - log::debug!("[{}] RESET RESPONSE: {}", self.name, packet); - - reply.push(packet); - - Ok(()) - } - - /// Move the chunk peeked with self.pending_queue.peek() to the inflight_queue. - async fn move_pending_data_chunk_to_inflight_queue( - &mut self, - beginning_fragment: bool, - unordered: bool, - ) -> Option { - if let Some(mut c) = self.pending_queue.pop(beginning_fragment, unordered) { - // Mark all fragments are in-flight now - if c.ending_fragment { - c.set_all_inflight(); - } - - // Assign TSN - c.tsn = self.generate_next_tsn(); - - c.since = SystemTime::now(); // use to calculate RTT and also for maxPacketLifeTime - c.nsent = 1; // being sent for the first time - - self.check_partial_reliability_status(&c); - - log::trace!( - "[{}] sending ppi={} tsn={} ssn={} sent={} len={} ({},{})", - self.name, - c.payload_type as u32, - c.tsn, - c.stream_sequence_number, - c.nsent, - c.user_data.len(), - c.beginning_fragment, - c.ending_fragment - ); - - self.inflight_queue.push_no_check(c.clone()); - - Some(c) - } else { - log::error!("[{}] failed to pop from pending queue", self.name); - None - } - } - - /// pop_pending_data_chunks_to_send pops chunks from the pending queues as many as - /// the cwnd and rwnd allows to send. - async fn pop_pending_data_chunks_to_send(&mut self) -> (Vec, Vec) { - let mut chunks = vec![]; - let mut sis_to_reset = vec![]; // stream identifiers to reset - - if self.pending_queue.len() == 0 { - return (chunks, sis_to_reset); - } - - // RFC 4960 sec 6.1. Transmission of DATA Chunks - // A) At any given time, the data sender MUST NOT transmit new data to - // any destination transport address if its peer's rwnd indicates - // that the peer has no buffer space (i.e., rwnd is 0; see Section - // 6.2.1). However, regardless of the value of rwnd (including if it - // is 0), the data sender can always have one DATA chunk in flight to - // the receiver if allowed by cwnd (see rule B, below). - while let Some(c) = self.pending_queue.peek() { - let (beginning_fragment, unordered, data_len, stream_identifier) = ( - c.beginning_fragment, - c.unordered, - c.user_data.len(), - c.stream_identifier, - ); - - if data_len == 0 { - sis_to_reset.push(stream_identifier); - if self - .pending_queue - .pop(beginning_fragment, unordered) - .is_none() - { - log::error!("failed to pop from pending queue"); - } - continue; - } - - if self.inflight_queue.get_num_bytes() + data_len > self.cwnd as usize { - break; // would exceed cwnd - } - - if data_len > self.rwnd as usize { - break; // no more rwnd - } - - self.rwnd -= data_len as u32; - - if let Some(chunk) = self - .move_pending_data_chunk_to_inflight_queue(beginning_fragment, unordered) - .await - { - chunks.push(chunk); - } - } - - // the data sender can always have one DATA chunk in flight to the receiver - if chunks.is_empty() && self.inflight_queue.is_empty() { - // Send zero window probe - if let Some(c) = self.pending_queue.peek() { - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - - if let Some(chunk) = self - .move_pending_data_chunk_to_inflight_queue(beginning_fragment, unordered) - .await - { - chunks.push(chunk); - } - } - } - - (chunks, sis_to_reset) - } - - /// bundle_data_chunks_into_packets packs DATA chunks into packets. It tries to bundle - /// DATA chunks into a packet so long as the resulting packet size does not exceed - /// the path MTU. - fn bundle_data_chunks_into_packets(&self, chunks: Vec) -> Vec { - let mut packets = vec![]; - let mut chunks_to_send = vec![]; - let mut bytes_in_packet = COMMON_HEADER_SIZE; - - for c in chunks { - // RFC 4960 sec 6.1. Transmission of DATA Chunks - // Multiple DATA chunks committed for transmission MAY be bundled in a - // single packet. Furthermore, DATA chunks being retransmitted MAY be - // bundled with new DATA chunks, as long as the resulting packet size - // does not exceed the path MTU. - if bytes_in_packet + c.user_data.len() as u32 > self.mtu { - packets.push(self.create_packet(chunks_to_send)); - chunks_to_send = vec![]; - bytes_in_packet = COMMON_HEADER_SIZE; - } - - bytes_in_packet += DATA_CHUNK_HEADER_SIZE + c.user_data.len() as u32; - chunks_to_send.push(Box::new(c)); - } - - if !chunks_to_send.is_empty() { - packets.push(self.create_packet(chunks_to_send)); - } - - packets - } - - fn check_partial_reliability_status(&self, c: &ChunkPayloadData) { - if !self.use_forward_tsn { - return; - } - - // draft-ietf-rtcweb-data-protocol-09.txt section 6 - // 6. Procedures - // All Data Channel Establishment Protocol messages MUST be sent using - // ordered delivery and reliable transmission. - // - if c.payload_type == PayloadProtocolIdentifier::Dcep { - return; - } - - // PR-SCTP - if let Some(s) = self.streams.get(&c.stream_identifier) { - let reliability_type: ReliabilityType = - s.reliability_type.load(Ordering::SeqCst).into(); - let reliability_value = s.reliability_value.load(Ordering::SeqCst); - - if reliability_type == ReliabilityType::Rexmit { - if c.nsent >= reliability_value { - c.set_abandoned(true); - log::trace!( - "[{}] marked as abandoned: tsn={} ppi={} (remix: {})", - self.name, - c.tsn, - c.payload_type, - c.nsent - ); - } - } else if reliability_type == ReliabilityType::Timed { - if let Ok(elapsed) = SystemTime::now().duration_since(c.since) { - if elapsed.as_millis() as u32 >= reliability_value { - c.set_abandoned(true); - log::trace!( - "[{}] marked as abandoned: tsn={} ppi={} (timed: {:?})", - self.name, - c.tsn, - c.payload_type, - elapsed - ); - } - } - } - } else { - log::error!("[{}] stream {} not found)", self.name, c.stream_identifier); - } - } - - /// get_data_packets_to_retransmit is called when T3-rtx is timed out and retransmit outstanding data chunks - /// that are not acked or abandoned yet. - fn get_data_packets_to_retransmit(&mut self) -> Vec { - let awnd = std::cmp::min(self.cwnd, self.rwnd); - let mut chunks = vec![]; - let mut bytes_to_send = 0; - let mut done = false; - let mut i = 0; - while !done { - let tsn = self.cumulative_tsn_ack_point + i + 1; - if let Some(c) = self.inflight_queue.get_mut(tsn) { - if !c.retransmit { - i += 1; - continue; - } - - if i == 0 && self.rwnd < c.user_data.len() as u32 { - // Send it as a zero window probe - done = true; - } else if bytes_to_send + c.user_data.len() > awnd as usize { - break; - } - - // reset the retransmit flag not to retransmit again before the next - // t3-rtx timer fires - c.retransmit = false; - bytes_to_send += c.user_data.len(); - - c.nsent += 1; - } else { - break; // end of pending data - } - - if let Some(c) = self.inflight_queue.get(tsn) { - self.check_partial_reliability_status(c); - - log::trace!( - "[{}] retransmitting tsn={} ssn={} sent={}", - self.name, - c.tsn, - c.stream_sequence_number, - c.nsent - ); - - chunks.push(c.clone()); - } - i += 1; - } - - self.bundle_data_chunks_into_packets(chunks) - } - - /// generate_next_tsn returns the my_next_tsn and increases it. The caller should hold the lock. - fn generate_next_tsn(&mut self) -> u32 { - let tsn = self.my_next_tsn; - self.my_next_tsn += 1; - tsn - } - - /// generate_next_rsn returns the my_next_rsn and increases it. The caller should hold the lock. - fn generate_next_rsn(&mut self) -> u32 { - let rsn = self.my_next_rsn; - self.my_next_rsn += 1; - rsn - } - - async fn create_selective_ack_chunk(&mut self) -> ChunkSelectiveAck { - ChunkSelectiveAck { - cumulative_tsn_ack: self.peer_last_tsn, - advertised_receiver_window_credit: self.get_my_receiver_window_credit().await, - gap_ack_blocks: self.payload_queue.get_gap_ack_blocks(self.peer_last_tsn), - duplicate_tsn: self.payload_queue.pop_duplicates(), - } - } - - fn pack(p: Packet) -> Vec { - vec![p] - } - - fn handle_chunk_start(&mut self) { - self.delayed_ack_triggered = false; - self.immediate_ack_triggered = false; - } - - fn handle_chunk_end(&mut self) { - if self.immediate_ack_triggered { - self.ack_state = AckState::Immediate; - if let Some(ack_timer) = &mut self.ack_timer { - ack_timer.stop(); - } - self.awake_write_loop(); - } else if self.delayed_ack_triggered { - // Will send delayed ack in the next ack timeout - self.ack_state = AckState::Delay; - if let Some(ack_timer) = &mut self.ack_timer { - ack_timer.start(); - } - } - } - - #[allow(clippy::borrowed_box)] - async fn handle_chunk( - &mut self, - p: &Packet, - chunk: &Box, - ) -> Result<()> { - chunk.check()?; - let chunk_any = chunk.as_any(); - let packets = if let Some(c) = chunk_any.downcast_ref::() { - if c.is_ack { - self.handle_init_ack(p, c).await? - } else { - self.handle_init(p, c).await? - } - } else if chunk_any.downcast_ref::().is_some() - || chunk_any.downcast_ref::().is_some() - { - return Err(Error::ErrChunk); - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_heartbeat(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_cookie_echo(c).await? - } else if chunk_any.downcast_ref::().is_some() { - self.handle_cookie_ack().await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_data(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_sack(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_reconfig(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_forward_tsn(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_shutdown(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_shutdown_ack(c).await? - } else if let Some(c) = chunk_any.downcast_ref::() { - self.handle_shutdown_complete(c).await? - } else { - /* - https://datatracker.ietf.org/doc/html/rfc4960#section-3 - - 00 - Stop processing this SCTP packet and discard it, do not - process any further chunks within it. - - 01 - Stop processing this SCTP packet and discard it, do not - process any further chunks within it, and report the - unrecognized chunk in an 'Unrecognized Chunk Type'. - - 10 - Skip this chunk and continue processing. - - 11 - Skip this chunk and continue processing, but report in an - ERROR chunk using the 'Unrecognized Chunk Type' cause of - error. - */ - let handle_code = chunk.header().typ.0 >> 6; - match handle_code { - 0b00 => { - // Stop processing this packet - return Err(Error::ErrChunkTypeUnhandled); - } - 0b01 => { - // stop processing but report the chunk as unrecognized - let err_chunk = ChunkError { - error_causes: vec![ErrorCause { - code: UNRECOGNIZED_CHUNK_TYPE, - raw: chunk.marshal()?, - }], - }; - let packet = Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - chunks: vec![Box::new(err_chunk)], - }; - self.control_queue.push_back(packet); - self.awake_write_loop(); - return Err(Error::ErrChunkTypeUnhandled); - } - 0b10 => { - // just ignore - vec![] - } - 0b11 => { - // keep processing but report the chunk as unrecognized - let err_chunk = ChunkError { - error_causes: vec![ErrorCause { - code: UNRECOGNIZED_CHUNK_TYPE, - raw: chunk.marshal()?, - }], - }; - let packet = Packet { - verification_tag: self.peer_verification_tag, - source_port: self.source_port, - destination_port: self.destination_port, - chunks: vec![Box::new(err_chunk)], - }; - vec![packet] - } - _ => unreachable!("This can only have 4 values."), - } - }; - - if !packets.is_empty() { - let mut buf: VecDeque<_> = packets.into_iter().collect(); - self.control_queue.append(&mut buf); - self.awake_write_loop(); - } - - Ok(()) - } - - /// buffered_amount returns total amount (in bytes) of currently buffered user data. - /// This is used only by testing. - pub(crate) fn buffered_amount(&self) -> usize { - self.pending_queue.get_num_bytes() + self.inflight_queue.get_num_bytes() - } -} - -#[async_trait] -impl AckTimerObserver for AssociationInternal { - async fn on_ack_timeout(&mut self) { - log::trace!( - "[{}] ack timed out (ack_state: {})", - self.name, - self.ack_state - ); - self.stats.inc_ack_timeouts(); - self.ack_state = AckState::Immediate; - self.awake_write_loop(); - } -} - -#[async_trait] -impl RtxTimerObserver for AssociationInternal { - async fn on_retransmission_timeout(&mut self, id: RtxTimerId, n_rtos: usize) { - match id { - RtxTimerId::T1Init => { - if let Err(err) = self.send_init() { - log::debug!( - "[{}] failed to retransmit init (n_rtos={}): {:?}", - self.name, - n_rtos, - err - ); - } - } - - RtxTimerId::T1Cookie => { - if let Err(err) = self.send_cookie_echo() { - log::debug!( - "[{}] failed to retransmit cookie-echo (n_rtos={}): {:?}", - self.name, - n_rtos, - err - ); - } - } - - RtxTimerId::T2Shutdown => { - log::debug!( - "[{}] retransmission of shutdown timeout (n_rtos={})", - self.name, - n_rtos - ); - let state = self.get_state(); - match state { - AssociationState::ShutdownSent => { - self.will_send_shutdown.store(true, Ordering::SeqCst); - self.awake_write_loop(); - } - AssociationState::ShutdownAckSent => { - self.will_send_shutdown_ack = true; - self.awake_write_loop(); - } - _ => {} - } - } - - RtxTimerId::T3RTX => { - self.stats.inc_t3timeouts(); - - // RFC 4960 sec 6.3.3 - // E1) For the destination address for which the timer expires, adjust - // its ssthresh with rules defined in Section 7.2.3 and set the - // cwnd <- MTU. - // RFC 4960 sec 7.2.3 - // When the T3-rtx timer expires on an address, SCTP should perform slow - // start by: - // ssthresh = max(cwnd/2, 4*MTU) - // cwnd = 1*MTU - - self.ssthresh = std::cmp::max(self.cwnd / 2, 4 * self.mtu); - self.cwnd = self.mtu; - log::trace!( - "[{}] updated cwnd={} ssthresh={} inflight={} (RTO)", - self.name, - self.cwnd, - self.ssthresh, - self.inflight_queue.get_num_bytes() - ); - - // RFC 3758 sec 3.5 - // A5) Any time the T3-rtx timer expires, on any destination, the sender - // SHOULD try to advance the "Advanced.Peer.Ack.Point" by following - // the procedures outlined in C2 - C5. - if self.use_forward_tsn { - // RFC 3758 Sec 3.5 C2 - let mut i = self.advanced_peer_tsn_ack_point + 1; - while let Some(c) = self.inflight_queue.get(i) { - if !c.abandoned() { - break; - } - self.advanced_peer_tsn_ack_point = i; - i += 1; - } - - // RFC 3758 Sec 3.5 C3 - if sna32gt( - self.advanced_peer_tsn_ack_point, - self.cumulative_tsn_ack_point, - ) { - self.will_send_forward_tsn = true; - log::debug!( - "[{}] on_retransmission_timeout {}: sna32GT({}, {})", - self.name, - self.will_send_forward_tsn, - self.advanced_peer_tsn_ack_point, - self.cumulative_tsn_ack_point - ); - } - } - - log::debug!( - "[{}] T3-rtx timed out: n_rtos={} cwnd={} ssthresh={}", - self.name, - n_rtos, - self.cwnd, - self.ssthresh - ); - - self.inflight_queue.mark_all_to_retrasmit(); - self.awake_write_loop(); - } - - RtxTimerId::Reconfig => { - self.will_retransmit_reconfig = true; - self.awake_write_loop(); - } - } - } - - async fn on_retransmission_failure(&mut self, id: RtxTimerId) { - match id { - RtxTimerId::T1Init => { - log::error!("[{}] retransmission failure: T1-init", self.name); - if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx { - let _ = handshake_completed_ch - .send(Some(Error::ErrHandshakeInitAck)) - .await; - } - } - RtxTimerId::T1Cookie => { - log::error!("[{}] retransmission failure: T1-cookie", self.name); - if let Some(handshake_completed_ch) = &self.handshake_completed_ch_tx { - let _ = handshake_completed_ch - .send(Some(Error::ErrHandshakeCookieEcho)) - .await; - } - } - - RtxTimerId::T2Shutdown => { - log::error!("[{}] retransmission failure: T2-shutdown", self.name); - } - - RtxTimerId::T3RTX => { - // T3-rtx timer will not fail by design - // Justifications: - // * ICE would fail if the connectivity is lost - // * WebRTC spec is not clear how this incident should be reported to ULP - log::error!("[{}] retransmission failure: T3-rtx (DATA)", self.name); - } - _ => {} - } - } -} diff --git a/sctp/src/association/association_internal/association_internal_test.rs b/sctp/src/association/association_internal/association_internal_test.rs deleted file mode 100644 index 2a830aef9..000000000 --- a/sctp/src/association/association_internal/association_internal_test.rs +++ /dev/null @@ -1,544 +0,0 @@ -use std::io; -use std::net::SocketAddr; - -use super::*; - -type Result = std::result::Result; - -impl From for util::Error { - fn from(e: Error) -> Self { - util::Error::from_std(e) - } -} - -struct DumbConn; - -#[async_trait] -impl Conn for DumbConn { - async fn connect(&self, _addr: SocketAddr) -> Result<()> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, _b: &mut [u8]) -> Result { - Ok(0) - } - - async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn send(&self, _b: &[u8]) -> Result { - Ok(0) - } - - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> Result { - Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Addr Not Available").into()) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> Result<()> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -fn create_association_internal(config: Config) -> AssociationInternal { - let (close_loop_ch_tx, _close_loop_ch_rx) = broadcast::channel(1); - let (accept_ch_tx, _accept_ch_rx) = mpsc::channel(1); - let (handshake_completed_ch_tx, _handshake_completed_ch_rx) = mpsc::channel(1); - let (awake_write_loop_ch_tx, _awake_write_loop_ch_rx) = mpsc::channel(1); - AssociationInternal::new( - config, - close_loop_ch_tx, - accept_ch_tx, - handshake_completed_ch_tx, - Arc::new(awake_write_loop_ch_tx), - ) -} - -#[test] -fn test_create_forward_tsn_forward_one_abandoned() -> Result<()> { - let mut a = AssociationInternal { - cumulative_tsn_ack_point: 9, - ..Default::default() - }; - - a.advanced_peer_tsn_ack_point = 10; - a.inflight_queue.push_no_check(ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: 10, - stream_identifier: 1, - stream_sequence_number: 2, - user_data: Bytes::from_static(b"ABC"), - nsent: 1, - abandoned: Arc::new(AtomicBool::new(true)), - ..Default::default() - }); - - let fwdtsn = a.create_forward_tsn(); - - assert_eq!(fwdtsn.new_cumulative_tsn, 10, "should be able to serialize"); - assert_eq!(fwdtsn.streams.len(), 1, "there should be one stream"); - assert_eq!(fwdtsn.streams[0].identifier, 1, "si should be 1"); - assert_eq!(fwdtsn.streams[0].sequence, 2, "ssn should be 2"); - - Ok(()) -} - -#[test] -fn test_create_forward_tsn_forward_two_abandoned_with_the_same_si() -> Result<()> { - let mut a = AssociationInternal { - cumulative_tsn_ack_point: 9, - ..Default::default() - }; - - a.advanced_peer_tsn_ack_point = 12; - a.inflight_queue.push_no_check(ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: 10, - stream_identifier: 1, - stream_sequence_number: 2, - user_data: Bytes::from_static(b"ABC"), - nsent: 1, - abandoned: Arc::new(AtomicBool::new(true)), - ..Default::default() - }); - a.inflight_queue.push_no_check(ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: 11, - stream_identifier: 1, - stream_sequence_number: 3, - user_data: Bytes::from_static(b"DEF"), - nsent: 1, - abandoned: Arc::new(AtomicBool::new(true)), - ..Default::default() - }); - a.inflight_queue.push_no_check(ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: 12, - stream_identifier: 2, - stream_sequence_number: 1, - user_data: Bytes::from_static(b"123"), - nsent: 1, - abandoned: Arc::new(AtomicBool::new(true)), - ..Default::default() - }); - - let fwdtsn = a.create_forward_tsn(); - - assert_eq!(fwdtsn.new_cumulative_tsn, 12, "should be able to serialize"); - assert_eq!(fwdtsn.streams.len(), 2, "there should be two stream"); - - let mut si1ok = false; - let mut si2ok = false; - for s in &fwdtsn.streams { - match s.identifier { - 1 => { - assert_eq!(3, s.sequence, "ssn should be 3"); - si1ok = true; - } - 2 => { - assert_eq!(1, s.sequence, "ssn should be 1"); - si2ok = true; - } - _ => panic!("unexpected stream identifier"), - } - } - assert!(si1ok, "si=1 should be present"); - assert!(si2ok, "si=2 should be present"); - - Ok(()) -} - -#[tokio::test] -async fn test_handle_forward_tsn_forward_3unreceived_chunks() -> Result<()> { - let mut a = AssociationInternal { - use_forward_tsn: true, - ..Default::default() - }; - - let prev_tsn = a.peer_last_tsn; - - let fwdtsn = ChunkForwardTsn { - new_cumulative_tsn: a.peer_last_tsn + 3, - streams: vec![ChunkForwardTsnStream { - identifier: 0, - sequence: 0, - }], - }; - - let p = a.handle_forward_tsn(&fwdtsn).await?; - - let delayed_ack_triggered = a.delayed_ack_triggered; - let immediate_ack_triggered = a.immediate_ack_triggered; - assert_eq!( - a.peer_last_tsn, - prev_tsn + 3, - "peerLastTSN should advance by 3 " - ); - assert!(delayed_ack_triggered, "delayed sack should be triggered"); - assert!( - !immediate_ack_triggered, - "immediate sack should NOT be triggered" - ); - assert!(p.is_empty(), "should return empty"); - - Ok(()) -} - -#[tokio::test] -async fn test_handle_forward_tsn_forward_1for1_missing() -> Result<()> { - let mut a = AssociationInternal { - use_forward_tsn: true, - ..Default::default() - }; - - let prev_tsn = a.peer_last_tsn; - - // this chunk is blocked by the missing chunk at tsn=1 - a.payload_queue.push( - ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: a.peer_last_tsn + 2, - stream_identifier: 0, - stream_sequence_number: 1, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }, - a.peer_last_tsn, - ); - - let fwdtsn = ChunkForwardTsn { - new_cumulative_tsn: a.peer_last_tsn + 1, - streams: vec![ChunkForwardTsnStream { - identifier: 0, - sequence: 1, - }], - }; - - let p = a.handle_forward_tsn(&fwdtsn).await?; - - let delayed_ack_triggered = a.delayed_ack_triggered; - let immediate_ack_triggered = a.immediate_ack_triggered; - assert_eq!( - a.peer_last_tsn, - prev_tsn + 2, - "peerLastTSN should advance by 2" - ); - assert!(delayed_ack_triggered, "delayed sack should be triggered"); - assert!( - !immediate_ack_triggered, - "immediate sack should NOT be triggered" - ); - assert!(p.is_empty(), "should return empty"); - - Ok(()) -} - -#[tokio::test] -async fn test_handle_forward_tsn_forward_1for2_missing() -> Result<()> { - let mut a = AssociationInternal { - use_forward_tsn: true, - ..Default::default() - }; - - let prev_tsn = a.peer_last_tsn; - - // this chunk is blocked by the missing chunk at tsn=1 - a.payload_queue.push( - ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: a.peer_last_tsn + 3, - stream_identifier: 0, - stream_sequence_number: 1, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }, - a.peer_last_tsn, - ); - - let fwdtsn = ChunkForwardTsn { - new_cumulative_tsn: a.peer_last_tsn + 1, - streams: vec![ChunkForwardTsnStream { - identifier: 0, - sequence: 1, - }], - }; - - let p = a.handle_forward_tsn(&fwdtsn).await?; - - let immediate_ack_triggered = a.immediate_ack_triggered; - assert_eq!( - a.peer_last_tsn, - prev_tsn + 1, - "peerLastTSN should advance by 1" - ); - assert!( - immediate_ack_triggered, - "immediate sack should be triggered" - ); - assert!(p.is_empty(), "should return empty"); - - Ok(()) -} - -#[tokio::test] -async fn test_handle_forward_tsn_dup_forward_tsn_chunk_should_generate_sack() -> Result<()> { - let mut a = AssociationInternal { - use_forward_tsn: true, - ..Default::default() - }; - - let prev_tsn = a.peer_last_tsn; - - let fwdtsn = ChunkForwardTsn { - new_cumulative_tsn: a.peer_last_tsn, - streams: vec![ChunkForwardTsnStream { - identifier: 0, - sequence: 1, - }], - }; - - let p = a.handle_forward_tsn(&fwdtsn).await?; - - assert_eq!(a.peer_last_tsn, prev_tsn, "peerLastTSN should not advance"); - assert_eq!(a.ack_state, AckState::Immediate, "sack should be requested"); - assert!(p.is_empty(), "should return empty"); - - Ok(()) -} - -#[tokio::test] -async fn test_assoc_create_new_stream() -> Result<()> { - let (accept_ch_tx, _accept_ch_rx) = mpsc::channel(ACCEPT_CH_SIZE); - let mut a = AssociationInternal { - accept_ch_tx: Some(accept_ch_tx), - ..Default::default() - }; - - for i in 0..ACCEPT_CH_SIZE { - let s = a.create_stream(i as u16, true); - if let Some(s) = s { - let result = a.streams.get(&s.stream_identifier); - assert!(result.is_some(), "should be in a.streams map"); - } else { - panic!("{i} should success"); - } - } - - let new_si = ACCEPT_CH_SIZE as u16; - let s = a.create_stream(new_si, true); - assert!(s.is_none(), "should be none"); - let result = a.streams.get(&new_si); - assert!(result.is_none(), "should NOT be in a.streams map"); - - let to_be_ignored = ChunkPayloadData { - beginning_fragment: true, - ending_fragment: true, - tsn: a.peer_last_tsn + 1, - stream_identifier: new_si, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }; - - let p = a.handle_data(&to_be_ignored).await?; - assert!(p.is_empty(), "should return empty"); - - Ok(()) -} - -async fn handle_init_test(name: &str, initial_state: AssociationState, expect_err: bool) { - let mut a = create_association_internal(Config { - net_conn: Arc::new(DumbConn {}), - max_receive_buffer_size: 0, - max_message_size: 0, - name: "client".to_owned(), - }); - a.set_state(initial_state); - let pkt = Packet { - source_port: 5001, - destination_port: 5002, - ..Default::default() - }; - let mut init = ChunkInit { - initial_tsn: 1234, - num_outbound_streams: 1001, - num_inbound_streams: 1002, - initiate_tag: 5678, - advertised_receiver_window_credit: 512 * 1024, - ..Default::default() - }; - init.set_supported_extensions(); - - let result = a.handle_init(&pkt, &init).await; - if expect_err { - assert!(result.is_err(), "{name} should fail"); - return; - } else { - assert!(result.is_ok(), "{name} should be ok"); - } - assert_eq!( - a.peer_last_tsn, - if init.initial_tsn == 0 { - u32::MAX - } else { - init.initial_tsn - 1 - }, - "{name} should match" - ); - assert_eq!(a.my_max_num_outbound_streams, 1001, "{name} should match"); - assert_eq!(a.my_max_num_inbound_streams, 1002, "{name} should match"); - assert_eq!(a.peer_verification_tag, 5678, "{name} should match"); - assert_eq!(a.destination_port, pkt.source_port, "{name} should match"); - assert_eq!(a.source_port, pkt.destination_port, "{name} should match"); - assert!(a.use_forward_tsn, "{name} should be set to true"); -} - -#[tokio::test] -async fn test_assoc_handle_init() -> Result<()> { - handle_init_test("normal", AssociationState::Closed, false).await; - - handle_init_test( - "unexpected state established", - AssociationState::Established, - true, - ) - .await; - - handle_init_test( - "unexpected state shutdownAckSent", - AssociationState::ShutdownAckSent, - true, - ) - .await; - - handle_init_test( - "unexpected state shutdownPending", - AssociationState::ShutdownPending, - true, - ) - .await; - - handle_init_test( - "unexpected state shutdownReceived", - AssociationState::ShutdownReceived, - true, - ) - .await; - - handle_init_test( - "unexpected state shutdownSent", - AssociationState::ShutdownSent, - true, - ) - .await; - - Ok(()) -} - -#[tokio::test] -async fn test_assoc_max_message_size_default() -> Result<()> { - let mut a = create_association_internal(Config { - net_conn: Arc::new(DumbConn {}), - max_receive_buffer_size: 0, - max_message_size: 0, - name: "client".to_owned(), - }); - assert_eq!( - a.max_message_size.load(Ordering::SeqCst), - 65536, - "should match" - ); - - let stream = a.create_stream(1, false); - assert!(stream.is_some(), "should succeed"); - - if let Some(s) = stream { - let p = Bytes::from(vec![0u8; 65537]); - let ppi = PayloadProtocolIdentifier::from(s.default_payload_type.load(Ordering::SeqCst)); - - if let Err(err) = s.write_sctp(&p.slice(..65536), ppi).await { - assert_ne!( - err, - Error::ErrOutboundPacketTooLarge, - "should be not Error::ErrOutboundPacketTooLarge" - ); - } else { - panic!("should be error"); - } - - if let Err(err) = s.write_sctp(&p.slice(..65537), ppi).await { - assert_eq!( - err, - Error::ErrOutboundPacketTooLarge, - "should be Error::ErrOutboundPacketTooLarge" - ); - } else { - panic!("should be error"); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_assoc_max_message_size_explicit() -> Result<()> { - let mut a = create_association_internal(Config { - net_conn: Arc::new(DumbConn {}), - max_receive_buffer_size: 0, - max_message_size: 30000, - name: "client".to_owned(), - }); - - assert_eq!( - a.max_message_size.load(Ordering::SeqCst), - 30000, - "should match" - ); - - let stream = a.create_stream(1, false); - assert!(stream.is_some(), "should succeed"); - - if let Some(s) = stream { - let p = Bytes::from(vec![0u8; 30001]); - let ppi = PayloadProtocolIdentifier::from(s.default_payload_type.load(Ordering::SeqCst)); - - if let Err(err) = s.write_sctp(&p.slice(..30000), ppi).await { - assert_ne!( - err, - Error::ErrOutboundPacketTooLarge, - "should be not Error::ErrOutboundPacketTooLarge" - ); - } else { - panic!("should be error"); - } - - if let Err(err) = s.write_sctp(&p.slice(..30001), ppi).await { - assert_eq!( - err, - Error::ErrOutboundPacketTooLarge, - "should be Error::ErrOutboundPacketTooLarge" - ); - } else { - panic!("should be error"); - } - } - - Ok(()) -} diff --git a/sctp/src/association/association_stats.rs b/sctp/src/association/association_stats.rs deleted file mode 100644 index 0fe390c0c..000000000 --- a/sctp/src/association/association_stats.rs +++ /dev/null @@ -1,61 +0,0 @@ -use portable_atomic::AtomicU64; -use std::sync::atomic::Ordering; - -#[derive(Default, Debug)] -pub(crate) struct AssociationStats { - n_datas: AtomicU64, - n_sacks: AtomicU64, - n_t3timeouts: AtomicU64, - n_ack_timeouts: AtomicU64, - n_fast_retrans: AtomicU64, -} - -impl AssociationStats { - pub(crate) fn inc_datas(&self) { - self.n_datas.fetch_add(1, Ordering::SeqCst); - } - - pub(crate) fn get_num_datas(&self) -> u64 { - self.n_datas.load(Ordering::SeqCst) - } - - pub(crate) fn inc_sacks(&self) { - self.n_sacks.fetch_add(1, Ordering::SeqCst); - } - - pub(crate) fn get_num_sacks(&self) -> u64 { - self.n_sacks.load(Ordering::SeqCst) - } - - pub(crate) fn inc_t3timeouts(&self) { - self.n_t3timeouts.fetch_add(1, Ordering::SeqCst); - } - - pub(crate) fn get_num_t3timeouts(&self) -> u64 { - self.n_t3timeouts.load(Ordering::SeqCst) - } - - pub(crate) fn inc_ack_timeouts(&self) { - self.n_ack_timeouts.fetch_add(1, Ordering::SeqCst); - } - - pub(crate) fn get_num_ack_timeouts(&self) -> u64 { - self.n_ack_timeouts.load(Ordering::SeqCst) - } - - pub(crate) fn inc_fast_retrans(&self) { - self.n_fast_retrans.fetch_add(1, Ordering::SeqCst); - } - - pub(crate) fn get_num_fast_retrans(&self) -> u64 { - self.n_fast_retrans.load(Ordering::SeqCst) - } - - pub(crate) fn reset(&self) { - self.n_datas.store(0, Ordering::SeqCst); - self.n_sacks.store(0, Ordering::SeqCst); - self.n_t3timeouts.store(0, Ordering::SeqCst); - self.n_ack_timeouts.store(0, Ordering::SeqCst); - self.n_fast_retrans.store(0, Ordering::SeqCst); - } -} diff --git a/sctp/src/association/association_test.rs b/sctp/src/association/association_test.rs deleted file mode 100644 index 069d35773..000000000 --- a/sctp/src/association/association_test.rs +++ /dev/null @@ -1,2616 +0,0 @@ -// Silence warning on `for i in 0..vec.len() { โ€ฆ }`: -#![allow(clippy::needless_range_loop)] - -use std::io; -use std::net::{Shutdown, SocketAddr}; -use std::str::FromStr; -use std::time::Duration; - -use async_trait::async_trait; -use tokio::net::UdpSocket; -use util::conn::conn_bridge::*; -use util::conn::conn_pipe::pipe; -use util::conn::*; - -use super::*; -use crate::chunk::chunk_selective_ack::GapAckBlock; -use crate::stream::*; - -async fn create_new_association_pair( - br: &Arc, - ca: Arc, - cb: Arc, - ack_mode: AckMode, - recv_buf_size: u32, -) -> Result<(Association, Association)> { - let (handshake0ch_tx, mut handshake0ch_rx) = mpsc::channel(1); - let (handshake1ch_tx, mut handshake1ch_rx) = mpsc::channel(1); - let (closed_tx, mut closed_rx0) = broadcast::channel::<()>(1); - let mut closed_rx1 = closed_tx.subscribe(); - - // Setup client - tokio::spawn(async move { - let client = Association::client(Config { - net_conn: ca, - max_receive_buffer_size: recv_buf_size, - max_message_size: 0, - name: "client".to_owned(), - }) - .await; - - let _ = handshake0ch_tx.send(client).await; - let _ = closed_rx0.recv().await; - - Result::<()>::Ok(()) - }); - - // Setup server - tokio::spawn(async move { - let server = Association::server(Config { - net_conn: cb, - max_receive_buffer_size: recv_buf_size, - max_message_size: 0, - name: "server".to_owned(), - }) - .await; - - let _ = handshake1ch_tx.send(server).await; - let _ = closed_rx1.recv().await; - - Result::<()>::Ok(()) - }); - - let mut client = None; - let mut server = None; - let mut a0handshake_done = false; - let mut a1handshake_done = false; - let mut i = 0; - while (!a0handshake_done || !a1handshake_done) && i < 100 { - br.tick().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - r0 = handshake0ch_rx.recv() => { - if let Ok(c) = r0.unwrap() { - client = Some(c); - } - a0handshake_done = true; - }, - r1 = handshake1ch_rx.recv() => { - if let Ok(s) = r1.unwrap() { - server = Some(s); - } - a1handshake_done = true; - }, - }; - i += 1; - } - - if !a0handshake_done || !a1handshake_done { - return Err(Error::Other("handshake failed".to_owned())); - } - - drop(closed_tx); - - let (client, server) = (client.unwrap(), server.unwrap()); - { - let mut ai = client.association_internal.lock().await; - ai.ack_mode = ack_mode; - } - { - let mut ai = server.association_internal.lock().await; - ai.ack_mode = ack_mode; - } - - Ok((client, server)) -} - -async fn close_association_pair(br: &Arc, client: Association, server: Association) { - let (handshake0ch_tx, mut handshake0ch_rx) = mpsc::channel(1); - let (handshake1ch_tx, mut handshake1ch_rx) = mpsc::channel(1); - let (closed_tx, mut closed_rx0) = broadcast::channel::<()>(1); - let mut closed_rx1 = closed_tx.subscribe(); - - // Close client - tokio::spawn(async move { - client.close().await?; - let _ = handshake0ch_tx.send(()).await; - let _ = closed_rx0.recv().await; - - Result::<()>::Ok(()) - }); - - // Close server - tokio::spawn(async move { - server.close().await?; - let _ = handshake1ch_tx.send(()).await; - let _ = closed_rx1.recv().await; - - Result::<()>::Ok(()) - }); - - let mut a0handshake_done = false; - let mut a1handshake_done = false; - let mut i = 0; - while (!a0handshake_done || !a1handshake_done) && i < 100 { - br.tick().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - _ = handshake0ch_rx.recv() => { - a0handshake_done = true; - }, - _ = handshake1ch_rx.recv() => { - a1handshake_done = true; - }, - }; - i += 1; - } - - drop(closed_tx); -} - -async fn flush_buffers(br: &Arc, client: &Association, server: &Association) { - loop { - loop { - let n = br.tick().await; - if n == 0 { - break; - } - } - - { - let (a0, a1) = ( - client.association_internal.lock().await, - server.association_internal.lock().await, - ); - if a0.buffered_amount() == 0 && a1.buffered_amount() == 0 { - break; - } - } - tokio::time::sleep(Duration::from_millis(10)).await; - } -} - -async fn establish_session_pair( - br: &Arc, - client: &Association, - server: &mut Association, - si: u16, -) -> Result<(Arc, Arc)> { - let hello_msg = Bytes::from_static(b"Hello"); - let s0 = client - .open_stream(si, PayloadProtocolIdentifier::Binary) - .await?; - let _ = s0 - .write_sctp(&hello_msg, PayloadProtocolIdentifier::Dcep) - .await?; - - flush_buffers(br, client, server).await; - - let s1 = server.accept_stream().await.unwrap(); - if s0.stream_identifier != s1.stream_identifier { - return Err(Error::Other("SI should match".to_owned())); - } - - br.process().await; - - let mut buf = vec![0u8; 1024]; - let (n, ppi) = s1.read_sctp(&mut buf).await?; - - if n != hello_msg.len() { - return Err(Error::Other("received data must by 3 bytes".to_owned())); - } - - if ppi != PayloadProtocolIdentifier::Dcep { - return Err(Error::Other("unexpected ppi".to_owned())); - } - - if buf[..n] != hello_msg { - return Err(Error::Other("received data mismatch".to_owned())); - } - - flush_buffers(br, client, server).await; - - Ok((s0, s1)) -} - -//use std::io::Write; - -#[cfg(not(target_os = "windows"))] // this times out in CI on windows. -#[tokio::test] -async fn test_assoc_reliable_simple() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - static MSG: Bytes = Bytes::from_static(b"ABC"); - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), 0, "incorrect bufferedAmount"); - } - - let n = s0 - .write_sctp(&MSG, PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, MSG.len(), "unexpected length of received data"); - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), MSG.len(), "incorrect bufferedAmount"); - } - - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 32]; - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, MSG.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), 0, "incorrect bufferedAmount"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -// NB: This is ignored on Windows due to flakiness with timing/IO interactions. -// TODO: Refactor this and other tests that are disabled for similar reason to not have such issues -#[cfg(not(target_os = "windows"))] -#[tokio::test] -async fn test_assoc_reliable_ordered_reordered() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 2; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - let mut sbufl = vec![0u8; 2000]; - for i in 0..sbufl.len() { - sbufl[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), 0, "incorrect bufferedAmount"); - } - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - tokio::time::sleep(Duration::from_millis(10)).await; - br.reorder(0).await; - br.process().await; - - let mut buf = vec![0u8; 2000]; - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 0, - "unexpected received data" - ); - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_reliable_ordered_fragmented_then_defragmented() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 3; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - let mut sbufl = vec![0u8; 2000]; - for i in 0..sbufl.len() { - sbufl[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - s0.set_reliability_params(false, ReliabilityType::Reliable, 0); - s1.set_reliability_params(false, ReliabilityType::Reliable, 0); - - let n = s0 - .write_sctp( - &Bytes::from(sbufl.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbufl.len(), "unexpected length of received data"); - - flush_buffers(&br, &a0, &a1).await; - - let mut rbuf = vec![0u8; 2000]; - let (n, ppi) = s1.read_sctp(&mut rbuf).await?; - assert_eq!(n, sbufl.len(), "unexpected length of received data"); - assert_eq!(&rbuf[..n], &sbufl, "unexpected received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_reliable_unordered_fragmented_then_defragmented() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 4; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - let mut sbufl = vec![0u8; 2000]; - for i in 0..sbufl.len() { - sbufl[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - s0.set_reliability_params(true, ReliabilityType::Reliable, 0); - s1.set_reliability_params(true, ReliabilityType::Reliable, 0); - - let n = s0 - .write_sctp( - &Bytes::from(sbufl.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbufl.len(), "unexpected length of received data"); - - flush_buffers(&br, &a0, &a1).await; - - let mut rbuf = vec![0u8; 2000]; - let (n, ppi) = s1.read_sctp(&mut rbuf).await?; - assert_eq!(n, sbufl.len(), "unexpected length of received data"); - assert_eq!(&rbuf[..n], &sbufl, "unexpected received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_reliable_unordered_ordered() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 5; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - let mut sbufl = vec![0u8; 2000]; - for i in 0..sbufl.len() { - sbufl[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - s0.set_reliability_params(true, ReliabilityType::Reliable, 0); - s1.set_reliability_params(true, ReliabilityType::Reliable, 0); - - br.reorder_next_nwrites(0, 2); - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 0, - "unexpected received data" - ); - - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -// NB: This is ignored on Windows due to flakiness with timing/IO interactions. -// TODO: Refactor this and other tests that are disabled for similar reason to not have such issues -#[cfg(not(target_os = "windows"))] -#[tokio::test] -async fn test_assoc_reliable_retransmission() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 6; - static MSG1: Bytes = Bytes::from_static(b"ABC"); - static MSG2: Bytes = Bytes::from_static(b"DEFG"); - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - { - let mut a = a0.association_internal.lock().await; - a.rto_mgr.set_rto(100, true); - } - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - let n = s0 - .write_sctp(&MSG1, PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, MSG1.len(), "unexpected length of received data"); - - let n = s0 - .write_sctp(&MSG2, PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, MSG2.len(), "unexpected length of received data"); - - tokio::time::sleep(Duration::from_millis(10)).await; - log::debug!("dropping packet"); - br.drop_offset(0, 0, 1).await; // drop the first packet (second one should be sacked) - - // process packets for 200 msec - for _ in 0..20 { - br.tick().await; - tokio::time::sleep(Duration::from_millis(10)).await; - } - - let mut buf = vec![0u8; 32]; - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, MSG1.len(), "unexpected length of received data"); - assert_eq!(&buf[..n], &MSG1, "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, MSG2.len(), "unexpected length of received data"); - assert_eq!(&buf[..n], &MSG2, "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_reliable_short_buffer() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - static MSG: Bytes = Bytes::from_static(b"Hello"); - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), 0, "incorrect bufferedAmount"); - } - - let n = s0 - .write_sctp(&MSG, PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, MSG.len(), "unexpected length of received data"); - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), MSG.len(), "incorrect bufferedAmount"); - } - - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 3]; - let result = s1.read_sctp(&mut buf).await; - assert!(result.is_err(), "expected error to be ErrShortBuffer"); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrShortBuffer { size: 3 }, - "expected error to be ErrShortBuffer" - ); - } - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), 0, "incorrect bufferedAmount"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_unreliable_rexmit_ordered_no_fragment() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - // When we set the reliability value to 0 [times], then it will cause - // the chunk to be abandoned immediately after the first transmission. - s0.set_reliability_params(false, ReliabilityType::Rexmit, 0); - s1.set_reliability_params(false, ReliabilityType::Rexmit, 0); // doesn't matter - - br.drop_next_nwrites(0, 1); // drop the first packet (second one should be sacked) - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - log::debug!("flush_buffers"); - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - log::debug!("read_sctp"); - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - log::debug!("process"); - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_unreliable_rexmit_ordered_fragment() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - let mut sbuf = vec![0u8; 2000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - // lock RTO value at 100 [msec] - let mut a = a0.association_internal.lock().await; - a.rto_mgr.set_rto(100, true); - } - // When we set the reliability value to 0 [times], then it will cause - // the chunk to be abandoned immediately after the first transmission. - s0.set_reliability_params(false, ReliabilityType::Rexmit, 0); - s1.set_reliability_params(false, ReliabilityType::Rexmit, 0); // doesn't matter - - br.drop_next_nwrites(0, 1); // drop the first packet (second one should be sacked) - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - //log::debug!("flush_buffers"); - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - //log::debug!("read_sctp"); - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - //log::debug!("process"); - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_unreliable_rexmit_unordered_no_fragment() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 2; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - // When we set the reliability value to 0 [times], then it will cause - // the chunk to be abandoned immediately after the first transmission. - s0.set_reliability_params(true, ReliabilityType::Rexmit, 0); - s1.set_reliability_params(true, ReliabilityType::Rexmit, 0); // doesn't matter - - br.drop_next_nwrites(0, 1); // drop the first packet (second one should be sacked) - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - //log::debug!("flush_buffers"); - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - //log::debug!("read_sctp"); - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - //log::debug!("process"); - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -// NB: This is ignored on Windows and macOS due to flakiness with timing/IO interactions. -// TODO: Refactor this and other tests that are disabled for similar reason to not have such issues -#[cfg(not(any(target_os = "macos", target_os = "windows")))] -#[tokio::test] -async fn test_assoc_unreliable_rexmit_unordered_fragment() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - let mut sbuf = vec![0u8; 2000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - // When we set the reliability value to 0 [times], then it will cause - // the chunk to be abandoned immediately after the first transmission. - s0.set_reliability_params(true, ReliabilityType::Rexmit, 0); - s1.set_reliability_params(true, ReliabilityType::Rexmit, 0); // doesn't matter - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - //log::debug!("flush_buffers"); - tokio::time::sleep(Duration::from_millis(10)).await; - br.drop_offset(0, 0, 2).await; // drop the second fragment of the first chunk (second chunk should be sacked) - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - //log::debug!("read_sctp"); - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - //log::debug!("process"); - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - assert_eq!( - q.unordered.len(), - 0, - "should be nothing in the unordered queue" - ); - assert_eq!( - q.unordered_chunks.len(), - 0, - "should be nothing in the unorderedChunks list" - ); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_unreliable_rexmit_timed_ordered() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 3; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - // When we set the reliability value to 0 [times], then it will cause - // the chunk to be abandoned immediately after the first transmission. - s0.set_reliability_params(false, ReliabilityType::Timed, 0); - s1.set_reliability_params(false, ReliabilityType::Timed, 0); // doesn't matter - - br.drop_next_nwrites(0, 1); // drop the first packet (second one should be sacked) - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - //log::debug!("flush_buffers"); - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - //log::debug!("read_sctp"); - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - //log::debug!("process"); - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_unreliable_rexmit_timed_unordered() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 3; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - // When we set the reliability value to 0 [times], then it will cause - // the chunk to be abandoned immediately after the first transmission. - s0.set_reliability_params(true, ReliabilityType::Timed, 0); - s1.set_reliability_params(true, ReliabilityType::Timed, 0); // doesn't matter - - br.drop_next_nwrites(0, 1); // drop the first packet (second one should be sacked) - - sbuf[0..4].copy_from_slice(&0u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - sbuf[0..4].copy_from_slice(&1u32.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - //log::debug!("flush_buffers"); - flush_buffers(&br, &a0, &a1).await; - - let mut buf = vec![0u8; 2000]; - - //log::debug!("read_sctp"); - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - 1, - "unexpected received data" - ); - - //log::debug!("process"); - br.process().await; - - { - let q = s0.reassembly_queue.lock().await; - assert!(!q.is_readable(), "should no longer be readable"); - assert_eq!( - q.unordered.len(), - 0, - "should be nothing in the unordered queue" - ); - assert_eq!( - q.unordered_chunks.len(), - 0, - "should be nothing in the unorderedChunks list" - ); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//TODO: TestAssocT1InitTimer -//TODO: TestAssocT1CookieTimer -//TODO: TestAssocT3RtxTimer - -//use std::io::Write; - -// 1) Send 4 packets. drop the first one. -// 2) Last 3 packet will be received, which triggers fast-retransmission -// 3) The first one is retransmitted, which makes s1 readable -// Above should be done before RTO occurs (fast recovery) -#[tokio::test] -async fn test_assoc_congestion_control_fast_retransmission() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 6; - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::Normal, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - br.drop_next_nwrites(0, 1); // drop the first packet (second one should be sacked) - - for i in 0..4u32 { - sbuf[0..4].copy_from_slice(&i.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - } - - // process packets for 500 msec, assuming that the fast retrans/recover - // should complete within 500 msec. - for _ in 0..50 { - br.tick().await; - tokio::time::sleep(Duration::from_millis(10)).await; - } - - let mut buf = vec![0u8; 3000]; - - // Try to read all 4 packets - for i in 0..4 { - { - let q = s1.reassembly_queue.lock().await; - assert!(q.is_readable(), "should be readable"); - } - - let (n, ppi) = s1.read_sctp(&mut buf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!( - u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]), - i, - "unexpected received data" - ); - } - - //br.process().await; - - { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - assert!(!a.in_fast_recovery, "should not be in fast-recovery"); - - log::debug!("nDATAs : {}", b.stats.get_num_datas()); - log::debug!("nSACKs : {}", a.stats.get_num_sacks()); - log::debug!("nAckTimeouts: {}", b.stats.get_num_ack_timeouts()); - log::debug!("nFastRetrans: {}", a.stats.get_num_fast_retrans()); - - assert_eq!(a.stats.get_num_fast_retrans(), 1, "should be 1"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_congestion_control_congestion_avoidance() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const MAX_RECEIVE_BUFFER_SIZE: u32 = 64 * 1024; - const SI: u16 = 6; - const N_PACKETS_TO_SEND: u32 = 2000; - - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = create_new_association_pair( - &br, - Arc::new(ca), - Arc::new(cb), - AckMode::Normal, - MAX_RECEIVE_BUFFER_SIZE, - ) - .await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - a.stats.reset(); - b.stats.reset(); - } - - for i in 0..N_PACKETS_TO_SEND { - sbuf[0..4].copy_from_slice(&i.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - } - - let mut rbuf = vec![0u8; 3000]; - - // Repeat calling br.Tick() until the buffered amount becomes 0 - let mut n_packets_received = 0u32; - while s0.buffered_amount() > 0 && n_packets_received < N_PACKETS_TO_SEND { - loop { - let n = br.tick().await; - if n == 0 { - break; - } - } - - loop { - let readable = { - let q = s1.reassembly_queue.lock().await; - q.is_readable() - }; - if !readable { - break; - } - let (n, ppi) = s1.read_sctp(&mut rbuf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!( - n_packets_received, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "unexpected length of received data" - ); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - n_packets_received += 1; - } - } - - br.process().await; - - assert_eq!( - n_packets_received, N_PACKETS_TO_SEND, - "unexpected num of packets received" - ); - - { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - - assert!(!a.in_fast_recovery, "should not be in fast-recovery"); - assert!( - a.cwnd > a.ssthresh, - "should be in congestion avoidance mode" - ); - assert!( - a.ssthresh >= MAX_RECEIVE_BUFFER_SIZE, - "{} should not be less than the initial size of 128KB {}", - a.ssthresh, - MAX_RECEIVE_BUFFER_SIZE - ); - - assert_eq!( - 0, - s1.get_num_bytes_in_reassembly_queue().await, - "reassembly queue should be empty" - ); - - log::debug!("nDATAs : {}", b.stats.get_num_datas()); - log::debug!("nSACKs : {}", a.stats.get_num_sacks()); - log::debug!("nT3Timeouts: {}", a.stats.get_num_t3timeouts()); - - assert_eq!( - b.stats.get_num_datas(), - N_PACKETS_TO_SEND as u64, - "packet count mismatch" - ); - assert!( - a.stats.get_num_sacks() <= N_PACKETS_TO_SEND as u64 / 2, - "too many sacks" - ); - assert_eq!(a.stats.get_num_t3timeouts(), 0, "should be no retransmit"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_congestion_control_slow_reader() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const MAX_RECEIVE_BUFFER_SIZE: u32 = 64 * 1024; - const SI: u16 = 6; - const N_PACKETS_TO_SEND: u32 = 130; - - let mut sbuf = vec![0u8; 1000]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = create_new_association_pair( - &br, - Arc::new(ca), - Arc::new(cb), - AckMode::Normal, - MAX_RECEIVE_BUFFER_SIZE, - ) - .await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - for i in 0..N_PACKETS_TO_SEND { - sbuf[0..4].copy_from_slice(&i.to_be_bytes()); - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - ) - .await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - } - - let mut rbuf = vec![0u8; 3000]; - - // 1. First forward packets to receiver until rwnd becomes 0 - // 2. Wait until the sender's cwnd becomes 1*MTU (RTO occurred) - // 3. Stat reading a1's data - let mut n_packets_received = 0u32; - let mut has_rtoed = false; - while s0.buffered_amount() > 0 && n_packets_received < N_PACKETS_TO_SEND { - loop { - let n = br.tick().await; - if n == 0 { - break; - } - } - - if !has_rtoed { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - - let rwnd = b.get_my_receiver_window_credit().await; - let cwnd = a.cwnd; - if cwnd > a.mtu || rwnd > 0 { - // Do not read until a1.getMyReceiverWindowCredit() becomes zero - continue; - } - - has_rtoed = true; - } - - loop { - let readable = { - let q = s1.reassembly_queue.lock().await; - q.is_readable() - }; - if !readable { - break; - } - let (n, ppi) = s1.read_sctp(&mut rbuf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!( - n_packets_received, - u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]), - "unexpected length of received data" - ); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - n_packets_received += 1; - } - - tokio::time::sleep(Duration::from_millis(4)).await; - } - - br.process().await; - - assert_eq!( - n_packets_received, N_PACKETS_TO_SEND, - "unexpected num of packets received" - ); - assert_eq!( - s1.get_num_bytes_in_reassembly_queue().await, - 0, - "reassembly queue should be empty" - ); - - { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - - log::debug!("nDATAs : {}", b.stats.get_num_datas()); - log::debug!("nSACKs : {}", a.stats.get_num_sacks()); - log::debug!("nAckTimeouts: {}", b.stats.get_num_ack_timeouts()); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -/*FIXME -use std::io::Write; - -#[tokio::test] -async fn test_assoc_delayed_ack() -> Result<()> { - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init(); - - const SI: u16 = 6; - let mut sbuf = vec![0u8; 1000]; - let mut rbuf = vec![0u8; 1500]; - for i in 0..sbuf.len() { - sbuf[i] = (i & 0xff) as u8; - } - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::AlwaysDelay, 0) - .await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - a.stats.reset(); - b.stats.reset(); - } - - let n = s0 - .write_sctp( - &Bytes::from(sbuf.clone()), - PayloadProtocolIdentifier::Binary, - )?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - - // Repeat calling br.Tick() until the buffered amount becomes 0 - let since = SystemTime::now(); - let mut n_packets_received = 0; - while s0.buffered_amount() > 0 { - loop { - let n = br.tick().await; - if n == 0 { - break; - } - } - - loop { - let readable = { - let q = s1.reassembly_queue.lock().await; - q.is_readable() - }; - if !readable { - break; - } - let (n, ppi) = s1.read_sctp(&mut rbuf).await?; - assert_eq!(n, sbuf.len(), "unexpected length of received data"); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - - n_packets_received += 1; - } - } - let delay = (SystemTime::now().duration_since(since).unwrap().as_millis() as f64) / 1000.0; - log::debug!("received in {} seconds", delay); - assert!(delay >= 0.2, "should be >= 200msec"); - - br.process().await; - - assert_eq!(n_packets_received, 1, "unexpected num of packets received"); - assert_eq!( - s1.get_num_bytes_in_reassembly_queue().await, - 0, - "reassembly queue should be empty" - ); - - { - let a = a0.association_internal.lock().await; - let b = a1.association_internal.lock().await; - - log::debug!("nDATAs : {}", b.stats.get_num_datas()); - log::debug!("nSACKs : {}", a.stats.get_num_sacks()); - log::debug!("nAckTimeouts: {}", b.stats.get_num_ack_timeouts()); - - assert_eq!(b.stats.get_num_datas(), 1, "DATA chunk count mismatch"); - assert_eq!( - a.stats.get_num_sacks(), - b.stats.get_num_datas(), - "sack count should be equal to the number of data chunks" - ); - assert_eq!( - b.stats.get_num_ack_timeouts(), - 1, - "ackTimeout count mismatch" - ); - assert_eq!(a.stats.get_num_t3timeouts(), 0, "should be no retransmit"); - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} -*/ - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_reset_close_one_way() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - static MSG: Bytes = Bytes::from_static(b"ABC"); - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount"); - } - - let n = s0 - .write_sctp(&MSG, PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, MSG.len(), "unexpected length of received data"); - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), MSG.len(), "incorrect bufferedAmount"); - } - - log::debug!("s0.shutdown"); - s0.shutdown(Shutdown::Both).await?; // send reset - - let (done_ch_tx, mut done_ch_rx) = mpsc::channel(1); - let mut buf = vec![0u8; 32]; - - tokio::spawn(async move { - loop { - log::debug!("s1.read_sctp begin"); - match s1.read_sctp(&mut buf).await { - Ok((0, PayloadProtocolIdentifier::Unknown)) => { - log::debug!("s1.read_sctp EOF"); - let _ = done_ch_tx.send(Some(Error::ErrEof)).await; - break; - } - Ok((n, ppi)) => { - log::debug!("s1.read_sctp done with {:?}", &buf[..n]); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!(n, MSG.len(), "unexpected length of received data"); - let _ = done_ch_tx.send(None).await; - } - Err(err) => { - log::debug!("s1.read_sctp err {:?}", err); - let _ = done_ch_tx.send(Some(err)).await; - break; - } - } - } - }); - - loop { - br.process().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - result = done_ch_rx.recv() => { - log::debug!("s1. {:?}", result); - if let Some(err_opt) = result { - if err_opt.is_some() { - break; - } - } else { - break; - } - } - } - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_reset_close_both_ways() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - static MSG: Bytes = Bytes::from_static(b"ABC"); - - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let (s0, s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - { - let a = a0.association_internal.lock().await; - assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount"); - } - - let n = s0 - .write_sctp(&MSG, PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, MSG.len(), "unexpected length of received data"); - { - let a = a0.association_internal.lock().await; - assert_eq!(a.buffered_amount(), MSG.len(), "incorrect bufferedAmount"); - } - - log::debug!("s0.shutdown"); - s0.shutdown(Shutdown::Both).await?; // send reset - - let (done_ch_tx, mut done_ch_rx) = mpsc::channel(1); - let done_ch_tx = Arc::new(done_ch_tx); - - let done_ch_tx1 = Arc::clone(&done_ch_tx); - let ss1 = Arc::clone(&s1); - tokio::spawn(async move { - let mut buf = vec![0u8; 32]; - loop { - log::debug!("s1.read_sctp begin"); - match ss1.read_sctp(&mut buf).await { - Ok((0, PayloadProtocolIdentifier::Unknown)) => { - log::debug!("s1.read_sctp EOF"); - let _ = done_ch_tx1.send(Some(Error::ErrEof)).await; - break; - } - Ok((n, ppi)) => { - log::debug!("s1.read_sctp done with {:?}", &buf[..n]); - assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); - assert_eq!(n, MSG.len(), "unexpected length of received data"); - let _ = done_ch_tx1.send(None).await; - } - Err(err) => { - log::debug!("s1.read_sctp err {:?}", err); - let _ = done_ch_tx1.send(Some(err)).await; - break; - } - } - } - }); - - loop { - br.process().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - result = done_ch_rx.recv() => { - log::debug!("s1. {:?}", result); - if let Some(err_opt) = result { - if err_opt.is_some() { - break; - } - } else { - break; - } - } - } - } - - log::debug!("s1.shutdown"); - s1.shutdown(Shutdown::Both).await?; // send reset - - let done_ch_tx0 = Arc::clone(&done_ch_tx); - tokio::spawn(async move { - let mut buf = vec![0u8; 32]; - - log::debug!("s.read_sctp begin"); - match s0.read_sctp(&mut buf).await { - Ok((0, PayloadProtocolIdentifier::Unknown)) => { - log::debug!("s0.read_sctp EOF"); - let _ = done_ch_tx0.send(Some(Error::ErrEof)).await; - } - Ok(_) => { - panic!("must be error"); - } - Err(err) => { - log::debug!("s0.read_sctp err {:?}", err); - let _ = done_ch_tx0.send(Some(err)).await; - } - } - }); - - loop { - br.process().await; - - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() =>{}, - result = done_ch_rx.recv() => { - log::debug!("s0. {:?}", result); - if let Some(err_opt) = result { - if err_opt.is_some() { - break; - } else { - panic!("must be error"); - } - } else { - break; - } - } - } - } - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_assoc_abort() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - const SI: u16 = 1; - let (br, ca, cb) = Bridge::new(0, None, None); - - let (a0, mut a1) = - create_new_association_pair(&br, Arc::new(ca), Arc::new(cb), AckMode::NoDelay, 0).await?; - - let abort = ChunkAbort { - error_causes: vec![ErrorCauseProtocolViolation { - code: PROTOCOL_VIOLATION, - ..Default::default() - }], - }; - - let packet = { - let a = a0.association_internal.lock().await; - a.create_packet(vec![Box::new(abort)]).marshal()? - }; - - let (_s0, _s1) = establish_session_pair(&br, &a0, &mut a1, SI).await?; - - // Both associations are established - assert_eq!(a0.get_state(), AssociationState::Established); - assert_eq!(a1.get_state(), AssociationState::Established); - - let result = a0.net_conn.send(&packet).await; - assert!(result.is_ok(), "must be ok"); - - flush_buffers(&br, &a0, &a1).await; - - // There is a little delay before changing the state to closed - tokio::time::sleep(Duration::from_millis(10)).await; - - // The receiving association should be closed because it got an ABORT - assert_eq!(a0.get_state(), AssociationState::Established); - assert_eq!(a1.get_state(), AssociationState::Closed); - - close_association_pair(&br, a0, a1).await; - - Ok(()) -} - -struct FakeEchoConn { - wr_tx: Mutex>>, - rd_rx: Mutex>>, - bytes_sent: AtomicUsize, - bytes_received: AtomicUsize, -} - -impl FakeEchoConn { - fn type_erased() -> impl Conn { - Self::default() - } -} - -impl Default for FakeEchoConn { - fn default() -> Self { - let (wr_tx, rd_rx) = mpsc::channel(1); - FakeEchoConn { - wr_tx: Mutex::new(wr_tx), - rd_rx: Mutex::new(rd_rx), - bytes_sent: AtomicUsize::new(0), - bytes_received: AtomicUsize::new(0), - } - } -} - -type UResult = std::result::Result; - -#[async_trait] -impl Conn for FakeEchoConn { - async fn connect(&self, _addr: SocketAddr) -> UResult<()> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, b: &mut [u8]) -> UResult { - let mut rd_rx = self.rd_rx.lock().await; - let v = match rd_rx.recv().await { - Some(v) => v, - None => { - return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected EOF").into()) - } - }; - let l = std::cmp::min(v.len(), b.len()); - b[..l].copy_from_slice(&v[..l]); - self.bytes_received.fetch_add(l, Ordering::SeqCst); - Ok(l) - } - - async fn recv_from(&self, _buf: &mut [u8]) -> UResult<(usize, SocketAddr)> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn send(&self, b: &[u8]) -> UResult { - let wr_tx = self.wr_tx.lock().await; - match wr_tx.send(b.to_vec()).await { - Ok(_) => {} - Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), - }; - self.bytes_sent.fetch_add(b.len(), Ordering::SeqCst); - Ok(b.len()) - } - - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> UResult { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> UResult { - Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Addr Not Available").into()) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> UResult<()> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -//use std::io::Write; - -#[tokio::test] -async fn test_stats() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let conn = Arc::new(FakeEchoConn::type_erased()); - let a = Association::client(Config { - net_conn: Arc::clone(&conn) as Arc, - max_receive_buffer_size: 0, - max_message_size: 0, - name: "client".to_owned(), - }) - .await?; - - if let Some(conn) = conn.as_any().downcast_ref::() { - assert_eq!( - conn.bytes_received.load(Ordering::SeqCst), - a.bytes_received() - ); - assert_eq!(conn.bytes_sent.load(Ordering::SeqCst), a.bytes_sent()); - } else { - panic!("must be FakeEchoConn"); - } - - Ok(()) -} - -async fn create_assocs() -> Result<(Association, Association)> { - let addr1 = SocketAddr::from_str("0.0.0.0:0").unwrap(); - let addr2 = SocketAddr::from_str("0.0.0.0:0").unwrap(); - - let udp1 = UdpSocket::bind(addr1).await.unwrap(); - let udp2 = UdpSocket::bind(addr2).await.unwrap(); - - udp1.connect(udp2.local_addr().unwrap()).await.unwrap(); - udp2.connect(udp1.local_addr().unwrap()).await.unwrap(); - - let (a1chan_tx, mut a1chan_rx) = mpsc::channel(1); - let (a2chan_tx, mut a2chan_rx) = mpsc::channel(1); - - tokio::spawn(async move { - let a = Association::client(Config { - net_conn: Arc::new(udp1), - max_receive_buffer_size: 0, - max_message_size: 0, - name: "client".to_owned(), - }) - .await?; - - let _ = a1chan_tx.send(a).await; - - Result::<()>::Ok(()) - }); - - tokio::spawn(async move { - let a = Association::server(Config { - net_conn: Arc::new(udp2), - max_receive_buffer_size: 0, - max_message_size: 0, - name: "server".to_owned(), - }) - .await?; - - let _ = a2chan_tx.send(a).await; - - Result::<()>::Ok(()) - }); - - let timer1 = tokio::time::sleep(Duration::from_secs(1)); - tokio::pin!(timer1); - let a1 = tokio::select! { - _ = timer1.as_mut() =>{ - panic!("timed out waiting for a1"); - }, - a1 = a1chan_rx.recv() => { - a1.unwrap() - } - }; - - let timer2 = tokio::time::sleep(Duration::from_secs(1)); - tokio::pin!(timer2); - let a2 = tokio::select! { - _ = timer2.as_mut() =>{ - panic!("timed out waiting for a2"); - }, - a2 = a2chan_rx.recv() => { - a2.unwrap() - } - }; - - Ok((a1, a2)) -} - -//use std::io::Write; -//TODO: remove this conditional test -#[cfg(not(target_os = "windows"))] -#[tokio::test] -async fn test_association_shutdown() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let (a1, a2) = create_assocs().await?; - - let s11 = a1.open_stream(1, PayloadProtocolIdentifier::String).await?; - let s21 = a2.open_stream(1, PayloadProtocolIdentifier::String).await?; - - let test_data = Bytes::from_static(b"test"); - - let n = s11.write(&test_data).await?; - assert_eq!(n, test_data.len()); - - let mut buf = vec![0u8; test_data.len()]; - let n = s21.read(&mut buf).await?; - assert_eq!(n, test_data.len()); - assert_eq!(&buf[0..n], &test_data); - - if let Ok(result) = tokio::time::timeout(Duration::from_secs(1), a1.shutdown()).await { - assert!(result.is_ok(), "shutdown should be ok"); - } else { - panic!("shutdown timeout"); - } - - { - let mut close_loop_ch_rx = a2.close_loop_ch_rx.lock().await; - - // Wait for close read loop channels to prevent flaky tests. - let timer2 = tokio::time::sleep(Duration::from_secs(1)); - tokio::pin!(timer2); - tokio::select! { - _ = timer2.as_mut() =>{ - panic!("timed out waiting for a2 read loop to close"); - }, - _ = close_loop_ch_rx.recv() => { - log::debug!("recv a2.close_loop_ch_rx"); - } - }; - } - Ok(()) -} - -//use std::io::Write; -//TODO: remove this conditional test -#[cfg(not(target_os = "windows"))] -#[tokio::test] -async fn test_association_shutdown_during_write() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let (a1, a2) = create_assocs().await?; - - let s11 = a1.open_stream(1, PayloadProtocolIdentifier::String).await?; - let s21 = a2.open_stream(1, PayloadProtocolIdentifier::String).await?; - - let (writing_done_tx, mut writing_done_rx) = mpsc::channel::<()>(1); - let ss21 = Arc::clone(&s21); - tokio::spawn(async move { - let mut i = 0; - while ss21.write(&Bytes::from(vec![i])).await.is_ok() { - if i == 255 { - i = 0; - } else { - i += 1; - } - - if i % 100 == 0 { - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - drop(writing_done_tx); - }); - - let test_data = Bytes::from_static(b"test"); - - let n = s11.write(&test_data).await?; - assert_eq!(n, test_data.len()); - - let mut buf = vec![0u8; test_data.len()]; - let n = s21.read(&mut buf).await?; - assert_eq!(n, test_data.len()); - assert_eq!(&buf[0..n], &test_data); - - { - let mut close_loop_ch_rx = a1.close_loop_ch_rx.lock().await; - tokio::select! { - res = tokio::time::timeout(Duration::from_secs(1), a1.shutdown()) => { - if let Ok(result) = res { - assert!(result.is_ok(), "shutdown should be ok"); - } else { - panic!("shutdown timeout"); - } - } - _ = writing_done_rx.recv() => { - log::debug!("writing_done_rx"); - let result = close_loop_ch_rx.recv().await; - log::debug!("a1.close_loop_ch_rx.recv: {:?}", result); - }, - }; - } - - { - let mut close_loop_ch_rx = a2.close_loop_ch_rx.lock().await; - // Wait for close read loop channels to prevent flaky tests. - let timer2 = tokio::time::sleep(Duration::from_secs(1)); - tokio::pin!(timer2); - tokio::select! { - _ = timer2.as_mut() =>{ - panic!("timed out waiting for a2 read loop to close"); - }, - _ = close_loop_ch_rx.recv() => { - log::debug!("recv a2.close_loop_ch_rx"); - } - }; - } - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_association_handle_packet_before_init() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let tests = vec![ - ( - "InitAck", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::new(ChunkInit { - is_ack: true, - initiate_tag: 1, - num_inbound_streams: 1, - num_outbound_streams: 1, - advertised_receiver_window_credit: 1500, - ..Default::default() - })], - }, - ), - ( - "Abort", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "CoockeEcho", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "HeartBeat", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "PayloadData", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "Sack", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::new(ChunkSelectiveAck { - cumulative_tsn_ack: 1000, - advertised_receiver_window_credit: 1500, - gap_ack_blocks: vec![GapAckBlock { - start: 100, - end: 200, - }], - ..Default::default() - })], - }, - ), - ( - "Reconfig", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::new(ChunkReconfig { - param_a: Some(Box::::default()), - param_b: Some(Box::::default()), - })], - }, - ), - ( - "ForwardTSN", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::new(ChunkForwardTsn { - new_cumulative_tsn: 100, - ..Default::default() - })], - }, - ), - ( - "Error", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "Shutdown", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "ShutdownAck", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ( - "ShutdownComplete", - Packet { - source_port: 1, - destination_port: 1, - verification_tag: 0, - chunks: vec![Box::::default()], - }, - ), - ]; - - for (name, packet) in tests { - log::debug!("testing {}", name); - - let (a_conn, charlie_conn) = pipe(); - - let (a, _) = Association::new( - Config { - net_conn: Arc::new(a_conn), - max_message_size: 0, - max_receive_buffer_size: 0, - name: "client".to_owned(), - }, - true, - ) - .await - .unwrap(); - - let packet = packet.marshal()?; - let result = charlie_conn.send(&packet).await; - assert!(result.is_ok(), "{name} charlie_conn.send should be ok"); - - // Should not panic. - tokio::time::sleep(Duration::from_millis(100)).await; - - a.close().await.unwrap(); - } - - Ok(()) -} diff --git a/sctp/src/association/mod.rs b/sctp/src/association/mod.rs deleted file mode 100644 index cec94accb..000000000 --- a/sctp/src/association/mod.rs +++ /dev/null @@ -1,626 +0,0 @@ -#[cfg(test)] -mod association_test; - -mod association_internal; -mod association_stats; - -use std::collections::{HashMap, VecDeque}; -use std::fmt; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use association_internal::*; -use association_stats::*; -use bytes::{Bytes, BytesMut}; -use portable_atomic::{AtomicBool, AtomicU32, AtomicU8, AtomicUsize}; -use rand::random; -use tokio::sync::{broadcast, mpsc, Mutex}; -use util::Conn; - -use crate::chunk::chunk_abort::ChunkAbort; -use crate::chunk::chunk_cookie_ack::ChunkCookieAck; -use crate::chunk::chunk_cookie_echo::ChunkCookieEcho; -use crate::chunk::chunk_error::ChunkError; -use crate::chunk::chunk_forward_tsn::{ChunkForwardTsn, ChunkForwardTsnStream}; -use crate::chunk::chunk_heartbeat::ChunkHeartbeat; -use crate::chunk::chunk_heartbeat_ack::ChunkHeartbeatAck; -use crate::chunk::chunk_init::ChunkInit; -use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier}; -use crate::chunk::chunk_reconfig::ChunkReconfig; -use crate::chunk::chunk_selective_ack::ChunkSelectiveAck; -use crate::chunk::chunk_shutdown::ChunkShutdown; -use crate::chunk::chunk_shutdown_ack::ChunkShutdownAck; -use crate::chunk::chunk_shutdown_complete::ChunkShutdownComplete; -use crate::chunk::chunk_type::*; -use crate::chunk::Chunk; -use crate::error::{Error, Result}; -use crate::error_cause::*; -use crate::packet::Packet; -use crate::param::param_heartbeat_info::ParamHeartbeatInfo; -use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest; -use crate::param::param_reconfig_response::{ParamReconfigResponse, ReconfigResult}; -use crate::param::param_state_cookie::ParamStateCookie; -use crate::param::param_supported_extensions::ParamSupportedExtensions; -use crate::param::Param; -use crate::queue::control_queue::ControlQueue; -use crate::queue::payload_queue::PayloadQueue; -use crate::queue::pending_queue::PendingQueue; -use crate::stream::*; -use crate::timer::ack_timer::*; -use crate::timer::rtx_timer::*; -use crate::util::*; - -pub(crate) const RECEIVE_MTU: usize = 8192; -/// MTU for inbound packet (from DTLS) -pub(crate) const INITIAL_MTU: u32 = 1228; -/// initial MTU for outgoing packets (to DTLS) -pub(crate) const INITIAL_RECV_BUF_SIZE: u32 = 1024 * 1024; -pub(crate) const COMMON_HEADER_SIZE: u32 = 12; -pub(crate) const DATA_CHUNK_HEADER_SIZE: u32 = 16; -pub(crate) const DEFAULT_MAX_MESSAGE_SIZE: u32 = 65536; - -/// other constants -pub(crate) const ACCEPT_CH_SIZE: usize = 16; - -/// association state enums -#[derive(Debug, Copy, Clone, PartialEq)] -pub(crate) enum AssociationState { - Closed = 0, - CookieWait = 1, - CookieEchoed = 2, - Established = 3, - ShutdownAckSent = 4, - ShutdownPending = 5, - ShutdownReceived = 6, - ShutdownSent = 7, -} - -impl From for AssociationState { - fn from(v: u8) -> AssociationState { - match v { - 1 => AssociationState::CookieWait, - 2 => AssociationState::CookieEchoed, - 3 => AssociationState::Established, - 4 => AssociationState::ShutdownAckSent, - 5 => AssociationState::ShutdownPending, - 6 => AssociationState::ShutdownReceived, - 7 => AssociationState::ShutdownSent, - _ => AssociationState::Closed, - } - } -} - -impl fmt::Display for AssociationState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - AssociationState::Closed => "Closed", - AssociationState::CookieWait => "CookieWait", - AssociationState::CookieEchoed => "CookieEchoed", - AssociationState::Established => "Established", - AssociationState::ShutdownPending => "ShutdownPending", - AssociationState::ShutdownSent => "ShutdownSent", - AssociationState::ShutdownReceived => "ShutdownReceived", - AssociationState::ShutdownAckSent => "ShutdownAckSent", - }; - write!(f, "{s}") - } -} - -/// retransmission timer IDs -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub(crate) enum RtxTimerId { - #[default] - T1Init, - T1Cookie, - T2Shutdown, - T3RTX, - Reconfig, -} - -impl fmt::Display for RtxTimerId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RtxTimerId::T1Init => "T1Init", - RtxTimerId::T1Cookie => "T1Cookie", - RtxTimerId::T2Shutdown => "T2Shutdown", - RtxTimerId::T3RTX => "T3RTX", - RtxTimerId::Reconfig => "Reconfig", - }; - write!(f, "{s}") - } -} - -/// ack mode (for testing) -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub(crate) enum AckMode { - #[default] - Normal, - NoDelay, - AlwaysDelay, -} - -impl fmt::Display for AckMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - AckMode::Normal => "Normal", - AckMode::NoDelay => "NoDelay", - AckMode::AlwaysDelay => "AlwaysDelay", - }; - write!(f, "{s}") - } -} - -/// ack transmission state -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub(crate) enum AckState { - #[default] - Idle, // ack timer is off - Immediate, // will send ack immediately - Delay, // ack timer is on (ack is being delayed) -} - -impl fmt::Display for AckState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - AckState::Idle => "Idle", - AckState::Immediate => "Immediate", - AckState::Delay => "Delay", - }; - write!(f, "{s}") - } -} - -/// Config collects the arguments to create_association construction into -/// a single structure -pub struct Config { - pub net_conn: Arc, - pub max_receive_buffer_size: u32, - pub max_message_size: u32, - pub name: String, -} - -///Association represents an SCTP association -///13.2. Parameters Necessary per Association (i.e., the TCB) -///Peer : Tag value to be sent in every packet and is received -///Verification: in the INIT or INIT ACK chunk. -///Tag : -/// -///My : Tag expected in every inbound packet and sent in the -///Verification: INIT or INIT ACK chunk. -/// -///Tag : -///State : A state variable indicating what state the association -/// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED, -/// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, -/// : SHUTDOWN-ACK-SENT. -/// -/// No Closed state is illustrated since if a -/// association is Closed its TCB SHOULD be removed. -pub struct Association { - name: String, - state: Arc, - max_message_size: Arc, - inflight_queue_length: Arc, - will_send_shutdown: Arc, - awake_write_loop_ch: Arc>, - close_loop_ch_rx: Mutex>, - accept_ch_rx: Mutex>>, - net_conn: Arc, - bytes_received: Arc, - bytes_sent: Arc, - - pub(crate) association_internal: Arc>, -} - -impl Association { - /// server accepts a SCTP stream over a conn - pub async fn server(config: Config) -> Result { - let (a, mut handshake_completed_ch_rx) = Association::new(config, false).await?; - - if let Some(err_opt) = handshake_completed_ch_rx.recv().await { - if let Some(err) = err_opt { - Err(err) - } else { - Ok(a) - } - } else { - Err(Error::ErrAssociationHandshakeClosed) - } - } - - /// Client opens a SCTP stream over a conn - pub async fn client(config: Config) -> Result { - let (a, mut handshake_completed_ch_rx) = Association::new(config, true).await?; - - if let Some(err_opt) = handshake_completed_ch_rx.recv().await { - if let Some(err) = err_opt { - Err(err) - } else { - Ok(a) - } - } else { - Err(Error::ErrAssociationHandshakeClosed) - } - } - - /// Shutdown initiates the shutdown sequence. The method blocks until the - /// shutdown sequence is completed and the connection is closed, or until the - /// passed context is done, in which case the context's error is returned. - pub async fn shutdown(&self) -> Result<()> { - log::debug!("[{}] closing association..", self.name); - - let state = self.get_state(); - if state != AssociationState::Established { - return Err(Error::ErrShutdownNonEstablished); - } - - // Attempt a graceful shutdown. - self.set_state(AssociationState::ShutdownPending); - - if self.inflight_queue_length.load(Ordering::SeqCst) == 0 { - // No more outstanding, send shutdown. - self.will_send_shutdown.store(true, Ordering::SeqCst); - let _ = self.awake_write_loop_ch.try_send(()); - self.set_state(AssociationState::ShutdownSent); - } - - { - let mut close_loop_ch_rx = self.close_loop_ch_rx.lock().await; - let _ = close_loop_ch_rx.recv().await; - } - - Ok(()) - } - - /// Close ends the SCTP Association and cleans up any state - pub async fn close(&self) -> Result<()> { - log::debug!("[{}] closing association..", self.name); - - let _ = self.net_conn.close().await; - - let mut ai = self.association_internal.lock().await; - ai.close().await - } - - async fn new(config: Config, is_client: bool) -> Result<(Self, mpsc::Receiver>)> { - let net_conn = Arc::clone(&config.net_conn); - - let (awake_write_loop_ch_tx, awake_write_loop_ch_rx) = mpsc::channel(1); - let (accept_ch_tx, accept_ch_rx) = mpsc::channel(ACCEPT_CH_SIZE); - let (handshake_completed_ch_tx, handshake_completed_ch_rx) = mpsc::channel(1); - let (close_loop_ch_tx, close_loop_ch_rx) = broadcast::channel(1); - let (close_loop_ch_rx1, close_loop_ch_rx2) = - (close_loop_ch_tx.subscribe(), close_loop_ch_tx.subscribe()); - let awake_write_loop_ch = Arc::new(awake_write_loop_ch_tx); - - let ai = AssociationInternal::new( - config, - close_loop_ch_tx, - accept_ch_tx, - handshake_completed_ch_tx, - Arc::clone(&awake_write_loop_ch), - ); - - let bytes_received = Arc::new(AtomicUsize::new(0)); - let bytes_sent = Arc::new(AtomicUsize::new(0)); - let name = ai.name.clone(); - let state = Arc::clone(&ai.state); - let max_message_size = Arc::clone(&ai.max_message_size); - let inflight_queue_length = Arc::clone(&ai.inflight_queue_length); - let will_send_shutdown = Arc::clone(&ai.will_send_shutdown); - - let mut init = ChunkInit { - initial_tsn: ai.my_next_tsn, - num_outbound_streams: ai.my_max_num_outbound_streams, - num_inbound_streams: ai.my_max_num_inbound_streams, - initiate_tag: ai.my_verification_tag, - advertised_receiver_window_credit: ai.max_receive_buffer_size, - ..Default::default() - }; - init.set_supported_extensions(); - - let name1 = name.clone(); - let name2 = name.clone(); - - let bytes_received1 = Arc::clone(&bytes_received); - let bytes_sent2 = Arc::clone(&bytes_sent); - - let net_conn1 = Arc::clone(&net_conn); - let net_conn2 = Arc::clone(&net_conn); - - let association_internal = Arc::new(Mutex::new(ai)); - let association_internal1 = Arc::clone(&association_internal); - let association_internal2 = Arc::clone(&association_internal); - - { - let association_internal3 = Arc::clone(&association_internal); - - let mut ai = association_internal.lock().await; - ai.t1init = Some(RtxTimer::new( - Arc::downgrade(&association_internal3), - RtxTimerId::T1Init, - MAX_INIT_RETRANS, - )); - ai.t1cookie = Some(RtxTimer::new( - Arc::downgrade(&association_internal3), - RtxTimerId::T1Cookie, - MAX_INIT_RETRANS, - )); - ai.t2shutdown = Some(RtxTimer::new( - Arc::downgrade(&association_internal3), - RtxTimerId::T2Shutdown, - NO_MAX_RETRANS, - )); // retransmit forever - ai.t3rtx = Some(RtxTimer::new( - Arc::downgrade(&association_internal3), - RtxTimerId::T3RTX, - NO_MAX_RETRANS, - )); // retransmit forever - ai.treconfig = Some(RtxTimer::new( - Arc::downgrade(&association_internal3), - RtxTimerId::Reconfig, - NO_MAX_RETRANS, - )); // retransmit forever - ai.ack_timer = Some(AckTimer::new( - Arc::downgrade(&association_internal3), - ACK_INTERVAL, - )); - } - - tokio::spawn(async move { - Association::read_loop( - name1, - bytes_received1, - net_conn1, - close_loop_ch_rx1, - association_internal1, - ) - .await; - }); - - tokio::spawn(async move { - Association::write_loop( - name2, - bytes_sent2, - net_conn2, - close_loop_ch_rx2, - association_internal2, - awake_write_loop_ch_rx, - ) - .await; - }); - - if is_client { - let mut ai = association_internal.lock().await; - ai.set_state(AssociationState::CookieWait); - ai.stored_init = Some(init); - ai.send_init()?; - let rto = ai.rto_mgr.get_rto(); - if let Some(t1init) = &ai.t1init { - t1init.start(rto).await; - } - } - - Ok(( - Association { - name, - state, - max_message_size, - inflight_queue_length, - will_send_shutdown, - awake_write_loop_ch, - close_loop_ch_rx: Mutex::new(close_loop_ch_rx), - accept_ch_rx: Mutex::new(accept_ch_rx), - net_conn, - bytes_received, - bytes_sent, - association_internal, - }, - handshake_completed_ch_rx, - )) - } - - async fn read_loop( - name: String, - bytes_received: Arc, - net_conn: Arc, - mut close_loop_ch: broadcast::Receiver<()>, - association_internal: Arc>, - ) { - log::debug!("[{}] read_loop entered", name); - - let mut buffer = vec![0u8; RECEIVE_MTU]; - let mut done = false; - let mut n; - while !done { - tokio::select! { - _ = close_loop_ch.recv() => break, - result = net_conn.recv(&mut buffer) => { - match result { - Ok(m) => { - n=m; - } - Err(err) => { - log::warn!("[{}] failed to read packets on net_conn: {}", name, err); - break; - } - } - } - }; - - // Make a buffer sized to what we read, then copy the data we - // read from the underlying transport. We do this because the - // user data is passed to the reassembly queue without - // copying. - log::debug!("[{}] recving {} bytes", name, n); - let inbound = Bytes::from(buffer[..n].to_vec()); - bytes_received.fetch_add(n, Ordering::SeqCst); - - { - let mut ai = association_internal.lock().await; - if let Err(err) = ai.handle_inbound(&inbound).await { - log::warn!("[{}] failed to handle_inbound: {:?}", name, err); - done = true; - } - } - } - - { - let mut ai = association_internal.lock().await; - if let Err(err) = ai.close().await { - log::warn!("[{}] failed to close association: {:?}", name, err); - } - } - - log::debug!("[{}] read_loop exited", name); - } - - async fn write_loop( - name: String, - bytes_sent: Arc, - net_conn: Arc, - mut close_loop_ch: broadcast::Receiver<()>, - association_internal: Arc>, - mut awake_write_loop_ch: mpsc::Receiver<()>, - ) { - log::debug!("[{}] write_loop entered", name); - let done = Arc::new(AtomicBool::new(false)); - let name = Arc::new(name); - - 'outer: while !done.load(Ordering::Relaxed) { - //log::debug!("[{}] gather_outbound begin", name); - let (packets, continue_loop) = { - let mut ai = association_internal.lock().await; - ai.gather_outbound().await - }; - //log::debug!("[{}] gather_outbound done with {}", name, packets.len()); - - let net_conn = Arc::clone(&net_conn); - let bytes_sent = Arc::clone(&bytes_sent); - let name2 = Arc::clone(&name); - let done2 = Arc::clone(&done); - let mut buffer = None; - for raw in packets { - let mut buf = buffer - .take() - .unwrap_or_else(|| BytesMut::with_capacity(16 * 1024)); - - // We do the marshalling work in a blocking task here for a reason: - // If we don't tokio tends to run the write_loop and read_loop of one connection on the same OS thread - // This means that even though we release the lock above, the read_loop isn't able to take it, simply because it is not being scheduled by tokio - // Doing it this way, tokio schedules this work on a dedicated blocking thread, this future is suspended, and the read_loop can make progress - match tokio::task::spawn_blocking(move || raw.marshal_to(&mut buf).map(|_| buf)) - .await - { - Ok(Ok(mut buf)) => { - let raw = buf.as_ref(); - if let Err(err) = net_conn.send(raw.as_ref()).await { - log::warn!("[{}] failed to write packets on net_conn: {}", name2, err); - done2.store(true, Ordering::Relaxed) - } else { - bytes_sent.fetch_add(raw.len(), Ordering::SeqCst); - } - - // Reuse allocation. Have to use options, since spawn blocking can't borrow, has to take ownership. - buf.clear(); - buffer = Some(buf); - } - Ok(Err(err)) => { - log::warn!("[{}] failed to serialize a packet: {:?}", name2, err); - } - Err(err) => { - if err.is_cancelled() { - log::debug!( - "[{}] task cancelled while serializing a packet: {:?}", - name, - err - ); - break 'outer; - } else { - log::error!("[{}] panic while serializing a packet: {:?}", name, err); - } - } - } - //log::debug!("[{}] sending {} bytes done", name, raw.len()); - } - - if !continue_loop { - break; - } - - //log::debug!("[{}] wait awake_write_loop_ch", name); - tokio::select! { - _ = awake_write_loop_ch.recv() =>{} - _ = close_loop_ch.recv() => { - done.store(true, Ordering::Relaxed); - } - }; - //log::debug!("[{}] wait awake_write_loop_ch done", name); - } - - { - let mut ai = association_internal.lock().await; - if let Err(err) = ai.close().await { - log::warn!("[{}] failed to close association: {:?}", name, err); - } - } - - log::debug!("[{}] write_loop exited", name); - } - - /// bytes_sent returns the number of bytes sent - pub fn bytes_sent(&self) -> usize { - self.bytes_sent.load(Ordering::SeqCst) - } - - /// bytes_received returns the number of bytes received - pub fn bytes_received(&self) -> usize { - self.bytes_received.load(Ordering::SeqCst) - } - - /// open_stream opens a stream - pub async fn open_stream( - &self, - stream_identifier: u16, - default_payload_type: PayloadProtocolIdentifier, - ) -> Result> { - let mut ai = self.association_internal.lock().await; - ai.open_stream(stream_identifier, default_payload_type) - } - - /// accept_stream accepts a stream - pub async fn accept_stream(&self) -> Option> { - let mut accept_ch_rx = self.accept_ch_rx.lock().await; - accept_ch_rx.recv().await - } - - /// max_message_size returns the maximum message size you can send. - pub fn max_message_size(&self) -> u32 { - self.max_message_size.load(Ordering::SeqCst) - } - - /// set_max_message_size sets the maximum message size you can send. - pub fn set_max_message_size(&self, max_message_size: u32) { - self.max_message_size - .store(max_message_size, Ordering::SeqCst); - } - - /// set_state atomically sets the state of the Association. - fn set_state(&self, new_state: AssociationState) { - let old_state = AssociationState::from(self.state.swap(new_state as u8, Ordering::SeqCst)); - if new_state != old_state { - log::debug!( - "[{}] state change: '{}' => '{}'", - self.name, - old_state, - new_state, - ); - } - } - - /// get_state atomically returns the state of the Association. - fn get_state(&self) -> AssociationState { - self.state.load(Ordering::SeqCst).into() - } -} diff --git a/sctp/src/chunk/chunk_abort.rs b/sctp/src/chunk/chunk_abort.rs deleted file mode 100644 index 6fc131442..000000000 --- a/sctp/src/chunk/chunk_abort.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; -use crate::error_cause::*; - -///Abort represents an SCTP Chunk of type ABORT -/// -///The ABORT chunk is sent to the peer of an association to close the -///association. The ABORT chunk may contain Cause Parameters to inform -///the receiver about the reason of the abort. DATA chunks MUST NOT be -///bundled with ABORT. Control chunks (except for INIT, INIT ACK, and -///SHUTDOWN COMPLETE) MAY be bundled with an ABORT, but they MUST be -///placed before the ABORT in the SCTP packet or they will be ignored by -///the receiver. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 6 |Reserved |T| Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| zero or more Error Causes | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkAbort { - pub(crate) error_causes: Vec, -} - -/// String makes chunkAbort printable -impl fmt::Display for ChunkAbort { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = vec![self.header().to_string()]; - - for cause in &self.error_causes { - res.push(format!(" - {cause}")); - } - - write!(f, "{}", res.join("\n")) - } -} - -impl Chunk for ChunkAbort { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_ABORT, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_ABORT { - return Err(Error::ErrChunkTypeNotAbort); - } - - let mut error_causes = vec![]; - let mut offset = CHUNK_HEADER_SIZE; - while offset + 4 <= raw.len() { - let e = ErrorCause::unmarshal( - &raw.slice(offset..CHUNK_HEADER_SIZE + header.value_length()), - )?; - offset += e.length(); - error_causes.push(e); - } - - Ok(ChunkAbort { error_causes }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - for ec in &self.error_causes { - buf.extend(ec.marshal()); - } - Ok(buf.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - self.error_causes - .iter() - .fold(0, |length, ec| length + ec.length()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_cookie_ack.rs b/sctp/src/chunk/chunk_cookie_ack.rs deleted file mode 100644 index a3548b645..000000000 --- a/sctp/src/chunk/chunk_cookie_ack.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -/// chunkCookieAck represents an SCTP Chunk of type chunkCookieAck -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Type = 11 |Chunk Flags | Length = 4 | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Clone)] -pub(crate) struct ChunkCookieAck; - -/// makes ChunkCookieAck printable -impl fmt::Display for ChunkCookieAck { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkCookieAck { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_COOKIE_ACK, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_COOKIE_ACK { - return Err(Error::ErrChunkTypeNotCookieAck); - } - - Ok(ChunkCookieAck {}) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - Ok(buf.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - 0 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_cookie_echo.rs b/sctp/src/chunk/chunk_cookie_echo.rs deleted file mode 100644 index c49c88a60..000000000 --- a/sctp/src/chunk/chunk_cookie_echo.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -/// CookieEcho represents an SCTP Chunk of type CookieEcho -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Type = 10 |Chunk Flags | Length | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Cookie | -/// | | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkCookieEcho { - pub(crate) cookie: Bytes, -} - -/// makes ChunkCookieEcho printable -impl fmt::Display for ChunkCookieEcho { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkCookieEcho { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_COOKIE_ECHO, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_COOKIE_ECHO { - return Err(Error::ErrChunkTypeNotCookieEcho); - } - - let cookie = raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()); - Ok(ChunkCookieEcho { cookie }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - buf.extend(self.cookie.clone()); - Ok(buf.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - self.cookie.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_error.rs b/sctp/src/chunk/chunk_error.rs deleted file mode 100644 index 3ce538b3a..000000000 --- a/sctp/src/chunk/chunk_error.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; -use crate::error_cause::*; - -///Operation Error (ERROR) (9) -/// -///An endpoint sends this chunk to its peer endpoint to notify it of -///certain error conditions. It contains one or more error causes. An -///Operation Error is not considered fatal in and of itself, but may be -///used with an ERROR chunk to report a fatal condition. It has the -///following parameters: -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 9 | Chunk Flags | Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| one or more Error Causes | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///Chunk Flags: 8 bits -/// Set to 0 on transmit and ignored on receipt. -///Length: 16 bits (unsigned integer) -/// Set to the size of the chunk in bytes, including the chunk header -/// and all the Error Cause fields present. -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkError { - pub(crate) error_causes: Vec, -} - -/// makes ChunkError printable -impl fmt::Display for ChunkError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = vec![self.header().to_string()]; - - for cause in &self.error_causes { - res.push(format!(" - {cause}")); - } - - write!(f, "{}", res.join("\n")) - } -} - -impl Chunk for ChunkError { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_ERROR, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_ERROR { - return Err(Error::ErrChunkTypeNotCtError); - } - - let mut error_causes = vec![]; - let mut offset = CHUNK_HEADER_SIZE; - while offset + 4 <= raw.len() { - let e = ErrorCause::unmarshal( - &raw.slice(offset..CHUNK_HEADER_SIZE + header.value_length()), - )?; - offset += e.length(); - error_causes.push(e); - } - - Ok(ChunkError { error_causes }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - for ec in &self.error_causes { - buf.extend(ec.marshal()); - } - Ok(buf.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - self.error_causes - .iter() - .fold(0, |length, ec| length + ec.length()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_forward_tsn.rs b/sctp/src/chunk/chunk_forward_tsn.rs deleted file mode 100644 index 5599460dc..000000000 --- a/sctp/src/chunk/chunk_forward_tsn.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -///This chunk shall be used by the data sender to inform the data -///receiver to adjust its cumulative received TSN point forward because -///some missing TSNs are associated with data chunks that SHOULD NOT be -///transmitted or retransmitted by the sender. -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 192 | Flags = 0x00 | Length = Variable | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| New Cumulative TSN | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Stream-1 | Stream Sequence-1 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Stream-N | Stream Sequence-N | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkForwardTsn { - /// This indicates the new cumulative TSN to the data receiver. Upon - /// the reception of this value, the data receiver MUST consider - /// any missing TSNs earlier than or equal to this value as received, - /// and stop reporting them as gaps in any subsequent SACKs. - pub(crate) new_cumulative_tsn: u32, - pub(crate) streams: Vec, -} - -pub(crate) const NEW_CUMULATIVE_TSN_LENGTH: usize = 4; -pub(crate) const FORWARD_TSN_STREAM_LENGTH: usize = 4; - -/// makes ChunkForwardTsn printable -impl fmt::Display for ChunkForwardTsn { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = vec![self.header().to_string()]; - res.push(format!("New Cumulative TSN: {}", self.new_cumulative_tsn)); - for s in &self.streams { - res.push(format!(" - si={}, ssn={}", s.identifier, s.sequence)); - } - - write!(f, "{}", res.join("\n")) - } -} - -impl Chunk for ChunkForwardTsn { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_FORWARD_TSN, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(buf: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(buf)?; - - if header.typ != CT_FORWARD_TSN { - return Err(Error::ErrChunkTypeNotForwardTsn); - } - - let mut offset = CHUNK_HEADER_SIZE + NEW_CUMULATIVE_TSN_LENGTH; - if buf.len() < offset { - return Err(Error::ErrChunkTooShort); - } - - let reader = &mut buf.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()); - let new_cumulative_tsn = reader.get_u32(); - - let mut streams = vec![]; - let mut remaining = buf.len() - offset; - while remaining > 0 { - let s = ChunkForwardTsnStream::unmarshal( - &buf.slice(offset..CHUNK_HEADER_SIZE + header.value_length()), - )?; - offset += s.value_length(); - remaining -= s.value_length(); - streams.push(s); - } - - Ok(ChunkForwardTsn { - new_cumulative_tsn, - streams, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - - writer.put_u32(self.new_cumulative_tsn); - - for s in &self.streams { - writer.extend(s.marshal()?); - } - - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - NEW_CUMULATIVE_TSN_LENGTH + FORWARD_TSN_STREAM_LENGTH * self.streams.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} - -#[derive(Debug, Clone)] -pub(crate) struct ChunkForwardTsnStream { - /// This field holds a stream number that was skipped by this - /// FWD-TSN. - pub(crate) identifier: u16, - - /// This field holds the sequence number associated with the stream - /// that was skipped. The stream sequence field holds the largest - /// stream sequence number in this stream being skipped. The receiver - /// of the FWD-TSN's can use the Stream-N and Stream Sequence-N fields - /// to enable delivery of any stranded TSN's that remain on the stream - /// re-ordering queues. This field MUST NOT report TSN's corresponding - /// to DATA chunks that are marked as unordered. For ordered DATA - /// chunks this field MUST be filled in. - pub(crate) sequence: u16, -} - -/// makes ChunkForwardTsnStream printable -impl fmt::Display for ChunkForwardTsnStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}, {}", self.identifier, self.sequence) - } -} - -impl Chunk for ChunkForwardTsnStream { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: ChunkType(0), - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(buf: &Bytes) -> Result { - if buf.len() < FORWARD_TSN_STREAM_LENGTH { - return Err(Error::ErrChunkTooShort); - } - - let reader = &mut buf.clone(); - let identifier = reader.get_u16(); - let sequence = reader.get_u16(); - - Ok(ChunkForwardTsnStream { - identifier, - sequence, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - writer.put_u16(self.identifier); - writer.put_u16(self.sequence); - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - FORWARD_TSN_STREAM_LENGTH - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_header.rs b/sctp/src/chunk/chunk_header.rs deleted file mode 100644 index 8c59f8bdc..000000000 --- a/sctp/src/chunk/chunk_header.rs +++ /dev/null @@ -1,111 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::chunk_type::*; -use super::*; - -///chunkHeader represents a SCTP Chunk header, defined in https://tools.ietf.org/html/rfc4960#section-3.2 -///The figure below illustrates the field format for the chunks to be -///transmitted in the SCTP packet. Each chunk is formatted with a Chunk -///Type field, a chunk-specific Flag field, a Chunk Length field, and a -///Value field. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Chunk Type | Chunk Flags | Chunk Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| Chunk Value | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Clone)] -pub(crate) struct ChunkHeader { - pub(crate) typ: ChunkType, - pub(crate) flags: u8, - pub(crate) value_length: u16, -} - -pub(crate) const CHUNK_HEADER_SIZE: usize = 4; - -/// makes ChunkHeader printable -impl fmt::Display for ChunkHeader { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.typ) - } -} - -impl Chunk for ChunkHeader { - fn header(&self) -> ChunkHeader { - self.clone() - } - - fn unmarshal(raw: &Bytes) -> Result { - if raw.len() < CHUNK_HEADER_SIZE { - return Err(Error::ErrChunkHeaderTooSmall); - } - - let reader = &mut raw.clone(); - - let typ = ChunkType(reader.get_u8()); - let flags = reader.get_u8(); - let length = reader.get_u16(); - - if length < CHUNK_HEADER_SIZE as u16 { - return Err(Error::ErrChunkHeaderInvalidLength); - } - if (length as usize) > raw.len() { - return Err(Error::ErrChunkHeaderInvalidLength); - } - - // Length includes Chunk header - let value_length = length as isize - CHUNK_HEADER_SIZE as isize; - - let length_after_value = raw.len() as isize - length as isize; - if length_after_value < 0 { - return Err(Error::ErrChunkHeaderNotEnoughSpace); - } else if length_after_value < 4 { - // https://tools.ietf.org/html/rfc4960#section-3.2 - // The Chunk Length field does not count any chunk PADDING. - // Chunks (including Type, Length, and Value fields) are padded out - // by the sender with all zero bytes to be a multiple of 4 bytes - // long. This PADDING MUST NOT be more than 3 bytes in total. The - // Chunk Length value does not include terminating PADDING of the - // chunk. However, it does include PADDING of any variable-length - // parameter except the last parameter in the chunk. The receiver - // MUST ignore the PADDING. - for i in (1..=length_after_value).rev() { - let padding_offset = CHUNK_HEADER_SIZE + (value_length + i - 1) as usize; - if raw[padding_offset] != 0 { - return Err(Error::ErrChunkHeaderPaddingNonZero); - } - } - } - - Ok(ChunkHeader { - typ, - flags, - value_length: length - CHUNK_HEADER_SIZE as u16, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - writer.put_u8(self.typ.0); - writer.put_u8(self.flags); - writer.put_u16(self.value_length + CHUNK_HEADER_SIZE as u16); - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - self.value_length as usize - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_heartbeat.rs b/sctp/src/chunk/chunk_heartbeat.rs deleted file mode 100644 index f2091a4c0..000000000 --- a/sctp/src/chunk/chunk_heartbeat.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; -use crate::param::param_header::*; -use crate::param::param_type::*; -use crate::param::*; - -///chunkHeartbeat represents an SCTP Chunk of type HEARTBEAT -/// -///An endpoint should send this chunk to its peer endpoint to probe the -///reachability of a particular destination transport address defined in -///the present association. -/// -///The parameter field contains the Heartbeat Information, which is a -///variable-length opaque data structure understood only by the sender. -/// -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 4 | Chunk Flags | Heartbeat Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| Heartbeat Information TLV (Variable-Length) | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -///Defined as a variable-length parameter using the format described -///in Section 3.2.1, i.e.: -/// -///Variable Parameters Status Type Value -///------------------------------------------------------------- -///heartbeat Info Mandatory 1 -#[derive(Default, Debug)] -pub(crate) struct ChunkHeartbeat { - pub(crate) params: Vec>, -} - -/// makes ChunkHeartbeat printable -impl fmt::Display for ChunkHeartbeat { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkHeartbeat { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_HEARTBEAT, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_HEARTBEAT { - return Err(Error::ErrChunkTypeNotHeartbeat); - } - - if raw.len() <= CHUNK_HEADER_SIZE { - return Err(Error::ErrHeartbeatNotLongEnoughInfo); - } - - let p = - build_param(&raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()))?; - if p.header().typ != ParamType::HeartbeatInfo { - return Err(Error::ErrHeartbeatParam); - } - let params = vec![p]; - - Ok(ChunkHeartbeat { params }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - for p in &self.params { - buf.extend(p.marshal()?); - } - Ok(buf.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - self.params.iter().fold(0, |length, p| { - length + PARAM_HEADER_LENGTH + p.value_length() - }) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_heartbeat_ack.rs b/sctp/src/chunk/chunk_heartbeat_ack.rs deleted file mode 100644 index 0d058c261..000000000 --- a/sctp/src/chunk/chunk_heartbeat_ack.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; -use crate::param::param_header::*; -use crate::param::param_type::ParamType; -use crate::param::*; -use crate::util::get_padding_size; - -///chunkHeartbeatAck represents an SCTP Chunk of type HEARTBEAT ACK -/// -///An endpoint should send this chunk to its peer endpoint as a response -///to a HEARTBEAT chunk (see Section 8.3). A HEARTBEAT ACK is always -///sent to the source IP address of the IP datagram containing the -///HEARTBEAT chunk to which this ack is responding. -/// -///The parameter field contains a variable-length opaque data structure. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 5 | Chunk Flags | Heartbeat Ack Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| Heartbeat Information TLV (Variable-Length) | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// -///Defined as a variable-length parameter using the format described -///in Section 3.2.1, i.e.: -/// -///Variable Parameters Status Type Value -///------------------------------------------------------------- -///Heartbeat Info Mandatory 1 -#[derive(Default, Debug)] -pub(crate) struct ChunkHeartbeatAck { - pub(crate) params: Vec>, -} - -/// makes ChunkHeartbeatAck printable -impl fmt::Display for ChunkHeartbeatAck { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkHeartbeatAck { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_HEARTBEAT_ACK, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_HEARTBEAT_ACK { - return Err(Error::ErrChunkTypeNotHeartbeatAck); - } - - if raw.len() <= CHUNK_HEADER_SIZE { - return Err(Error::ErrHeartbeatNotLongEnoughInfo); - } - - let p = - build_param(&raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()))?; - if p.header().typ != ParamType::HeartbeatInfo { - return Err(Error::ErrHeartbeatParam); - } - let params = vec![p]; - - Ok(ChunkHeartbeatAck { params }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - if self.params.len() != 1 { - return Err(Error::ErrHeartbeatAckParams); - } - if self.params[0].header().typ != ParamType::HeartbeatInfo { - return Err(Error::ErrHeartbeatAckNotHeartbeatInfo); - } - - self.header().marshal_to(buf)?; - for (idx, p) in self.params.iter().enumerate() { - let pp = p.marshal()?; - let pp_len = pp.len(); - buf.extend(pp); - - // Chunks (including Type, Length, and Value fields) are padded out - // by the sender with all zero bytes to be a multiple of 4 bytes - // long. This PADDING MUST NOT be more than 3 bytes in total. The - // Chunk Length value does not include terminating PADDING of the - // chunk. *However, it does include PADDING of any variable-length - // parameter except the last parameter in the chunk.* The receiver - // MUST ignore the PADDING. - if idx != self.params.len() - 1 { - let cnt = get_padding_size(pp_len); - buf.extend(vec![0u8; cnt]); - } - } - Ok(buf.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - let mut l = 0; - for (idx, p) in self.params.iter().enumerate() { - let p_len = PARAM_HEADER_LENGTH + p.value_length(); - l += p_len; - if idx != self.params.len() - 1 { - l += get_padding_size(p_len); - } - } - l - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_init.rs b/sctp/src/chunk/chunk_init.rs deleted file mode 100644 index 141ed102d..000000000 --- a/sctp/src/chunk/chunk_init.rs +++ /dev/null @@ -1,304 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; -use crate::param::param_header::*; -use crate::param::param_supported_extensions::ParamSupportedExtensions; -use crate::param::*; -use crate::util::get_padding_size; - -///chunkInitCommon represents an SCTP Chunk body of type INIT and INIT ACK -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 1 | Chunk Flags | Chunk Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Initiate Tag | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Advertised Receiver Window Credit (a_rwnd) | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Number of Outbound Streams | Number of Inbound Streams | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Initial TSN | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| Optional/Variable-Length Parameters | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -///The INIT chunk contains the following parameters. Unless otherwise -///noted, each parameter MUST only be included once in the INIT chunk. -/// -///Fixed Parameters Status -///---------------------------------------------- -///Initiate Tag Mandatory -///Advertised Receiver Window Credit Mandatory -///Number of Outbound Streams Mandatory -///Number of Inbound Streams Mandatory -///Initial TSN Mandatory -/// -///Init represents an SCTP Chunk of type INIT -/// -///See chunkInitCommon for the fixed headers -/// -///Variable Parameters Status Type Value -///------------------------------------------------------------- -///IPv4 IP (Note 1) Optional 5 -///IPv6 IP (Note 1) Optional 6 -///Cookie Preservative Optional 9 -///Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) -///Host Name IP (Note 3) Optional 11 -///Supported IP Types (Note 4) Optional 12 -/// -/// -/// chunkInitAck represents an SCTP Chunk of type INIT ACK -/// -///See chunkInitCommon for the fixed headers -/// -///Variable Parameters Status Type Value -///------------------------------------------------------------- -///State Cookie Mandatory 7 -///IPv4 IP (Note 1) Optional 5 -///IPv6 IP (Note 1) Optional 6 -///Unrecognized Parameter Optional 8 -///Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) -///Host Name IP (Note 3) Optional 11 -#[derive(Default, Debug)] -pub(crate) struct ChunkInit { - pub(crate) is_ack: bool, - pub(crate) initiate_tag: u32, - pub(crate) advertised_receiver_window_credit: u32, - pub(crate) num_outbound_streams: u16, - pub(crate) num_inbound_streams: u16, - pub(crate) initial_tsn: u32, - pub(crate) params: Vec>, -} - -impl Clone for ChunkInit { - fn clone(&self) -> Self { - ChunkInit { - is_ack: self.is_ack, - initiate_tag: self.initiate_tag, - advertised_receiver_window_credit: self.advertised_receiver_window_credit, - num_outbound_streams: self.num_outbound_streams, - num_inbound_streams: self.num_inbound_streams, - initial_tsn: self.initial_tsn, - params: self.params.to_vec(), - } - } -} - -pub(crate) const INIT_CHUNK_MIN_LENGTH: usize = 16; -pub(crate) const INIT_OPTIONAL_VAR_HEADER_LENGTH: usize = 4; - -/// makes chunkInitCommon printable -impl fmt::Display for ChunkInit { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = format!( - "is_ack: {} - initiate_tag: {} - advertised_receiver_window_credit: {} - num_outbound_streams: {} - num_inbound_streams: {} - initial_tsn: {}", - self.is_ack, - self.initiate_tag, - self.advertised_receiver_window_credit, - self.num_outbound_streams, - self.num_inbound_streams, - self.initial_tsn, - ); - - for (i, param) in self.params.iter().enumerate() { - res += format!("Param {i}:\n {param}").as_str(); - } - write!(f, "{} {}", self.header(), res) - } -} - -impl Chunk for ChunkInit { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: if self.is_ack { CT_INIT_ACK } else { CT_INIT }, - flags: 0, - value_length: self.value_length() as u16, - } - } - - ///https://tools.ietf.org/html/rfc4960#section-3.2.1 - /// - ///Chunk values of SCTP control chunks consist of a chunk-type-specific - ///header of required fields, followed by zero or more parameters. The - ///optional and variable-length parameters contained in a chunk are - ///defined in a Type-Length-Value format as shown below. - /// - ///0 1 2 3 - ///0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - ///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - ///| Parameter Type | Parameter Length | - ///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - ///| | - ///| Parameter Value | - ///| | - ///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if !(header.typ == CT_INIT || header.typ == CT_INIT_ACK) { - return Err(Error::ErrChunkTypeNotTypeInit); - } else if header.value_length() < INIT_CHUNK_MIN_LENGTH { - // validity of value_length is checked in ChunkHeader::unmarshal - return Err(Error::ErrChunkValueNotLongEnough); - } - - // The Chunk Flags field in INIT is reserved, and all bits in it should - // be set to 0 by the sender and ignored by the receiver. The sequence - // of parameters within an INIT can be processed in any order. - if header.flags != 0 { - return Err(Error::ErrChunkTypeInitFlagZero); - } - - let reader = &mut raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()); - - let initiate_tag = reader.get_u32(); - let advertised_receiver_window_credit = reader.get_u32(); - let num_outbound_streams = reader.get_u16(); - let num_inbound_streams = reader.get_u16(); - let initial_tsn = reader.get_u32(); - - let mut params = vec![]; - let mut offset = CHUNK_HEADER_SIZE + INIT_CHUNK_MIN_LENGTH; - let mut remaining = raw.len() as isize - offset as isize; - while remaining >= INIT_OPTIONAL_VAR_HEADER_LENGTH as isize { - let p = build_param(&raw.slice(offset..CHUNK_HEADER_SIZE + header.value_length()))?; - let p_len = PARAM_HEADER_LENGTH + p.value_length(); - let len_plus_padding = p_len + get_padding_size(p_len); - params.push(p); - offset += len_plus_padding; - remaining -= len_plus_padding as isize; - } - - Ok(ChunkInit { - is_ack: header.typ == CT_INIT_ACK, - initiate_tag, - advertised_receiver_window_credit, - num_outbound_streams, - num_inbound_streams, - initial_tsn, - params, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - - writer.put_u32(self.initiate_tag); - writer.put_u32(self.advertised_receiver_window_credit); - writer.put_u16(self.num_outbound_streams); - writer.put_u16(self.num_inbound_streams); - writer.put_u32(self.initial_tsn); - for (idx, p) in self.params.iter().enumerate() { - let pp = p.marshal()?; - let pp_len = pp.len(); - writer.extend(pp); - - // Chunks (including Type, Length, and Value fields) are padded out - // by the sender with all zero bytes to be a multiple of 4 bytes - // long. This padding MUST NOT be more than 3 bytes in total. The - // Chunk Length value does not include terminating padding of the - // chunk. *However, it does include padding of any variable-length - // parameter except the last parameter in the chunk.* The receiver - // MUST ignore the padding. - if idx != self.params.len() - 1 { - let cnt = get_padding_size(pp_len); - writer.extend(vec![0u8; cnt]); - } - } - - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - // The receiver of the INIT (the responding end) records the value of - // the Initiate Tag parameter. This value MUST be placed into the - // Verification Tag field of every SCTP packet that the receiver of - // the INIT transmits within this association. - // - // The Initiate Tag is allowed to have any value except 0. See - // Section 5.3.1 for more on the selection of the tag value. - // - // If the value of the Initiate Tag in a received INIT chunk is found - // to be 0, the receiver MUST treat it as an error and close the - // association by transmitting an ABORT. - if self.initiate_tag == 0 { - return Err(Error::ErrChunkTypeInitInitiateTagZero); - } - - // Defines the maximum number of streams the sender of this INIT - // chunk allows the peer end to create in this association. The - // value 0 MUST NOT be used. - // - // Note: There is no negotiation of the actual number of streams but - // instead the two endpoints will use the min(requested, offered). - // See Section 5.1.1 for details. - // - // Note: A receiver of an INIT with the MIS value of 0 SHOULD abort - // the association. - if self.num_inbound_streams == 0 { - return Err(Error::ErrInitInboundStreamRequestZero); - } - - // Defines the number of outbound streams the sender of this INIT - // chunk wishes to create in this association. The value of 0 MUST - // NOT be used. - // - // Note: A receiver of an INIT with the OS value set to 0 SHOULD - // abort the association. - - if self.num_outbound_streams == 0 { - return Err(Error::ErrInitOutboundStreamRequestZero); - } - - // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in - // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate - // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT - // ACK. - if self.advertised_receiver_window_credit < 1500 { - return Err(Error::ErrInitAdvertisedReceiver1500); - } - - Ok(()) - } - - fn value_length(&self) -> usize { - let mut l = 4 + 4 + 2 + 2 + 4; - for (idx, p) in self.params.iter().enumerate() { - let p_len = PARAM_HEADER_LENGTH + p.value_length(); - l += p_len; - if idx != self.params.len() - 1 { - l += get_padding_size(p_len); - } - } - l - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} - -impl ChunkInit { - pub(crate) fn set_supported_extensions(&mut self) { - // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 - // An implementation supporting this (Supported Extensions Parameter) - // extension MUST list the ASCONF, the ASCONF-ACK, and the AUTH chunks - // in its INIT and INIT-ACK parameters. - self.params.push(Box::new(ParamSupportedExtensions { - chunk_types: vec![CT_RECONFIG, CT_FORWARD_TSN], - })); - } -} diff --git a/sctp/src/chunk/chunk_payload_data.rs b/sctp/src/chunk/chunk_payload_data.rs deleted file mode 100644 index 2a2da26ad..000000000 --- a/sctp/src/chunk/chunk_payload_data.rs +++ /dev/null @@ -1,272 +0,0 @@ -use std::fmt; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use portable_atomic::AtomicBool; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -pub(crate) const PAYLOAD_DATA_ENDING_FRAGMENT_BITMASK: u8 = 1; -pub(crate) const PAYLOAD_DATA_BEGINNING_FRAGMENT_BITMASK: u8 = 2; -pub(crate) const PAYLOAD_DATA_UNORDERED_BITMASK: u8 = 4; -pub(crate) const PAYLOAD_DATA_IMMEDIATE_SACK: u8 = 8; -pub(crate) const PAYLOAD_DATA_HEADER_SIZE: usize = 12; - -/// PayloadProtocolIdentifier is an enum for DataChannel payload types -/// PayloadProtocolIdentifier enums -/// https://www.iana.org/assignments/sctp-parameters/sctp-parameters.xhtml#sctp-parameters-25 -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -#[repr(C)] -pub enum PayloadProtocolIdentifier { - Dcep = 50, - String = 51, - Binary = 53, - StringEmpty = 56, - BinaryEmpty = 57, - #[default] - Unknown, -} - -impl fmt::Display for PayloadProtocolIdentifier { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - PayloadProtocolIdentifier::Dcep => "WebRTC DCEP", - PayloadProtocolIdentifier::String => "WebRTC String", - PayloadProtocolIdentifier::Binary => "WebRTC Binary", - PayloadProtocolIdentifier::StringEmpty => "WebRTC String (Empty)", - PayloadProtocolIdentifier::BinaryEmpty => "WebRTC Binary (Empty)", - _ => "Unknown Payload Protocol Identifier", - }; - write!(f, "{s}") - } -} - -impl From for PayloadProtocolIdentifier { - fn from(v: u32) -> PayloadProtocolIdentifier { - match v { - 50 => PayloadProtocolIdentifier::Dcep, - 51 => PayloadProtocolIdentifier::String, - 53 => PayloadProtocolIdentifier::Binary, - 56 => PayloadProtocolIdentifier::StringEmpty, - 57 => PayloadProtocolIdentifier::BinaryEmpty, - _ => PayloadProtocolIdentifier::Unknown, - } - } -} - -///chunkPayloadData represents an SCTP Chunk of type DATA -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 0 | Reserved|U|B|E| Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| TSN | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Stream Identifier S | Stream Sequence Number n | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Payload Protocol Identifier | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| User Data (seq n of Stream S) | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// -///An unfragmented user message shall have both the B and E bits set to -///'1'. Setting both B and E bits to '0' indicates a middle fragment of -///a multi-fragment user message, as summarized in the following table: -/// B E Description -///============================================================ -///| 1 0 | First piece of a fragmented user message | -///+----------------------------------------------------------+ -///| 0 0 | Middle piece of a fragmented user message | -///+----------------------------------------------------------+ -///| 0 1 | Last piece of a fragmented user message | -///+----------------------------------------------------------+ -///| 1 1 | Unfragmented message | -///============================================================ -///| Table 1: Fragment Description Flags | -///============================================================ -#[derive(Debug, Clone)] -pub struct ChunkPayloadData { - pub(crate) unordered: bool, - pub(crate) beginning_fragment: bool, - pub(crate) ending_fragment: bool, - pub(crate) immediate_sack: bool, - - pub(crate) tsn: u32, - pub(crate) stream_identifier: u16, - pub(crate) stream_sequence_number: u16, - pub(crate) payload_type: PayloadProtocolIdentifier, - pub(crate) user_data: Bytes, - - /// Whether this data chunk was acknowledged (received by peer) - pub(crate) acked: bool, - pub(crate) miss_indicator: u32, - - /// Partial-reliability parameters used only by sender - pub(crate) since: SystemTime, - /// number of transmission made for this chunk - pub(crate) nsent: u32, - - /// valid only with the first fragment - pub(crate) abandoned: Arc, - /// valid only with the first fragment - pub(crate) all_inflight: Arc, - - /// Retransmission flag set when T1-RTX timeout occurred and this - /// chunk is still in the inflight queue - pub(crate) retransmit: bool, -} - -impl Default for ChunkPayloadData { - fn default() -> Self { - ChunkPayloadData { - unordered: false, - beginning_fragment: false, - ending_fragment: false, - immediate_sack: false, - tsn: 0, - stream_identifier: 0, - stream_sequence_number: 0, - payload_type: PayloadProtocolIdentifier::default(), - user_data: Bytes::new(), - acked: false, - miss_indicator: 0, - since: SystemTime::now(), - nsent: 0, - abandoned: Arc::new(AtomicBool::new(false)), - all_inflight: Arc::new(AtomicBool::new(false)), - retransmit: false, - } - } -} - -/// makes chunkPayloadData printable -impl fmt::Display for ChunkPayloadData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}\n{}", self.header(), self.tsn) - } -} - -impl Chunk for ChunkPayloadData { - fn header(&self) -> ChunkHeader { - let mut flags: u8 = 0; - if self.ending_fragment { - flags = 1; - } - if self.beginning_fragment { - flags |= 1 << 1; - } - if self.unordered { - flags |= 1 << 2; - } - if self.immediate_sack { - flags |= 1 << 3; - } - - ChunkHeader { - typ: CT_PAYLOAD_DATA, - flags, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_PAYLOAD_DATA { - return Err(Error::ErrChunkTypeNotPayloadData); - } - - let immediate_sack = (header.flags & PAYLOAD_DATA_IMMEDIATE_SACK) != 0; - let unordered = (header.flags & PAYLOAD_DATA_UNORDERED_BITMASK) != 0; - let beginning_fragment = (header.flags & PAYLOAD_DATA_BEGINNING_FRAGMENT_BITMASK) != 0; - let ending_fragment = (header.flags & PAYLOAD_DATA_ENDING_FRAGMENT_BITMASK) != 0; - - // validity of value_length is checked in ChunkHeader::unmarshal - if header.value_length() < PAYLOAD_DATA_HEADER_SIZE { - return Err(Error::ErrChunkPayloadSmall); - } - - let reader = &mut raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()); - - let tsn = reader.get_u32(); - let stream_identifier = reader.get_u16(); - let stream_sequence_number = reader.get_u16(); - let payload_type: PayloadProtocolIdentifier = reader.get_u32().into(); - let user_data = raw.slice( - CHUNK_HEADER_SIZE + PAYLOAD_DATA_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length(), - ); - - Ok(ChunkPayloadData { - unordered, - beginning_fragment, - ending_fragment, - immediate_sack, - - tsn, - stream_identifier, - stream_sequence_number, - payload_type, - user_data, - acked: false, - miss_indicator: 0, - since: SystemTime::now(), - nsent: 0, - abandoned: Arc::new(AtomicBool::new(false)), - all_inflight: Arc::new(AtomicBool::new(false)), - retransmit: false, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - - writer.put_u32(self.tsn); - writer.put_u16(self.stream_identifier); - writer.put_u16(self.stream_sequence_number); - writer.put_u32(self.payload_type as u32); - writer.extend_from_slice(&self.user_data); - - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - PAYLOAD_DATA_HEADER_SIZE + self.user_data.len() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} - -impl ChunkPayloadData { - pub(crate) fn abandoned(&self) -> bool { - let (abandoned, all_inflight) = ( - self.abandoned.load(Ordering::SeqCst), - self.all_inflight.load(Ordering::SeqCst), - ); - - abandoned && all_inflight - } - - pub(crate) fn set_abandoned(&self, abandoned: bool) { - self.abandoned.store(abandoned, Ordering::SeqCst); - } - - pub(crate) fn set_all_inflight(&mut self) { - if self.ending_fragment { - self.all_inflight.store(true, Ordering::SeqCst); - } - } -} diff --git a/sctp/src/chunk/chunk_reconfig.rs b/sctp/src/chunk/chunk_reconfig.rs deleted file mode 100644 index 2282857d2..000000000 --- a/sctp/src/chunk/chunk_reconfig.rs +++ /dev/null @@ -1,133 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; -use crate::param::param_header::*; -use crate::param::*; -use crate::util::get_padding_size; - -///https://tools.ietf.org/html/rfc6525#section-3.1 -///chunkReconfig represents an SCTP Chunk used to reconfigure streams. -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 130 | Chunk Flags | Chunk Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| Re-configuration Parameter | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| Re-configuration Parameter (optional) | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug)] -pub(crate) struct ChunkReconfig { - pub(crate) param_a: Option>, - pub(crate) param_b: Option>, -} - -impl Clone for ChunkReconfig { - fn clone(&self) -> Self { - ChunkReconfig { - param_a: self.param_a.as_ref().cloned(), - param_b: self.param_b.as_ref().cloned(), - } - } -} - -/// makes chunkReconfig printable -impl fmt::Display for ChunkReconfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = String::new(); - if let Some(param_a) = &self.param_a { - res += format!("Param A:\n {param_a}").as_str(); - } - if let Some(param_b) = &self.param_b { - res += format!("Param B:\n {param_b}").as_str() - } - write!(f, "{res}") - } -} - -impl Chunk for ChunkReconfig { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_RECONFIG, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_RECONFIG { - return Err(Error::ErrChunkTypeNotReconfig); - } - - let param_a = - build_param(&raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()))?; - - let padding = get_padding_size(PARAM_HEADER_LENGTH + param_a.value_length()); - let offset = CHUNK_HEADER_SIZE + PARAM_HEADER_LENGTH + param_a.value_length() + padding; - let param_b = if CHUNK_HEADER_SIZE + header.value_length() > offset { - Some(build_param( - &raw.slice(offset..CHUNK_HEADER_SIZE + header.value_length()), - )?) - } else { - None - }; - - Ok(ChunkReconfig { - param_a: Some(param_a), - param_b, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - - let param_a_value_length = if let Some(param_a) = &self.param_a { - writer.extend(param_a.marshal()?); - param_a.value_length() - } else { - return Err(Error::ErrChunkReconfigInvalidParamA); - }; - - if let Some(param_b) = &self.param_b { - // Pad param A - let padding = get_padding_size(PARAM_HEADER_LENGTH + param_a_value_length); - writer.extend(vec![0u8; padding]); - writer.extend(param_b.marshal()?); - } - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - let mut l = PARAM_HEADER_LENGTH; - let param_a_value_length = if let Some(param_a) = &self.param_a { - l += param_a.value_length(); - param_a.value_length() - } else { - 0 - }; - if let Some(param_b) = &self.param_b { - let padding = get_padding_size(PARAM_HEADER_LENGTH + param_a_value_length); - l += PARAM_HEADER_LENGTH + param_b.value_length() + padding; - } - l - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_selective_ack.rs b/sctp/src/chunk/chunk_selective_ack.rs deleted file mode 100644 index 38e22e6c1..000000000 --- a/sctp/src/chunk/chunk_selective_ack.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -///chunkSelectiveAck represents an SCTP Chunk of type SACK -/// -///This chunk is sent to the peer endpoint to acknowledge received DATA -///chunks and to inform the peer endpoint of gaps in the received -///subsequences of DATA chunks as represented by their TSNs. -///0 1 2 3 -///0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 3 |Chunk Flags | Chunk Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Cumulative TSN Ack | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Advertised Receiver Window Credit (a_rwnd) | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Number of Gap Ack Blocks = N | Number of Duplicate TSNs = X | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Gap Ack Block #1 Start | Gap Ack Block #1 End | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| ... | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Gap Ack Block #N Start | Gap Ack Block #N End | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Duplicate TSN 1 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| | -///| ... | -///| | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Duplicate TSN X | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Debug, Default, Copy, Clone)] -pub(crate) struct GapAckBlock { - pub(crate) start: u16, - pub(crate) end: u16, -} - -/// makes gapAckBlock printable -impl fmt::Display for GapAckBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} - {}", self.start, self.end) - } -} - -#[derive(Default, Debug)] -pub(crate) struct ChunkSelectiveAck { - pub(crate) cumulative_tsn_ack: u32, - pub(crate) advertised_receiver_window_credit: u32, - pub(crate) gap_ack_blocks: Vec, - pub(crate) duplicate_tsn: Vec, -} - -/// makes chunkSelectiveAck printable -impl fmt::Display for ChunkSelectiveAck { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = format!( - "SACK cumTsnAck={} arwnd={} dupTsn={:?}", - self.cumulative_tsn_ack, self.advertised_receiver_window_credit, self.duplicate_tsn - ); - - for gap in &self.gap_ack_blocks { - res += format!("\n gap ack: {gap}").as_str(); - } - - write!(f, "{res}") - } -} - -pub(crate) const SELECTIVE_ACK_HEADER_SIZE: usize = 12; - -impl Chunk for ChunkSelectiveAck { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_SACK, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_SACK { - return Err(Error::ErrChunkTypeNotSack); - } - - // validity of value_length is checked in ChunkHeader::unmarshal - if header.value_length() < SELECTIVE_ACK_HEADER_SIZE { - return Err(Error::ErrSackSizeNotLargeEnoughInfo); - } - - let reader = &mut raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()); - - let cumulative_tsn_ack = reader.get_u32(); - let advertised_receiver_window_credit = reader.get_u32(); - let gap_ack_blocks_len = reader.get_u16() as usize; - let duplicate_tsn_len = reader.get_u16() as usize; - - // Here we must account for case where the buffer contains another chunk - // right after this one. Testing for equality would incorrectly fail the - // parsing of this chunk and incorrectly close the transport. - - // validity of value_length is checked in ChunkHeader::unmarshal - if header.value_length() - < SELECTIVE_ACK_HEADER_SIZE + (4 * gap_ack_blocks_len + 4 * duplicate_tsn_len) - { - return Err(Error::ErrSackSizeNotLargeEnoughInfo); - } - - let mut gap_ack_blocks = vec![]; - let mut duplicate_tsn = vec![]; - for _ in 0..gap_ack_blocks_len { - let start = reader.get_u16(); - let end = reader.get_u16(); - gap_ack_blocks.push(GapAckBlock { start, end }); - } - for _ in 0..duplicate_tsn_len { - duplicate_tsn.push(reader.get_u32()); - } - - Ok(ChunkSelectiveAck { - cumulative_tsn_ack, - advertised_receiver_window_credit, - gap_ack_blocks, - duplicate_tsn, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - - writer.put_u32(self.cumulative_tsn_ack); - writer.put_u32(self.advertised_receiver_window_credit); - writer.put_u16(self.gap_ack_blocks.len() as u16); - writer.put_u16(self.duplicate_tsn.len() as u16); - for g in &self.gap_ack_blocks { - writer.put_u16(g.start); - writer.put_u16(g.end); - } - for t in &self.duplicate_tsn { - writer.put_u32(*t); - } - - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - SELECTIVE_ACK_HEADER_SIZE + self.gap_ack_blocks.len() * 4 + self.duplicate_tsn.len() * 4 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_shutdown.rs b/sctp/src/chunk/chunk_shutdown.rs deleted file mode 100644 index 44a3e4547..000000000 --- a/sctp/src/chunk/chunk_shutdown.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -///chunkShutdown represents an SCTP Chunk of type chunkShutdown -/// -///0 1 2 3 -///0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 7 | Chunk Flags | Length = 8 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Cumulative TSN Ack | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkShutdown { - pub(crate) cumulative_tsn_ack: u32, -} - -pub(crate) const CUMULATIVE_TSN_ACK_LENGTH: usize = 4; - -/// makes chunkShutdown printable -impl fmt::Display for ChunkShutdown { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkShutdown { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_SHUTDOWN, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_SHUTDOWN { - return Err(Error::ErrChunkTypeNotShutdown); - } - - if raw.len() != CHUNK_HEADER_SIZE + CUMULATIVE_TSN_ACK_LENGTH { - return Err(Error::ErrInvalidChunkSize); - } - - let reader = &mut raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + header.value_length()); - - let cumulative_tsn_ack = reader.get_u32(); - - Ok(ChunkShutdown { cumulative_tsn_ack }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - writer.put_u32(self.cumulative_tsn_ack); - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - CUMULATIVE_TSN_ACK_LENGTH - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_shutdown_ack.rs b/sctp/src/chunk/chunk_shutdown_ack.rs deleted file mode 100644 index d22c86739..000000000 --- a/sctp/src/chunk/chunk_shutdown_ack.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -///chunkShutdownAck represents an SCTP Chunk of type chunkShutdownAck -/// -///0 1 2 3 -///0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 8 | Chunk Flags | Length = 4 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkShutdownAck; - -/// makes chunkShutdownAck printable -impl fmt::Display for ChunkShutdownAck { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkShutdownAck { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_SHUTDOWN_ACK, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_SHUTDOWN_ACK { - return Err(Error::ErrChunkTypeNotShutdownAck); - } - - Ok(ChunkShutdownAck {}) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - 0 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_shutdown_complete.rs b/sctp/src/chunk/chunk_shutdown_complete.rs deleted file mode 100644 index c84da65a6..000000000 --- a/sctp/src/chunk/chunk_shutdown_complete.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; - -use super::chunk_header::*; -use super::chunk_type::*; -use super::*; - -///chunkShutdownComplete represents an SCTP Chunk of type chunkShutdownComplete -/// -///0 1 2 3 -///0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Type = 14 |Reserved |T| Length = 4 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone)] -pub(crate) struct ChunkShutdownComplete; - -/// makes chunkShutdownComplete printable -impl fmt::Display for ChunkShutdownComplete { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Chunk for ChunkShutdownComplete { - fn header(&self) -> ChunkHeader { - ChunkHeader { - typ: CT_SHUTDOWN_COMPLETE, - flags: 0, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ChunkHeader::unmarshal(raw)?; - - if header.typ != CT_SHUTDOWN_COMPLETE { - return Err(Error::ErrChunkTypeNotShutdownComplete); - } - - Ok(ChunkShutdownComplete {}) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - self.header().marshal_to(writer)?; - Ok(writer.len()) - } - - fn check(&self) -> Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - 0 - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/chunk/chunk_test.rs b/sctp/src/chunk/chunk_test.rs deleted file mode 100644 index f16169661..000000000 --- a/sctp/src/chunk/chunk_test.rs +++ /dev/null @@ -1,752 +0,0 @@ -/////////////////////////////////////////////////////////////////// -//chunk_type_test -/////////////////////////////////////////////////////////////////// -use super::chunk_type::*; -use super::*; - -#[test] -fn test_chunk_type_string() -> Result<()> { - let tests = vec![ - (CT_PAYLOAD_DATA, "DATA"), - (CT_INIT, "INIT"), - (CT_INIT_ACK, "INIT-ACK"), - (CT_SACK, "SACK"), - (CT_HEARTBEAT, "HEARTBEAT"), - (CT_HEARTBEAT_ACK, "HEARTBEAT-ACK"), - (CT_ABORT, "ABORT"), - (CT_SHUTDOWN, "SHUTDOWN"), - (CT_SHUTDOWN_ACK, "SHUTDOWN-ACK"), - (CT_ERROR, "ERROR"), - (CT_COOKIE_ECHO, "COOKIE-ECHO"), - (CT_COOKIE_ACK, "COOKIE-ACK"), - (CT_ECNE, "ECNE"), - (CT_CWR, "CWR"), - (CT_SHUTDOWN_COMPLETE, "SHUTDOWN-COMPLETE"), - (CT_RECONFIG, "RECONFIG"), - (CT_FORWARD_TSN, "FORWARD-TSN"), - (ChunkType(255), "Unknown ChunkType: 255"), - ]; - - for (ct, expected) in tests { - assert_eq!( - ct.to_string(), - expected, - "failed to stringify chunkType {ct}, expected {expected}" - ); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_abort_test -/////////////////////////////////////////////////////////////////// -use super::chunk_abort::*; -use crate::error_cause::*; - -#[test] -fn test_abort_chunk_one_error_cause() -> Result<()> { - let abort1 = ChunkAbort { - error_causes: vec![ErrorCause { - code: PROTOCOL_VIOLATION, - ..Default::default() - }], - }; - - let b = abort1.marshal()?; - let abort2 = ChunkAbort::unmarshal(&b)?; - - assert_eq!(abort2.error_causes.len(), 1, "should have only one cause"); - assert_eq!( - abort2.error_causes[0].error_cause_code(), - abort1.error_causes[0].error_cause_code(), - "errorCause code should match" - ); - - Ok(()) -} - -#[test] -fn test_abort_chunk_many_error_causes() -> Result<()> { - let abort1 = ChunkAbort { - error_causes: vec![ - ErrorCause { - code: INVALID_MANDATORY_PARAMETER, - ..Default::default() - }, - ErrorCause { - code: UNRECOGNIZED_CHUNK_TYPE, - ..Default::default() - }, - ErrorCause { - code: PROTOCOL_VIOLATION, - ..Default::default() - }, - ], - }; - - let b = abort1.marshal()?; - let abort2 = ChunkAbort::unmarshal(&b)?; - assert_eq!(abort2.error_causes.len(), 3, "should have only one cause"); - for (i, error_cause) in abort1.error_causes.iter().enumerate() { - assert_eq!( - abort2.error_causes[i].error_cause_code(), - error_cause.error_cause_code(), - "errorCause code should match" - ); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_error_test -/////////////////////////////////////////////////////////////////// -use bytes::BufMut; -use lazy_static::lazy_static; - -use super::chunk_error::*; - -const CHUNK_FLAGS: u8 = 0x00; -static ORG_UNRECOGNIZED_CHUNK: Bytes = - Bytes::from_static(&[0xc0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x3]); - -lazy_static! { - static ref RAW_IN: Bytes = { - let mut raw = BytesMut::new(); - raw.put_u8(CT_ERROR.0); - raw.put_u8(CHUNK_FLAGS); - raw.extend(vec![0x00, 0x10, 0x00, 0x06, 0x00, 0x0c]); - raw.extend(ORG_UNRECOGNIZED_CHUNK.clone()); - raw.freeze() - }; -} - -#[test] -fn test_chunk_error_unrecognized_chunk_type_unmarshal() -> Result<()> { - let c = ChunkError::unmarshal(&RAW_IN)?; - assert_eq!(c.header().typ, CT_ERROR, "chunk type should be ERROR"); - assert_eq!(c.error_causes.len(), 1, "there should be on errorCause"); - - let ec = &c.error_causes[0]; - assert_eq!( - ec.error_cause_code(), - UNRECOGNIZED_CHUNK_TYPE, - "cause code should be unrecognizedChunkType" - ); - assert_eq!( - ec.raw, ORG_UNRECOGNIZED_CHUNK, - "should have valid unrecognizedChunk" - ); - - Ok(()) -} - -#[test] -fn test_chunk_error_unrecognized_chunk_type_marshal() -> Result<()> { - let ec_unrecognized_chunk_type = ErrorCause { - code: UNRECOGNIZED_CHUNK_TYPE, - raw: ORG_UNRECOGNIZED_CHUNK.clone(), - }; - - let ec = ChunkError { - error_causes: vec![ec_unrecognized_chunk_type], - }; - - let raw = ec.marshal()?; - assert_eq!(raw, *RAW_IN, "unexpected serialization result"); - - Ok(()) -} - -#[test] -fn test_chunk_error_unrecognized_chunk_type_marshal_with_cause_value_being_nil() -> Result<()> { - let expected = - Bytes::from_static(&[CT_ERROR.0, CHUNK_FLAGS, 0x00, 0x08, 0x00, 0x06, 0x00, 0x04]); - let ec_unrecognized_chunk_type = ErrorCause { - code: UNRECOGNIZED_CHUNK_TYPE, - ..Default::default() - }; - - let ec = ChunkError { - error_causes: vec![ec_unrecognized_chunk_type], - }; - - let raw = ec.marshal()?; - assert_eq!(raw, expected, "unexpected serialization result"); - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_forward_tsn_test -/////////////////////////////////////////////////////////////////// -use super::chunk_forward_tsn::*; - -static CHUNK_FORWARD_TSN_BYTES: Bytes = - Bytes::from_static(&[0xc0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x3]); - -#[test] -fn test_chunk_forward_tsn_success() -> Result<()> { - let tests = vec![ - CHUNK_FORWARD_TSN_BYTES.clone(), - Bytes::from_static(&[0xc0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5]), - Bytes::from_static(&[ - 0xc0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6, 0x0, 0x7, - ]), - ]; - - for binary in tests { - let actual = ChunkForwardTsn::unmarshal(&binary)?; - let b = actual.marshal()?; - assert_eq!(b, binary, "test not equal"); - } - - Ok(()) -} - -#[test] -fn test_chunk_forward_tsn_unmarshal_failure() -> Result<()> { - let tests = vec![ - ("chunk header to short", Bytes::from_static(&[0xc0])), - ( - "missing New Cumulative TSN", - Bytes::from_static(&[0xc0, 0x0, 0x0, 0x4]), - ), - ( - "missing stream sequence", - Bytes::from_static(&[ - 0xc0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6, - ]), - ), - ]; - - for (name, binary) in tests { - let result = ChunkForwardTsn::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_reconfig_test -/////////////////////////////////////////////////////////////////// -use super::chunk_reconfig::*; - -static TEST_CHUNK_RECONFIG_PARAM_A: Bytes = Bytes::from_static(&[ - 0x0, 0xd, 0x0, 0x16, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, - 0x5, 0x0, 0x6, -]); - -static TEST_CHUNK_RECONFIG_PARAM_B: Bytes = Bytes::from_static(&[ - 0x0, 0xd, 0x0, 0x10, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, -]); - -static TEST_CHUNK_RECONFIG_RESPONSE: Bytes = - Bytes::from_static(&[0x0, 0x10, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1]); - -lazy_static! { - static ref TEST_CHUNK_RECONFIG_BYTES: Vec = { - let mut tests = vec![]; - { - let mut test = BytesMut::new(); - test.extend(vec![0x82, 0x0, 0x0, 0x1a]); - test.extend(TEST_CHUNK_RECONFIG_PARAM_A.clone()); - tests.push(test.freeze()); - } - { - let mut test = BytesMut::new(); - test.extend(vec![0x82, 0x0, 0x0, 0x14]); - test.extend(TEST_CHUNK_RECONFIG_PARAM_B.clone()); - tests.push(test.freeze()); - } - { - let mut test = BytesMut::new(); - test.extend(vec![0x82, 0x0, 0x0, 0x10]); - test.extend(TEST_CHUNK_RECONFIG_RESPONSE.clone()); - tests.push(test.freeze()); - } - { - let mut test = BytesMut::new(); - test.extend(vec![0x82, 0x0, 0x0, 0x2c]); - test.extend(TEST_CHUNK_RECONFIG_PARAM_A.clone()); - test.extend(vec![0u8; 2]); - test.extend(TEST_CHUNK_RECONFIG_PARAM_B.clone()); - tests.push(test.freeze()); - } - { - let mut test = BytesMut::new(); - test.extend(vec![0x82, 0x0, 0x0, 0x2a]); - test.extend(TEST_CHUNK_RECONFIG_PARAM_B.clone()); - test.extend(TEST_CHUNK_RECONFIG_PARAM_A.clone()); - tests.push(test.freeze()); - } - - tests - }; -} - -#[test] -fn test_chunk_reconfig_success() -> Result<()> { - for (i, binary) in TEST_CHUNK_RECONFIG_BYTES.iter().enumerate() { - let actual = ChunkReconfig::unmarshal(binary)?; - let b = actual.marshal()?; - assert_eq!(*binary, b, "test {} not equal: {:?} vs {:?}", i, *binary, b); - } - - Ok(()) -} - -#[test] -fn test_chunk_reconfig_unmarshal_failure() -> Result<()> { - let mut test = BytesMut::new(); - test.extend(vec![0x82, 0x0, 0x0, 0x18]); - test.extend(TEST_CHUNK_RECONFIG_PARAM_B.clone()); - test.extend(vec![0x0, 0xd, 0x0, 0x0]); - let tests = vec![ - ("chunk header to short", Bytes::from_static(&[0x82])), - ( - "missing parse param type (A)", - Bytes::from_static(&[0x82, 0x0, 0x0, 0x4]), - ), - ( - "wrong param (A)", - Bytes::from_static(&[0x82, 0x0, 0x0, 0x8, 0x0, 0xd, 0x0, 0x0]), - ), - ("wrong param (B)", test.freeze()), - ]; - - for (name, binary) in tests { - let result = ChunkReconfig::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_shutdown_test -/////////////////////////////////////////////////////////////////// -use super::chunk_shutdown::*; - -#[test] -fn test_chunk_shutdown_success() -> Result<()> { - let tests = vec![Bytes::from_static(&[ - 0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78, - ])]; - - for binary in tests { - let actual = ChunkShutdown::unmarshal(&binary)?; - let b = actual.marshal()?; - assert_eq!(b, binary, "test not equal"); - } - - Ok(()) -} - -#[test] -fn test_chunk_shutdown_failure() -> Result<()> { - let tests = vec![ - ( - "length too short", - Bytes::from_static(&[0x07, 0x00, 0x00, 0x07, 0x12, 0x34, 0x56, 0x78]), - ), - ( - "length too long", - Bytes::from_static(&[0x07, 0x00, 0x00, 0x09, 0x12, 0x34, 0x56, 0x78]), - ), - ( - "payload too short", - Bytes::from_static(&[0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56]), - ), - ( - "payload too long", - Bytes::from_static(&[0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78, 0x9f]), - ), - ( - "invalid type", - Bytes::from_static(&[0x08, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78]), - ), - ]; - - for (name, binary) in tests { - let result = ChunkShutdown::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_shutdown_ack_test -/////////////////////////////////////////////////////////////////// -use super::chunk_shutdown_ack::*; - -#[test] -fn test_chunk_shutdown_ack_success() -> Result<()> { - let tests = vec![Bytes::from_static(&[0x08, 0x00, 0x00, 0x04])]; - - for binary in tests { - let actual = ChunkShutdownAck::unmarshal(&binary)?; - let b = actual.marshal()?; - assert_eq!(binary, b, "test not equal"); - } - - Ok(()) -} - -#[test] -fn test_chunk_shutdown_ack_failure() -> Result<()> { - let tests = vec![ - ("length too short", Bytes::from_static(&[0x08, 0x00, 0x00])), - ( - "length too long", - Bytes::from_static(&[0x08, 0x00, 0x00, 0x04, 0x12]), - ), - ( - "invalid type", - Bytes::from_static(&[0x0f, 0x00, 0x00, 0x04]), - ), - ]; - - for (name, binary) in tests { - let result = ChunkShutdownAck::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_shutdown_complete_test -/////////////////////////////////////////////////////////////////// -use super::chunk_shutdown_complete::*; - -#[test] -fn test_chunk_shutdown_complete_success() -> Result<()> { - let tests = vec![Bytes::from_static(&[0x0e, 0x00, 0x00, 0x04])]; - - for binary in tests { - let actual = ChunkShutdownComplete::unmarshal(&binary)?; - let b = actual.marshal()?; - assert_eq!(b, binary, "test not equal"); - } - - Ok(()) -} - -#[test] -fn test_chunk_shutdown_complete_failure() -> Result<()> { - let tests = vec![ - ("length too short", Bytes::from_static(&[0x0e, 0x00, 0x00])), - ( - "length too long", - Bytes::from_static(&[0x0e, 0x00, 0x00, 0x04, 0x12]), - ), - ( - "invalid type", - Bytes::from_static(&[0x0f, 0x00, 0x00, 0x04]), - ), - ]; - - for (name, binary) in tests { - let result = ChunkShutdownComplete::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//chunk_test -/////////////////////////////////////////////////////////////////// -use crate::chunk::chunk_init::*; -use crate::chunk::chunk_payload_data::*; -use crate::chunk::chunk_selective_ack::ChunkSelectiveAck; -use crate::packet::*; -use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest; -use crate::param::param_state_cookie::*; - -#[test] -fn test_init_chunk() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, - 0x56, 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, - 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, - 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, - 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, - 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, - 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - - if let Some(c) = pkt.chunks[0].as_any().downcast_ref::() { - assert_eq!( - c.initiate_tag, 1438213285, - "Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: {} act: {}", - 1438213285, c.initiate_tag - ); - assert_eq!(c.advertised_receiver_window_credit, 131072, "Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: {} act: {}", 131072, c.advertised_receiver_window_credit); - assert_eq!(c.num_outbound_streams, 1024, "Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp:{} act: {}", 1024, c.num_outbound_streams); - assert_eq!( - c.num_inbound_streams, 2048, - "Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: {} act: {}", - 2048, c.num_inbound_streams - ); - assert_eq!( - c.initial_tsn, 3899461680u32, - "Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: {} act: {}", - 3899461680u32, c.initial_tsn - ); - } else { - panic!("Failed to cast Chunk -> Init"); - } - - Ok(()) -} - -#[test] -fn test_init_ack() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0x96, 0x19, 0xe8, 0xb2, 0x02, 0x00, 0x00, - 0x1c, 0xeb, 0x81, 0x4e, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x50, 0xdf, - 0x90, 0xd9, 0x00, 0x07, 0x00, 0x08, 0x94, 0x06, 0x2f, 0x93, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - assert!( - pkt.chunks[0].as_any().downcast_ref::().is_some(), - "Failed to cast Chunk -> Init" - ); - - Ok(()) -} - -#[test] -fn test_chrome_chunk1_init() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xbc, 0xb3, 0x45, 0xa2, 0x01, 0x00, 0x00, - 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, - 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, - 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, - 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, - 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, - 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - let raw_pkt2 = pkt.marshal()?; - assert_eq!(raw_pkt2, raw_pkt); - - Ok(()) -} - -#[test] -fn test_chrome_chunk2_init_ack() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0xb5, 0xdb, 0x2d, 0x93, 0x02, 0x00, 0x01, - 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, - 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, - 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, - 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, - 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, - 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x00, 0x07, 0x01, 0x38, 0x4b, - 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00, - 0x9c, 0x1e, 0x49, 0x5b, 0x00, 0x00, 0x00, 0x00, 0xd2, 0x42, 0x06, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x60, 0xea, 0x00, 0x00, 0xc4, 0x13, 0x3d, 0xe9, 0x86, 0xb1, 0x85, 0x75, 0xa2, 0x79, - 0x15, 0xce, 0x9b, 0xd5, 0xb3, 0x6f, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0xe0, 0x9f, 0x89, - 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, - 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, - 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, - 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, - 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, - 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, - 0x80, 0xc1, 0x00, 0x00, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, - 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, - 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, - 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, - 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, - 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, - 0x00, 0x00, 0xca, 0x0c, 0x21, 0x11, 0xce, 0xf4, 0xfc, 0xb3, 0x66, 0x99, 0x4f, 0xdb, 0x4f, - 0x95, 0x6b, 0x6f, 0x3b, 0xb1, 0xdb, 0x5a, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - let raw_pkt2 = pkt.marshal()?; - assert_eq!(raw_pkt2, raw_pkt); - - Ok(()) -} - -#[test] -fn test_init_marshal_unmarshal() -> Result<()> { - let mut p = Packet { - destination_port: 1, - source_port: 1, - verification_tag: 123, - chunks: vec![], - }; - - let mut init_ack = ChunkInit { - is_ack: true, - initiate_tag: 123, - advertised_receiver_window_credit: 1024, - num_outbound_streams: 1, - num_inbound_streams: 1, - initial_tsn: 123, - params: vec![], - }; - - let cookie = Box::new(ParamStateCookie::new()); - init_ack.params.push(cookie); - - p.chunks.push(Box::new(init_ack)); - - let raw_pkt = p.marshal()?; - let pkt = Packet::unmarshal(&raw_pkt)?; - - if let Some(c) = pkt.chunks[0].as_any().downcast_ref::() { - assert_eq!( - c.initiate_tag, 123, - "Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: {} act: {}", - 123, c.initiate_tag - ); - assert_eq!(c.advertised_receiver_window_credit, 1024, "Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: {} act: {}", 1024, c.advertised_receiver_window_credit); - assert_eq!(c.num_outbound_streams, 1, "Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp:{} act: {}", 1, c.num_outbound_streams); - assert_eq!( - c.num_inbound_streams, 1, - "Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: {} act: {}", - 1, c.num_inbound_streams - ); - assert_eq!( - c.initial_tsn, 123, - "Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: {} act: {}", - 123, c.initial_tsn - ); - } else { - panic!("Failed to cast Chunk -> InitAck"); - } - - Ok(()) -} - -#[test] -fn test_payload_data_marshal_unmarshal() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0xfc, 0xd6, 0x3f, 0xc6, 0xbe, 0xfa, 0xdc, 0x52, 0x0a, 0x00, 0x00, - 0x24, 0x9b, 0x28, 0x7e, 0x48, 0xa3, 0x7b, 0xc1, 0x83, 0xc4, 0x4b, 0x41, 0x04, 0xa4, 0xf7, - 0xed, 0x4c, 0x93, 0x62, 0xc3, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1f, 0xa8, 0x79, 0xa1, 0xc7, 0x00, 0x01, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x32, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, - 0x00, 0x66, 0x6f, 0x6f, 0x00, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - assert!( - pkt.chunks[1] - .as_any() - .downcast_ref::() - .is_some(), - "Failed to cast Chunk -> PayloadData" - ); - Ok(()) -} - -#[test] -fn test_select_ack_chunk() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x42, 0x31, 0xea, 0x78, 0x03, 0x00, 0x00, - 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, 0xfe, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, - 0x00, 0x02, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - assert!( - pkt.chunks[0] - .as_any() - .downcast_ref::() - .is_some(), - "Failed to cast Chunk -> SelectiveAck" - ); - Ok(()) -} - -#[test] -fn test_reconfig_chunk() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x75, 0x3b, 0x12, 0xd3, 0x82, 0x0, 0x0, - 0x16, 0x0, 0xd, 0x0, 0x12, 0x4e, 0x1c, 0xb9, 0xe6, 0x3a, 0x74, 0x8d, 0xff, 0x4e, 0x1c, - 0xb9, 0xe6, 0x0, 0x1, 0x0, 0x0, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - if let Some(c) = pkt.chunks[0].as_any().downcast_ref::() { - assert!(c.param_a.is_some(), "param_a must not be none"); - assert_eq!( - c.param_a - .as_ref() - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .stream_identifiers[0], - 1, - "unexpected stream identifier" - ); - } else { - panic!("Failed to cast Chunk -> Reconfig"); - } - - Ok(()) -} - -#[test] -fn test_forward_tsn_chunk() -> Result<()> { - let mut raw_pkt = BytesMut::new(); - raw_pkt.extend(vec![ - 0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x1f, 0x9d, 0xa0, 0xfb, - ]); - raw_pkt.extend(CHUNK_FORWARD_TSN_BYTES.clone()); - let raw_pkt = raw_pkt.freeze(); - let pkt = Packet::unmarshal(&raw_pkt)?; - - if let Some(c) = pkt.chunks[0].as_any().downcast_ref::() { - assert_eq!( - c.new_cumulative_tsn, 3, - "unexpected New Cumulative TSN: {}", - c.new_cumulative_tsn - ); - } else { - panic!("Failed to cast Chunk -> Forward TSN"); - } - - Ok(()) -} - -#[test] -fn test_select_ack_chunk_followed_by_a_payload_data_chunk() -> Result<()> { - let raw_pkt = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x58, 0xcf, 0x38, - 0xC0, // A SACK chunk follows. - 0x03, 0x00, 0x00, 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, 0xfe, 0x74, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x02, 0x00, 0x02, // A payload data chunk follows. - 0x00, 0x07, 0x00, 0x3B, 0xA4, 0x50, 0x7B, 0xC5, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x33, 0x7B, 0x22, 0x65, 0x76, 0x65, 0x6E, 0x74, 0x22, 0x3A, 0x22, 0x72, 0x65, 0x73, 0x69, - 0x7A, 0x65, 0x22, 0x2C, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3A, 0x36, 0x36, 0x35, - 0x2C, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3A, 0x34, 0x39, 0x39, 0x7D, 0x00, - ]); - let pkt = Packet::unmarshal(&raw_pkt)?; - assert!( - pkt.chunks[0] - .as_any() - .downcast_ref::() - .is_some(), - "Failed to cast Chunk -> SelectiveAck" - ); - assert!( - pkt.chunks[1] - .as_any() - .downcast_ref::() - .is_some(), - "Failed to cast Chunk -> PayloadData" - ); - Ok(()) -} diff --git a/sctp/src/chunk/chunk_type.rs b/sctp/src/chunk/chunk_type.rs deleted file mode 100644 index 60db10985..000000000 --- a/sctp/src/chunk/chunk_type.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::fmt; - -// chunkType is an enum for SCTP Chunk Type field -// This field identifies the type of information contained in the -// Chunk Value field. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] -pub(crate) struct ChunkType(pub(crate) u8); - -pub(crate) const CT_PAYLOAD_DATA: ChunkType = ChunkType(0); -pub(crate) const CT_INIT: ChunkType = ChunkType(1); -pub(crate) const CT_INIT_ACK: ChunkType = ChunkType(2); -pub(crate) const CT_SACK: ChunkType = ChunkType(3); -pub(crate) const CT_HEARTBEAT: ChunkType = ChunkType(4); -pub(crate) const CT_HEARTBEAT_ACK: ChunkType = ChunkType(5); -pub(crate) const CT_ABORT: ChunkType = ChunkType(6); -pub(crate) const CT_SHUTDOWN: ChunkType = ChunkType(7); -pub(crate) const CT_SHUTDOWN_ACK: ChunkType = ChunkType(8); -pub(crate) const CT_ERROR: ChunkType = ChunkType(9); -pub(crate) const CT_COOKIE_ECHO: ChunkType = ChunkType(10); -pub(crate) const CT_COOKIE_ACK: ChunkType = ChunkType(11); -pub(crate) const CT_ECNE: ChunkType = ChunkType(12); -pub(crate) const CT_CWR: ChunkType = ChunkType(13); -pub(crate) const CT_SHUTDOWN_COMPLETE: ChunkType = ChunkType(14); -pub(crate) const CT_RECONFIG: ChunkType = ChunkType(130); -pub(crate) const CT_FORWARD_TSN: ChunkType = ChunkType(192); - -impl fmt::Display for ChunkType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let others = format!("Unknown ChunkType: {}", self.0); - let s = match *self { - CT_PAYLOAD_DATA => "DATA", - CT_INIT => "INIT", - CT_INIT_ACK => "INIT-ACK", - CT_SACK => "SACK", - CT_HEARTBEAT => "HEARTBEAT", - CT_HEARTBEAT_ACK => "HEARTBEAT-ACK", - CT_ABORT => "ABORT", - CT_SHUTDOWN => "SHUTDOWN", - CT_SHUTDOWN_ACK => "SHUTDOWN-ACK", - CT_ERROR => "ERROR", - CT_COOKIE_ECHO => "COOKIE-ECHO", - CT_COOKIE_ACK => "COOKIE-ACK", - CT_ECNE => "ECNE", // Explicit Congestion Notification Echo - CT_CWR => "CWR", // Reserved for Congestion Window Reduced (CWR) - CT_SHUTDOWN_COMPLETE => "SHUTDOWN-COMPLETE", - CT_RECONFIG => "RECONFIG", // Re-configuration - CT_FORWARD_TSN => "FORWARD-TSN", - _ => others.as_str(), - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_chunk_type_string() { - let tests = vec![ - (CT_PAYLOAD_DATA, "DATA"), - (CT_INIT, "INIT"), - (CT_INIT_ACK, "INIT-ACK"), - (CT_SACK, "SACK"), - (CT_HEARTBEAT, "HEARTBEAT"), - (CT_HEARTBEAT_ACK, "HEARTBEAT-ACK"), - (CT_ABORT, "ABORT"), - (CT_SHUTDOWN, "SHUTDOWN"), - (CT_SHUTDOWN_ACK, "SHUTDOWN-ACK"), - (CT_ERROR, "ERROR"), - (CT_COOKIE_ECHO, "COOKIE-ECHO"), - (CT_COOKIE_ACK, "COOKIE-ACK"), - (CT_ECNE, "ECNE"), - (CT_CWR, "CWR"), - (CT_SHUTDOWN_COMPLETE, "SHUTDOWN-COMPLETE"), - (CT_RECONFIG, "RECONFIG"), - (CT_FORWARD_TSN, "FORWARD-TSN"), - (ChunkType(255), "Unknown ChunkType: 255"), - ]; - - for (ct, expected) in tests { - assert_eq!( - ct.to_string(), - expected, - "failed to stringify chunkType {ct}, expected {expected}" - ); - } - } -} diff --git a/sctp/src/chunk/chunk_unknown.rs b/sctp/src/chunk/chunk_unknown.rs deleted file mode 100644 index 5ee94ee41..000000000 --- a/sctp/src/chunk/chunk_unknown.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::any::Any; -use std::fmt::{Debug, Display, Formatter}; - -use bytes::{Bytes, BytesMut}; - -use crate::chunk::chunk_header::{ChunkHeader, CHUNK_HEADER_SIZE}; -use crate::chunk::Chunk; - -#[derive(Clone, Debug)] -pub struct ChunkUnknown { - hdr: ChunkHeader, - value: Bytes, -} - -impl Display for ChunkUnknown { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "ChunkUnknown( {} {:?} )", self.hdr, self.value) - } -} - -impl Chunk for ChunkUnknown { - fn header(&self) -> ChunkHeader { - self.hdr.clone() - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn check(&self) -> crate::error::Result<()> { - Ok(()) - } - - fn value_length(&self) -> usize { - self.value.len() - } - - fn marshal_to(&self, buf: &mut BytesMut) -> crate::error::Result { - self.header().marshal_to(buf)?; - buf.extend(&self.value); - Ok(buf.len()) - } - - fn unmarshal(raw: &Bytes) -> crate::error::Result - where - Self: Sized, - { - let header = ChunkHeader::unmarshal(raw)?; - let len = header.value_length(); - Ok(Self { - hdr: header, - value: raw.slice(CHUNK_HEADER_SIZE..CHUNK_HEADER_SIZE + len), - }) - } -} diff --git a/sctp/src/chunk/mod.rs b/sctp/src/chunk/mod.rs deleted file mode 100644 index e7fb6b96d..000000000 --- a/sctp/src/chunk/mod.rs +++ /dev/null @@ -1,47 +0,0 @@ -#[cfg(test)] -mod chunk_test; - -pub(crate) mod chunk_abort; -pub(crate) mod chunk_cookie_ack; -pub(crate) mod chunk_cookie_echo; -pub(crate) mod chunk_error; -pub(crate) mod chunk_forward_tsn; -pub(crate) mod chunk_header; -pub(crate) mod chunk_heartbeat; -pub(crate) mod chunk_heartbeat_ack; -pub(crate) mod chunk_init; -pub mod chunk_payload_data; -pub(crate) mod chunk_reconfig; -pub(crate) mod chunk_selective_ack; -pub(crate) mod chunk_shutdown; -pub(crate) mod chunk_shutdown_ack; -pub(crate) mod chunk_shutdown_complete; -pub(crate) mod chunk_type; -pub(crate) mod chunk_unknown; - -use std::any::Any; -use std::fmt; -use std::marker::Sized; - -use bytes::{Bytes, BytesMut}; -use chunk_header::*; - -use crate::error::{Error, Result}; - -pub(crate) trait Chunk: fmt::Display + fmt::Debug { - fn header(&self) -> ChunkHeader; - fn unmarshal(raw: &Bytes) -> Result - where - Self: Sized; - fn marshal_to(&self, buf: &mut BytesMut) -> Result; - fn check(&self) -> Result<()>; - fn value_length(&self) -> usize; - fn as_any(&self) -> &(dyn Any + Send + Sync); - - fn marshal(&self) -> Result { - let capacity = CHUNK_HEADER_SIZE + self.value_length(); - let mut buf = BytesMut::with_capacity(capacity); - self.marshal_to(&mut buf)?; - Ok(buf.freeze()) - } -} diff --git a/sctp/src/error.rs b/sctp/src/error.rs deleted file mode 100644 index 3069f02f4..000000000 --- a/sctp/src/error.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::io; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq, Eq, Clone)] -#[non_exhaustive] -pub enum Error { - #[error("raw is too small for a SCTP chunk")] - ErrChunkHeaderTooSmall, - #[error("not enough data left in SCTP packet to satisfy requested length")] - ErrChunkHeaderNotEnoughSpace, - #[error("chunk PADDING is non-zero at offset")] - ErrChunkHeaderPaddingNonZero, - #[error("chunk has invalid length")] - ErrChunkHeaderInvalidLength, - - #[error("ChunkType is not of type ABORT")] - ErrChunkTypeNotAbort, - #[error("failed build Abort Chunk")] - ErrBuildAbortChunkFailed, - #[error("ChunkType is not of type COOKIEACK")] - ErrChunkTypeNotCookieAck, - #[error("ChunkType is not of type COOKIEECHO")] - ErrChunkTypeNotCookieEcho, - #[error("ChunkType is not of type ctError")] - ErrChunkTypeNotCtError, - #[error("failed build Error Chunk")] - ErrBuildErrorChunkFailed, - #[error("failed to marshal stream")] - ErrMarshalStreamFailed, - #[error("chunk too short")] - ErrChunkTooShort, - #[error("ChunkType is not of type ForwardTsn")] - ErrChunkTypeNotForwardTsn, - #[error("ChunkType is not of type HEARTBEAT")] - ErrChunkTypeNotHeartbeat, - #[error("ChunkType is not of type HEARTBEATACK")] - ErrChunkTypeNotHeartbeatAck, - #[error("heartbeat is not long enough to contain Heartbeat Info")] - ErrHeartbeatNotLongEnoughInfo, - #[error("failed to parse param type")] - ErrParseParamTypeFailed, - #[error("heartbeat should only have HEARTBEAT param")] - ErrHeartbeatParam, - #[error("failed unmarshalling param in Heartbeat Chunk")] - ErrHeartbeatChunkUnmarshal, - #[error("unimplemented")] - ErrUnimplemented, - #[error("heartbeat Ack must have one param")] - ErrHeartbeatAckParams, - #[error("heartbeat Ack must have one param, and it should be a HeartbeatInfo")] - ErrHeartbeatAckNotHeartbeatInfo, - #[error("unable to marshal parameter for Heartbeat Ack")] - ErrHeartbeatAckMarshalParam, - - #[error("raw is too small for error cause")] - ErrErrorCauseTooSmall, - - #[error("unhandled ParamType `{typ}`")] - ErrParamTypeUnhandled { typ: u16 }, - - #[error("unexpected ParamType")] - ErrParamTypeUnexpected, - - #[error("param header too short")] - ErrParamHeaderTooShort, - #[error("param self reported length is shorter than header length")] - ErrParamHeaderSelfReportedLengthShorter, - #[error("param self reported length is longer than header length")] - ErrParamHeaderSelfReportedLengthLonger, - #[error("failed to parse param type")] - ErrParamHeaderParseFailed, - - #[error("packet to short")] - ErrParamPacketTooShort, - #[error("outgoing SSN reset request parameter too short")] - ErrSsnResetRequestParamTooShort, - #[error("reconfig response parameter too short")] - ErrReconfigRespParamTooShort, - #[error("invalid algorithm type")] - ErrInvalidAlgorithmType, - - #[error("failed to parse param type")] - ErrInitChunkParseParamTypeFailed, - #[error("failed unmarshalling param in Init Chunk")] - ErrInitChunkUnmarshalParam, - #[error("unable to marshal parameter for INIT/INITACK")] - ErrInitAckMarshalParam, - - #[error("ChunkType is not of type INIT")] - ErrChunkTypeNotTypeInit, - #[error("chunk Value isn't long enough for mandatory parameters exp")] - ErrChunkValueNotLongEnough, - #[error("ChunkType of type INIT flags must be all 0")] - ErrChunkTypeInitFlagZero, - #[error("failed to unmarshal INIT body")] - ErrChunkTypeInitUnmarshalFailed, - #[error("failed marshaling INIT common data")] - ErrChunkTypeInitMarshalFailed, - #[error("ChunkType of type INIT ACK InitiateTag must not be 0")] - ErrChunkTypeInitInitiateTagZero, - #[error("INIT ACK inbound stream request must be > 0")] - ErrInitInboundStreamRequestZero, - #[error("INIT ACK outbound stream request must be > 0")] - ErrInitOutboundStreamRequestZero, - #[error("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500")] - ErrInitAdvertisedReceiver1500, - - #[error("packet is smaller than the header size")] - ErrChunkPayloadSmall, - #[error("ChunkType is not of type PayloadData")] - ErrChunkTypeNotPayloadData, - #[error("ChunkType is not of type Reconfig")] - ErrChunkTypeNotReconfig, - #[error("ChunkReconfig has invalid ParamA")] - ErrChunkReconfigInvalidParamA, - - #[error("failed to parse param type")] - ErrChunkParseParamTypeFailed, - #[error("unable to marshal parameter A for reconfig")] - ErrChunkMarshalParamAReconfigFailed, - #[error("unable to marshal parameter B for reconfig")] - ErrChunkMarshalParamBReconfigFailed, - - #[error("ChunkType is not of type SACK")] - ErrChunkTypeNotSack, - #[error("SACK Chunk size is not large enough to contain header")] - ErrSackSizeNotLargeEnoughInfo, - - #[error("invalid chunk size")] - ErrInvalidChunkSize, - #[error("ChunkType is not of type SHUTDOWN")] - ErrChunkTypeNotShutdown, - - #[error("ChunkType is not of type SHUTDOWN-ACK")] - ErrChunkTypeNotShutdownAck, - #[error("ChunkType is not of type SHUTDOWN-COMPLETE")] - ErrChunkTypeNotShutdownComplete, - - #[error("raw is smaller than the minimum length for a SCTP packet")] - ErrPacketRawTooSmall, - #[error("unable to parse SCTP chunk, not enough data for complete header")] - ErrParseSctpChunkNotEnoughData, - #[error("failed to unmarshal, contains unknown chunk type")] - ErrUnmarshalUnknownChunkType, - #[error("checksum mismatch theirs")] - ErrChecksumMismatch, - - #[error("unexpected chunk popped (unordered)")] - ErrUnexpectedChuckPoppedUnordered, - #[error("unexpected chunk popped (ordered)")] - ErrUnexpectedChuckPoppedOrdered, - #[error("unexpected q state (should've been selected)")] - ErrUnexpectedQState, - #[error("try again")] - ErrTryAgain, - - #[error("abort chunk, with following errors")] - ErrChunk, - #[error("shutdown called in non-Established state")] - ErrShutdownNonEstablished, - #[error("association closed before connecting")] - ErrAssociationClosedBeforeConn, - #[error("association init failed")] - ErrAssociationInitFailed, - #[error("association handshake closed")] - ErrAssociationHandshakeClosed, - #[error("silently discard")] - ErrSilentlyDiscard, - #[error("the init not stored to send")] - ErrInitNotStoredToSend, - #[error("cookieEcho not stored to send")] - ErrCookieEchoNotStoredToSend, - #[error("sctp packet must not have a source port of 0")] - ErrSctpPacketSourcePortZero, - #[error("sctp packet must not have a destination port of 0")] - ErrSctpPacketDestinationPortZero, - #[error("init chunk must not be bundled with any other chunk")] - ErrInitChunkBundled, - #[error("init chunk expects a verification tag of 0 on the packet when out-of-the-blue")] - ErrInitChunkVerifyTagNotZero, - #[error("todo: handle Init when in state")] - ErrHandleInitState, - #[error("no cookie in InitAck")] - ErrInitAckNoCookie, - #[error("there already exists a stream with identifier")] - ErrStreamAlreadyExist, - #[error("Failed to create a stream with identifier")] - ErrStreamCreateFailed, - #[error("unable to be popped from inflight queue TSN")] - ErrInflightQueueTsnPop, - #[error("requested non-existent TSN")] - ErrTsnRequestNotExist, - #[error("sending reset packet in non-Established state")] - ErrResetPacketInStateNotExist, - #[error("unexpected parameter type")] - ErrParameterType, - #[error("sending payload data in non-Established state")] - ErrPayloadDataStateNotExist, - #[error("unhandled chunk type")] - ErrChunkTypeUnhandled, - #[error("handshake failed (INIT ACK)")] - ErrHandshakeInitAck, - #[error("handshake failed (COOKIE ECHO)")] - ErrHandshakeCookieEcho, - - #[error("outbound packet larger than maximum message size")] - ErrOutboundPacketTooLarge, - #[error("Stream closed")] - ErrStreamClosed, - #[error("Short buffer (size: {size:?}) to be filled")] - ErrShortBuffer { size: usize }, - #[error("Io EOF")] - ErrEof, - #[error("Invalid SystemTime")] - ErrInvalidSystemTime, - #[error("Net Conn read error")] - ErrNetConnReadError, - #[error("Max Data Channel ID")] - ErrMaxDataChannelID, - - #[error("{0}")] - Other(String), -} - -impl From for io::Error { - fn from(error: Error) -> Self { - match error { - e @ Error::ErrEof => io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()), - e @ Error::ErrStreamClosed => { - io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()) - } - e => io::Error::new(io::ErrorKind::Other, e.to_string()), - } - } -} diff --git a/sctp/src/error_cause.rs b/sctp/src/error_cause.rs deleted file mode 100644 index d8afd02ce..000000000 --- a/sctp/src/error_cause.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use crate::error::{Error, Result}; - -/// errorCauseCode is a cause code that appears in either a ERROR or ABORT chunk -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] -pub(crate) struct ErrorCauseCode(pub(crate) u16); - -pub(crate) const INVALID_STREAM_IDENTIFIER: ErrorCauseCode = ErrorCauseCode(1); -pub(crate) const MISSING_MANDATORY_PARAMETER: ErrorCauseCode = ErrorCauseCode(2); -pub(crate) const STALE_COOKIE_ERROR: ErrorCauseCode = ErrorCauseCode(3); -pub(crate) const OUT_OF_RESOURCE: ErrorCauseCode = ErrorCauseCode(4); -pub(crate) const UNRESOLVABLE_ADDRESS: ErrorCauseCode = ErrorCauseCode(5); -pub(crate) const UNRECOGNIZED_CHUNK_TYPE: ErrorCauseCode = ErrorCauseCode(6); -pub(crate) const INVALID_MANDATORY_PARAMETER: ErrorCauseCode = ErrorCauseCode(7); -pub(crate) const UNRECOGNIZED_PARAMETERS: ErrorCauseCode = ErrorCauseCode(8); -pub(crate) const NO_USER_DATA: ErrorCauseCode = ErrorCauseCode(9); -pub(crate) const COOKIE_RECEIVED_WHILE_SHUTTING_DOWN: ErrorCauseCode = ErrorCauseCode(10); -pub(crate) const RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESSES: ErrorCauseCode = ErrorCauseCode(11); -pub(crate) const USER_INITIATED_ABORT: ErrorCauseCode = ErrorCauseCode(12); -pub(crate) const PROTOCOL_VIOLATION: ErrorCauseCode = ErrorCauseCode(13); - -impl fmt::Display for ErrorCauseCode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let others = format!("Unknown CauseCode: {}", self.0); - let s = match *self { - INVALID_STREAM_IDENTIFIER => "Invalid Stream Identifier", - MISSING_MANDATORY_PARAMETER => "Missing Mandatory Parameter", - STALE_COOKIE_ERROR => "Stale Cookie Error", - OUT_OF_RESOURCE => "Out Of Resource", - UNRESOLVABLE_ADDRESS => "Unresolvable IP", - UNRECOGNIZED_CHUNK_TYPE => "Unrecognized Chunk Type", - INVALID_MANDATORY_PARAMETER => "Invalid Mandatory Parameter", - UNRECOGNIZED_PARAMETERS => "Unrecognized Parameters", - NO_USER_DATA => "No User Data", - COOKIE_RECEIVED_WHILE_SHUTTING_DOWN => "Cookie Received While Shutting Down", - RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESSES => { - "Restart Of An Association With New Addresses" - } - USER_INITIATED_ABORT => "User Initiated Abort", - PROTOCOL_VIOLATION => "Protocol Violation", - _ => others.as_str(), - }; - write!(f, "{s}") - } -} - -/// ErrorCauseHeader represents the shared header that is shared by all error causes -#[derive(Debug, Clone, Default)] -pub(crate) struct ErrorCause { - pub(crate) code: ErrorCauseCode, - pub(crate) raw: Bytes, -} - -/// ErrorCauseInvalidMandatoryParameter represents an SCTP error cause -pub(crate) type ErrorCauseInvalidMandatoryParameter = ErrorCause; - -/// ErrorCauseUnrecognizedChunkType represents an SCTP error cause -pub(crate) type ErrorCauseUnrecognizedChunkType = ErrorCause; - -/// -/// This error cause MAY be included in ABORT chunks that are sent -/// because an SCTP endpoint detects a protocol violation of the peer -/// that is not covered by the error causes described in Section 3.3.10.1 -/// to Section 3.3.10.12. An implementation MAY provide additional -/// information specifying what kind of protocol violation has been -/// detected. -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Cause Code=13 | Cause Length=Variable | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// / Additional Information / -/// \ \ -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -pub(crate) type ErrorCauseProtocolViolation = ErrorCause; - -pub(crate) const ERROR_CAUSE_HEADER_LENGTH: usize = 4; - -/// makes ErrorCauseHeader printable -impl fmt::Display for ErrorCause { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.code) - } -} - -impl ErrorCause { - pub(crate) fn unmarshal(buf: &Bytes) -> Result { - if buf.len() < ERROR_CAUSE_HEADER_LENGTH { - return Err(Error::ErrErrorCauseTooSmall); - } - - let reader = &mut buf.clone(); - - let code = ErrorCauseCode(reader.get_u16()); - let len = reader.get_u16(); - - if len < ERROR_CAUSE_HEADER_LENGTH as u16 { - return Err(Error::ErrErrorCauseTooSmall); - } - if buf.len() < len as usize { - return Err(Error::ErrErrorCauseTooSmall); - } - - let value_length = len as usize - ERROR_CAUSE_HEADER_LENGTH; - - let raw = buf.slice(ERROR_CAUSE_HEADER_LENGTH..ERROR_CAUSE_HEADER_LENGTH + value_length); - - Ok(ErrorCause { code, raw }) - } - - pub(crate) fn marshal(&self) -> Bytes { - let mut buf = BytesMut::with_capacity(self.length()); - let _ = self.marshal_to(&mut buf); - buf.freeze() - } - - pub(crate) fn marshal_to(&self, writer: &mut BytesMut) -> usize { - let len = self.raw.len() + ERROR_CAUSE_HEADER_LENGTH; - writer.put_u16(self.code.0); - writer.put_u16(len as u16); - writer.extend(self.raw.clone()); - writer.len() - } - - pub(crate) fn length(&self) -> usize { - self.raw.len() + ERROR_CAUSE_HEADER_LENGTH - } - - pub(crate) fn error_cause_code(&self) -> ErrorCauseCode { - self.code - } -} diff --git a/sctp/src/fuzz_artifact_test.rs b/sctp/src/fuzz_artifact_test.rs deleted file mode 100644 index c40c9db10..000000000 --- a/sctp/src/fuzz_artifact_test.rs +++ /dev/null @@ -1,44 +0,0 @@ -//! # What are these tests? -//! -//! These tests ensure that regressions in the unmarshalling code are caught. -//! -//! They check all artifacts of the fuzzer that crashed this lib, and make sure they no longer crash the library. -//! -//! The content of the files is mostly garbage, but it triggers "interesting" behaviour in the unmarshalling code. -//! So if your change fails one of these tests you probably made an error somewhere. -//! -//! Sadly these tests cannot really tell you where your error is specifically outside the standard backtrace rust will provide to you. Sorry. - -use bytes::Bytes; - -#[test] -fn param_crash_artifacts() { - for artifact in std::fs::read_dir("fuzz/artifacts/param").unwrap() { - let artifact = artifact.unwrap(); - if artifact - .file_name() - .into_string() - .unwrap() - .starts_with("crash-") - { - let artifact = std::fs::read(artifact.path()).unwrap(); - crate::param::build_param(&Bytes::from(artifact)).ok(); - } - } -} - -#[test] -fn packet_crash_artifacts() { - for artifact in std::fs::read_dir("fuzz/artifacts/packet").unwrap() { - let artifact = artifact.unwrap(); - if artifact - .file_name() - .into_string() - .unwrap() - .starts_with("crash-") - { - let artifact = std::fs::read(artifact.path()).unwrap(); - crate::packet::Packet::unmarshal(&Bytes::from(artifact)).ok(); - } - } -} diff --git a/sctp/src/lib.rs b/sctp/src/lib.rs deleted file mode 100644 index f299abdb3..000000000 --- a/sctp/src/lib.rs +++ /dev/null @@ -1,18 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod association; -pub mod chunk; -mod error; -pub mod error_cause; -pub mod packet; -pub mod param; -pub(crate) mod queue; -pub mod stream; -pub(crate) mod timer; -pub(crate) mod util; - -pub use error::Error; - -#[cfg(test)] -mod fuzz_artifact_test; diff --git a/sctp/src/packet.rs b/sctp/src/packet.rs deleted file mode 100644 index 2c37fa76a..000000000 --- a/sctp/src/packet.rs +++ /dev/null @@ -1,304 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use crate::chunk::chunk_abort::ChunkAbort; -use crate::chunk::chunk_cookie_ack::ChunkCookieAck; -use crate::chunk::chunk_cookie_echo::ChunkCookieEcho; -use crate::chunk::chunk_error::ChunkError; -use crate::chunk::chunk_forward_tsn::ChunkForwardTsn; -use crate::chunk::chunk_header::*; -use crate::chunk::chunk_heartbeat::ChunkHeartbeat; -use crate::chunk::chunk_init::ChunkInit; -use crate::chunk::chunk_payload_data::ChunkPayloadData; -use crate::chunk::chunk_reconfig::ChunkReconfig; -use crate::chunk::chunk_selective_ack::ChunkSelectiveAck; -use crate::chunk::chunk_shutdown::ChunkShutdown; -use crate::chunk::chunk_shutdown_ack::ChunkShutdownAck; -use crate::chunk::chunk_shutdown_complete::ChunkShutdownComplete; -use crate::chunk::chunk_type::*; -use crate::chunk::chunk_unknown::ChunkUnknown; -use crate::chunk::Chunk; -use crate::error::{Error, Result}; -use crate::util::*; - -///Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3 -///An SCTP packet is composed of a common header and chunks. A chunk -///contains either control information or user data. -/// -/// -///SCTP Packet Format -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Common Header | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Chunk #1 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| ... | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Chunk #n | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// -/// -///SCTP Common Header Format -/// -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Source Value Number | Destination Value Number | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Verification Tag | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Checksum | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug)] -pub(crate) struct Packet { - pub(crate) source_port: u16, - pub(crate) destination_port: u16, - pub(crate) verification_tag: u32, - pub(crate) chunks: Vec>, -} - -/// makes packet printable -impl fmt::Display for Packet { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut res = format!( - "Packet: - source_port: {} - destination_port: {} - verification_tag: {} - ", - self.source_port, self.destination_port, self.verification_tag, - ); - for chunk in &self.chunks { - res += format!("Chunk: {chunk}").as_str(); - } - write!(f, "{res}") - } -} - -pub(crate) const PACKET_HEADER_SIZE: usize = 12; - -impl Packet { - pub(crate) fn unmarshal(raw: &Bytes) -> Result { - if raw.len() < PACKET_HEADER_SIZE { - return Err(Error::ErrPacketRawTooSmall); - } - - let reader = &mut raw.clone(); - - let source_port = reader.get_u16(); - let destination_port = reader.get_u16(); - let verification_tag = reader.get_u32(); - - #[cfg(not(fuzzing))] - // only check for checksums when we are not fuzzing. This lets the fuzzer test the code much easier without guessing correct checksums. - { - let their_checksum = reader.get_u32_le(); - let our_checksum = generate_packet_checksum(raw); - - if their_checksum != our_checksum { - return Err(Error::ErrChecksumMismatch); - } - } - - let mut chunks = vec![]; - let mut offset = PACKET_HEADER_SIZE; - loop { - // Exact match, no more chunks - if offset == raw.len() { - break; - } else if offset + CHUNK_HEADER_SIZE > raw.len() { - return Err(Error::ErrParseSctpChunkNotEnoughData); - } - - let ct = ChunkType(raw[offset]); - let c: Box = match ct { - CT_INIT => Box::new(ChunkInit::unmarshal(&raw.slice(offset..))?), - CT_INIT_ACK => Box::new(ChunkInit::unmarshal(&raw.slice(offset..))?), - CT_ABORT => Box::new(ChunkAbort::unmarshal(&raw.slice(offset..))?), - CT_COOKIE_ECHO => Box::new(ChunkCookieEcho::unmarshal(&raw.slice(offset..))?), - CT_COOKIE_ACK => Box::new(ChunkCookieAck::unmarshal(&raw.slice(offset..))?), - CT_HEARTBEAT => Box::new(ChunkHeartbeat::unmarshal(&raw.slice(offset..))?), - CT_PAYLOAD_DATA => Box::new(ChunkPayloadData::unmarshal(&raw.slice(offset..))?), - CT_SACK => Box::new(ChunkSelectiveAck::unmarshal(&raw.slice(offset..))?), - CT_RECONFIG => Box::new(ChunkReconfig::unmarshal(&raw.slice(offset..))?), - CT_FORWARD_TSN => Box::new(ChunkForwardTsn::unmarshal(&raw.slice(offset..))?), - CT_ERROR => Box::new(ChunkError::unmarshal(&raw.slice(offset..))?), - CT_SHUTDOWN => Box::new(ChunkShutdown::unmarshal(&raw.slice(offset..))?), - CT_SHUTDOWN_ACK => Box::new(ChunkShutdownAck::unmarshal(&raw.slice(offset..))?), - CT_SHUTDOWN_COMPLETE => { - Box::new(ChunkShutdownComplete::unmarshal(&raw.slice(offset..))?) - } - _ => Box::new(ChunkUnknown::unmarshal(&raw.slice(offset..))?), - }; - - let chunk_value_padding = get_padding_size(c.value_length()); - offset += CHUNK_HEADER_SIZE + c.value_length() + chunk_value_padding; - chunks.push(c); - } - - Ok(Packet { - source_port, - destination_port, - verification_tag, - chunks, - }) - } - - pub(crate) fn marshal_to(&self, writer: &mut BytesMut) -> Result { - // Populate static headers - // 8-12 is Checksum which will be populated when packet is complete - writer.put_u16(self.source_port); - writer.put_u16(self.destination_port); - writer.put_u32(self.verification_tag); - - // This is where the checksum will be written - let checksum_pos = writer.len(); - writer.extend_from_slice(&[0, 0, 0, 0]); - - // Populate chunks - for c in &self.chunks { - c.marshal_to(writer)?; - - let padding_needed = get_padding_size(writer.len()); - if padding_needed != 0 { - // padding needed if < 4 because we pad to 4 - writer.extend_from_slice(&[0u8; PADDING_MULTIPLE][..padding_needed]); - } - } - - let mut digest = ISCSI_CRC.digest(); - digest.update(writer); - let checksum = digest.finalize(); - - // Checksum is already in BigEndian - // Using LittleEndian stops it from being flipped - let checksum_place = &mut writer[checksum_pos..checksum_pos + 4]; - checksum_place.copy_from_slice(&checksum.to_le_bytes()); - - Ok(writer.len()) - } - - pub(crate) fn marshal(&self) -> Result { - let mut buf = BytesMut::with_capacity(PACKET_HEADER_SIZE); - self.marshal_to(&mut buf)?; - Ok(buf.freeze()) - } -} - -impl Packet { - pub(crate) fn check_packet(&self) -> Result<()> { - // All packets must adhere to these rules - - // This is the SCTP sender's port number. It can be used by the - // receiver in combination with the source IP address, the SCTP - // destination port, and possibly the destination IP address to - // identify the association to which this packet belongs. The port - // number 0 MUST NOT be used. - if self.source_port == 0 { - return Err(Error::ErrSctpPacketSourcePortZero); - } - - // This is the SCTP port number to which this packet is destined. - // The receiving host will use this port number to de-multiplex the - // SCTP packet to the correct receiving endpoint/application. The - // port number 0 MUST NOT be used. - if self.destination_port == 0 { - return Err(Error::ErrSctpPacketDestinationPortZero); - } - - // Check values on the packet that are specific to a particular chunk type - for c in &self.chunks { - if let Some(ci) = c.as_any().downcast_ref::() { - if !ci.is_ack { - // An INIT or INIT ACK chunk MUST NOT be bundled with any other chunk. - // They MUST be the only chunks present in the SCTP packets that carry - // them. - if self.chunks.len() != 1 { - return Err(Error::ErrInitChunkBundled); - } - - // A packet containing an INIT chunk MUST have a zero Verification - // Tag. - if self.verification_tag != 0 { - return Err(Error::ErrInitChunkVerifyTagNotZero); - } - } - } - } - - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_packet_unmarshal() -> Result<()> { - let result = Packet::unmarshal(&Bytes::new()); - assert!( - result.is_err(), - "Unmarshal should fail when a packet is too small to be SCTP" - ); - - let header_only = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1, - ]); - let pkt = Packet::unmarshal(&header_only)?; - //assert!(result.o(), "Unmarshal failed for SCTP packet with no chunks: {}", result); - assert_eq!( - pkt.source_port, 5000, - "Unmarshal passed for SCTP packet, but got incorrect source port exp: {} act: {}", - 5000, pkt.source_port - ); - assert_eq!( - pkt.destination_port, 5000, - "Unmarshal passed for SCTP packet, but got incorrect destination port exp: {} act: {}", - 5000, pkt.destination_port - ); - assert_eq!( - pkt.verification_tag, 0, - "Unmarshal passed for SCTP packet, but got incorrect verification tag exp: {} act: {}", - 0, pkt.verification_tag - ); - - let raw_chunk = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, - 0x00, 0x56, 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, - 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, - 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, - 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, - 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, - 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, - 0x00, 0x00, - ]); - - Packet::unmarshal(&raw_chunk)?; - - Ok(()) - } - - #[test] - fn test_packet_marshal() -> Result<()> { - let header_only = Bytes::from_static(&[ - 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1, - ]); - let pkt = Packet::unmarshal(&header_only)?; - let header_only_marshaled = pkt.marshal()?; - assert_eq!(header_only, header_only_marshaled, "Unmarshal/Marshaled header only packet did not match \nheaderOnly: {header_only:?} \nheader_only_marshaled {header_only_marshaled:?}"); - - Ok(()) - } - - /*fn BenchmarkPacketGenerateChecksum(b *testing.B) { - var data [1024]byte - - for i := 0; i < b.N; i++ { - _ = generatePacketChecksum(data[:]) - } - }*/ -} diff --git a/sctp/src/param/mod.rs b/sctp/src/param/mod.rs deleted file mode 100644 index 49ca3c962..000000000 --- a/sctp/src/param/mod.rs +++ /dev/null @@ -1,89 +0,0 @@ -#[cfg(test)] -mod param_test; - -pub(crate) mod param_chunk_list; -pub(crate) mod param_forward_tsn_supported; -pub(crate) mod param_header; -pub(crate) mod param_heartbeat_info; -pub(crate) mod param_outgoing_reset_request; -pub(crate) mod param_random; -pub(crate) mod param_reconfig_response; -pub(crate) mod param_requested_hmac_algorithm; -pub(crate) mod param_state_cookie; -pub(crate) mod param_supported_extensions; -pub(crate) mod param_type; -pub(crate) mod param_unknown; -pub(crate) mod param_unrecognized; - -use std::any::Any; -use std::fmt; - -use bytes::{Buf, Bytes, BytesMut}; -use param_header::*; -use param_type::*; - -use crate::error::{Error, Result}; -use crate::param::param_chunk_list::ParamChunkList; -use crate::param::param_forward_tsn_supported::ParamForwardTsnSupported; -use crate::param::param_heartbeat_info::ParamHeartbeatInfo; -use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest; -use crate::param::param_random::ParamRandom; -use crate::param::param_reconfig_response::ParamReconfigResponse; -use crate::param::param_requested_hmac_algorithm::ParamRequestedHmacAlgorithm; -use crate::param::param_state_cookie::ParamStateCookie; -use crate::param::param_supported_extensions::ParamSupportedExtensions; -use crate::param::param_unknown::ParamUnknown; - -pub(crate) trait Param: fmt::Display + fmt::Debug { - fn header(&self) -> ParamHeader; - fn unmarshal(raw: &Bytes) -> Result - where - Self: Sized; - fn marshal_to(&self, buf: &mut BytesMut) -> Result; - fn value_length(&self) -> usize; - fn clone_to(&self) -> Box; - fn as_any(&self) -> &(dyn Any + Send + Sync); - - fn marshal(&self) -> Result { - let capacity = PARAM_HEADER_LENGTH + self.value_length(); - let mut buf = BytesMut::with_capacity(capacity); - self.marshal_to(&mut buf)?; - Ok(buf.freeze()) - } -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_to() - } -} - -pub(crate) fn build_param(raw_param: &Bytes) -> Result> { - if raw_param.len() < PARAM_HEADER_LENGTH { - return Err(Error::ErrParamHeaderTooShort); - } - let reader = &mut raw_param.slice(..2); - let raw_type = reader.get_u16(); - match raw_type.into() { - ParamType::ForwardTsnSupp => Ok(Box::new(ParamForwardTsnSupported::unmarshal(raw_param)?)), - ParamType::SupportedExt => Ok(Box::new(ParamSupportedExtensions::unmarshal(raw_param)?)), - ParamType::Random => Ok(Box::new(ParamRandom::unmarshal(raw_param)?)), - ParamType::ReqHmacAlgo => Ok(Box::new(ParamRequestedHmacAlgorithm::unmarshal(raw_param)?)), - ParamType::ChunkList => Ok(Box::new(ParamChunkList::unmarshal(raw_param)?)), - ParamType::StateCookie => Ok(Box::new(ParamStateCookie::unmarshal(raw_param)?)), - ParamType::HeartbeatInfo => Ok(Box::new(ParamHeartbeatInfo::unmarshal(raw_param)?)), - ParamType::OutSsnResetReq => Ok(Box::new(ParamOutgoingResetRequest::unmarshal(raw_param)?)), - ParamType::ReconfigResp => Ok(Box::new(ParamReconfigResponse::unmarshal(raw_param)?)), - _ => { - // According to RFC https://datatracker.ietf.org/doc/html/rfc4960#section-3.2.1 - let stop_processing = ((raw_type >> 15) & 0x01) == 0; - if stop_processing { - Err(Error::ErrParamTypeUnhandled { typ: raw_type }) - } else { - // We still might need to report this param as unrecognized. - // This depends on the context though. - Ok(Box::new(ParamUnknown::unmarshal(raw_param)?)) - } - } - } -} diff --git a/sctp/src/param/param_chunk_list.rs b/sctp/src/param/param_chunk_list.rs deleted file mode 100644 index d3b704e73..000000000 --- a/sctp/src/param/param_chunk_list.rs +++ /dev/null @@ -1,73 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; -use crate::chunk::chunk_type::*; - -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamChunkList { - pub(crate) chunk_types: Vec, -} - -impl fmt::Display for ParamChunkList { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {}", - self.header(), - self.chunk_types - .iter() - .map(|ct| ct.to_string()) - .collect::>() - .join(" ") - ) - } -} - -impl Param for ParamChunkList { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::ChunkList, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - - if header.typ != ParamType::ChunkList { - return Err(Error::ErrParamTypeUnexpected); - } - - let reader = - &mut raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - - let mut chunk_types = vec![]; - while reader.has_remaining() { - chunk_types.push(ChunkType(reader.get_u8())); - } - - Ok(ParamChunkList { chunk_types }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - for ct in &self.chunk_types { - buf.put_u8(ct.0); - } - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.chunk_types.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_forward_tsn_supported.rs b/sctp/src/param/param_forward_tsn_supported.rs deleted file mode 100644 index c742b1408..000000000 --- a/sctp/src/param/param_forward_tsn_supported.rs +++ /dev/null @@ -1,53 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -/// At the initialization of the association, the sender of the INIT or -/// INIT ACK chunk MAY include this OPTIONAL parameter to inform its peer -/// that it is able to support the Forward TSN chunk -/// -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Parameter Type = 49152 | Parameter Length = 4 | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamForwardTsnSupported; - -impl fmt::Display for ParamForwardTsnSupported { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.header()) - } -} - -impl Param for ParamForwardTsnSupported { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::ForwardTsnSupp, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let _ = ParamHeader::unmarshal(raw)?; - Ok(ParamForwardTsnSupported {}) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - 0 - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_header.rs b/sctp/src/param/param_header.rs deleted file mode 100644 index a004f3ab3..000000000 --- a/sctp/src/param/param_header.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::param_type::*; -use super::*; - -#[derive(Debug, Clone, PartialEq)] -pub(crate) struct ParamHeader { - pub(crate) typ: ParamType, - pub(crate) value_length: u16, -} - -pub(crate) const PARAM_HEADER_LENGTH: usize = 4; - -/// String makes paramHeader printable -impl fmt::Display for ParamHeader { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.typ) - } -} - -impl Param for ParamHeader { - fn header(&self) -> ParamHeader { - self.clone() - } - - fn unmarshal(raw: &Bytes) -> Result { - if raw.len() < PARAM_HEADER_LENGTH { - return Err(Error::ErrParamHeaderTooShort); - } - - let reader = &mut raw.clone(); - - let typ: ParamType = reader.get_u16().into(); - - let len = reader.get_u16() as usize; - if len < PARAM_HEADER_LENGTH || raw.len() < len { - return Err(Error::ErrParamHeaderTooShort); - } - - Ok(ParamHeader { - typ, - value_length: (len - PARAM_HEADER_LENGTH) as u16, - }) - } - - fn marshal_to(&self, writer: &mut BytesMut) -> Result { - writer.put_u16(self.typ.into()); - writer.put_u16(self.value_length + PARAM_HEADER_LENGTH as u16); - Ok(writer.len()) - } - - fn value_length(&self) -> usize { - self.value_length as usize - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_heartbeat_info.rs b/sctp/src/param/param_heartbeat_info.rs deleted file mode 100644 index 12b310381..000000000 --- a/sctp/src/param/param_heartbeat_info.rs +++ /dev/null @@ -1,52 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamHeartbeatInfo { - pub(crate) heartbeat_information: Bytes, -} - -impl fmt::Display for ParamHeartbeatInfo { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {:?}", self.header(), self.heartbeat_information) - } -} - -impl Param for ParamHeartbeatInfo { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::HeartbeatInfo, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - let heartbeat_information = - raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - Ok(ParamHeartbeatInfo { - heartbeat_information, - }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - buf.extend(self.heartbeat_information.clone()); - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.heartbeat_information.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_outgoing_reset_request.rs b/sctp/src/param/param_outgoing_reset_request.rs deleted file mode 100644 index abd2a78b7..000000000 --- a/sctp/src/param/param_outgoing_reset_request.rs +++ /dev/null @@ -1,124 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -pub(crate) const PARAM_OUTGOING_RESET_REQUEST_STREAM_IDENTIFIERS_OFFSET: usize = 12; - -///This parameter is used by the sender to request the reset of some or -///all outgoing streams. -/// 0 1 2 3 -/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Parameter Type = 13 | Parameter Length = 16 + 2 * N | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Re-configuration Request Sequence Number | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Re-configuration Response Sequence Number | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Sender's Last Assigned TSN | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Stream Number 1 (optional) | Stream Number 2 (optional) | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| ...... | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Stream Number N-1 (optional) | Stream Number N (optional) | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamOutgoingResetRequest { - /// reconfig_request_sequence_number is used to identify the request. It is a monotonically - /// increasing number that is initialized to the same value as the - /// initial TSN. It is increased by 1 whenever sending a new Re- - /// configuration Request Parameter. - pub(crate) reconfig_request_sequence_number: u32, - /// When this Outgoing SSN Reset Request Parameter is sent in response - /// to an Incoming SSN Reset Request Parameter, this parameter is also - /// an implicit response to the incoming request. This field then - /// holds the Re-configuration Request Sequence Number of the incoming - /// request. In other cases, it holds the next expected - /// Re-configuration Request Sequence Number minus 1. - pub(crate) reconfig_response_sequence_number: u32, - /// This value holds the next TSN minus 1 -- in other words, the last - /// TSN that this sender assigned. - pub(crate) sender_last_tsn: u32, - /// This optional field, if included, is used to indicate specific - /// streams that are to be reset. If no streams are listed, then all - /// streams are to be reset. - pub(crate) stream_identifiers: Vec, -} - -impl fmt::Display for ParamOutgoingResetRequest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {} {} {} {:?}", - self.header(), - self.reconfig_request_sequence_number, - self.reconfig_request_sequence_number, - self.reconfig_response_sequence_number, - self.stream_identifiers - ) - } -} - -impl Param for ParamOutgoingResetRequest { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::OutSsnResetReq, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - - // validity of value_length is checked in ParamHeader::unmarshal - if header.value_length() < PARAM_OUTGOING_RESET_REQUEST_STREAM_IDENTIFIERS_OFFSET { - return Err(Error::ErrSsnResetRequestParamTooShort); - } - - let reader = - &mut raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - let reconfig_request_sequence_number = reader.get_u32(); - let reconfig_response_sequence_number = reader.get_u32(); - let sender_last_tsn = reader.get_u32(); - - let lim = - (header.value_length() - PARAM_OUTGOING_RESET_REQUEST_STREAM_IDENTIFIERS_OFFSET) / 2; - let mut stream_identifiers = vec![]; - for _ in 0..lim { - stream_identifiers.push(reader.get_u16()); - } - - Ok(ParamOutgoingResetRequest { - reconfig_request_sequence_number, - reconfig_response_sequence_number, - sender_last_tsn, - stream_identifiers, - }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - buf.put_u32(self.reconfig_request_sequence_number); - buf.put_u32(self.reconfig_response_sequence_number); - buf.put_u32(self.sender_last_tsn); - for sid in &self.stream_identifiers { - buf.put_u16(*sid); - } - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - PARAM_OUTGOING_RESET_REQUEST_STREAM_IDENTIFIERS_OFFSET + self.stream_identifiers.len() * 2 - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_random.rs b/sctp/src/param/param_random.rs deleted file mode 100644 index d645b76e5..000000000 --- a/sctp/src/param/param_random.rs +++ /dev/null @@ -1,50 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamRandom { - pub(crate) random_data: Bytes, -} - -impl fmt::Display for ParamRandom { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {:?}", self.header(), self.random_data) - } -} - -impl Param for ParamRandom { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::Random, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - let random_data = - raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - Ok(ParamRandom { random_data }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - buf.extend(self.random_data.clone()); - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.random_data.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_reconfig_response.rs b/sctp/src/param/param_reconfig_response.rs deleted file mode 100644 index 7ef7a029c..000000000 --- a/sctp/src/param/param_reconfig_response.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -#[derive(Default, Debug, Copy, Clone, PartialEq)] -#[repr(C)] -pub(crate) enum ReconfigResult { - SuccessNop = 0, - SuccessPerformed = 1, - Denied = 2, - ErrorWrongSsn = 3, - ErrorRequestAlreadyInProgress = 4, - ErrorBadSequenceNumber = 5, - InProgress = 6, - #[default] - Unknown, -} - -impl fmt::Display for ReconfigResult { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - ReconfigResult::SuccessNop => "0: Success - Nothing to do", - ReconfigResult::SuccessPerformed => "1: Success - Performed", - ReconfigResult::Denied => "2: Denied", - ReconfigResult::ErrorWrongSsn => "3: Error - Wrong SSN", - ReconfigResult::ErrorRequestAlreadyInProgress => { - "4: Error - Request already in progress" - } - ReconfigResult::ErrorBadSequenceNumber => "5: Error - Bad Sequence Number", - ReconfigResult::InProgress => "6: In progress", - _ => "Unknown ReconfigResult", - }; - write!(f, "{s}") - } -} - -impl From for ReconfigResult { - fn from(v: u32) -> ReconfigResult { - match v { - 0 => ReconfigResult::SuccessNop, - 1 => ReconfigResult::SuccessPerformed, - 2 => ReconfigResult::Denied, - 3 => ReconfigResult::ErrorWrongSsn, - 4 => ReconfigResult::ErrorRequestAlreadyInProgress, - 5 => ReconfigResult::ErrorBadSequenceNumber, - 6 => ReconfigResult::InProgress, - _ => ReconfigResult::Unknown, - } - } -} - -///This parameter is used by the receiver of a Re-configuration Request -///Parameter to respond to the request. -/// -///0 1 2 3 -///0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Parameter Type = 16 | Parameter Length | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Re-configuration Response Sequence Number | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Result | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Sender's Next TSN (optional) | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -///| Receiver's Next TSN (optional) | -///+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamReconfigResponse { - /// This value is copied from the request parameter and is used by the - /// receiver of the Re-configuration Response Parameter to tie the - /// response to the request. - pub(crate) reconfig_response_sequence_number: u32, - /// This value describes the result of the processing of the request. - pub(crate) result: ReconfigResult, -} - -impl fmt::Display for ParamReconfigResponse { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {} {}", - self.header(), - self.reconfig_response_sequence_number, - self.result - ) - } -} - -impl Param for ParamReconfigResponse { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::ReconfigResp, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - - // validity of value_length is checked in ParamHeader::unmarshal - if header.value_length < 8 { - return Err(Error::ErrReconfigRespParamTooShort); - } - - let reader = - &mut raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - - let reconfig_response_sequence_number = reader.get_u32(); - let result = reader.get_u32().into(); - - Ok(ParamReconfigResponse { - reconfig_response_sequence_number, - result, - }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - buf.put_u32(self.reconfig_response_sequence_number); - buf.put_u32(self.result as u32); - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - 8 - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_requested_hmac_algorithm.rs b/sctp/src/param/param_requested_hmac_algorithm.rs deleted file mode 100644 index e9ef97b32..000000000 --- a/sctp/src/param/param_requested_hmac_algorithm.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::fmt; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -#[derive(Debug, Copy, Clone, PartialEq)] -#[repr(C)] -pub(crate) enum HmacAlgorithm { - HmacResv1 = 0, - HmacSha128 = 1, - HmacResv2 = 2, - HmacSha256 = 3, - Unknown, -} - -impl fmt::Display for HmacAlgorithm { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - HmacAlgorithm::HmacResv1 => "HMAC Reserved (0x00)", - HmacAlgorithm::HmacSha128 => "HMAC SHA-128", - HmacAlgorithm::HmacResv2 => "HMAC Reserved (0x02)", - HmacAlgorithm::HmacSha256 => "HMAC SHA-256", - _ => "Unknown HMAC Algorithm", - }; - write!(f, "{s}") - } -} - -impl From for HmacAlgorithm { - fn from(v: u16) -> HmacAlgorithm { - match v { - 0 => HmacAlgorithm::HmacResv1, - 1 => HmacAlgorithm::HmacSha128, - 2 => HmacAlgorithm::HmacResv2, - 3 => HmacAlgorithm::HmacSha256, - _ => HmacAlgorithm::Unknown, - } - } -} - -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamRequestedHmacAlgorithm { - pub(crate) available_algorithms: Vec, -} - -impl fmt::Display for ParamRequestedHmacAlgorithm { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {}", - self.header(), - self.available_algorithms - .iter() - .map(|ct| ct.to_string()) - .collect::>() - .join(" "), - ) - } -} - -impl Param for ParamRequestedHmacAlgorithm { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::ReqHmacAlgo, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - - let reader = - &mut raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - - let mut available_algorithms = vec![]; - let mut offset = 0; - while offset + 1 < header.value_length() { - let a: HmacAlgorithm = reader.get_u16().into(); - if a == HmacAlgorithm::HmacSha128 || a == HmacAlgorithm::HmacSha256 { - available_algorithms.push(a); - } else { - return Err(Error::ErrInvalidAlgorithmType); - } - - offset += 2; - } - - Ok(ParamRequestedHmacAlgorithm { - available_algorithms, - }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - for a in &self.available_algorithms { - buf.put_u16(*a as u16); - } - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - 2 * self.available_algorithms.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_state_cookie.rs b/sctp/src/param/param_state_cookie.rs deleted file mode 100644 index 7c9e785d3..000000000 --- a/sctp/src/param/param_state_cookie.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::fmt; - -use bytes::{Bytes, BytesMut}; -use rand::Rng; - -use super::param_header::*; -use super::param_type::*; -use super::*; - -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamStateCookie { - pub(crate) cookie: Bytes, -} - -/// String makes paramStateCookie printable -impl fmt::Display for ParamStateCookie { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}: {:?}", self.header(), self.cookie) - } -} - -impl Param for ParamStateCookie { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::StateCookie, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - let cookie = raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - Ok(ParamStateCookie { cookie }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - buf.extend(self.cookie.clone()); - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.cookie.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} - -impl ParamStateCookie { - pub(crate) fn new() -> Self { - let mut cookie = BytesMut::new(); - cookie.resize(32, 0); - rand::thread_rng().fill(cookie.as_mut()); - - ParamStateCookie { - cookie: cookie.freeze(), - } - } -} diff --git a/sctp/src/param/param_supported_extensions.rs b/sctp/src/param/param_supported_extensions.rs deleted file mode 100644 index b1689959d..000000000 --- a/sctp/src/param/param_supported_extensions.rs +++ /dev/null @@ -1,69 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::param_header::*; -use super::param_type::*; -use super::*; -use crate::chunk::chunk_type::*; - -#[derive(Default, Debug, Clone, PartialEq)] -pub(crate) struct ParamSupportedExtensions { - pub(crate) chunk_types: Vec, -} - -impl fmt::Display for ParamSupportedExtensions { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {}", - self.header(), - self.chunk_types - .iter() - .map(|ct| ct.to_string()) - .collect::>() - .join(" "), - ) - } -} - -impl Param for ParamSupportedExtensions { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::SupportedExt, - value_length: self.value_length() as u16, - } - } - - fn unmarshal(raw: &Bytes) -> Result { - let header = ParamHeader::unmarshal(raw)?; - - let reader = - &mut raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - - let mut chunk_types = vec![]; - while reader.has_remaining() { - chunk_types.push(ChunkType(reader.get_u8())); - } - - Ok(ParamSupportedExtensions { chunk_types }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> Result { - self.header().marshal_to(buf)?; - for ct in &self.chunk_types { - buf.put_u8(ct.0); - } - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.chunk_types.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } -} diff --git a/sctp/src/param/param_test.rs b/sctp/src/param/param_test.rs deleted file mode 100644 index d8c567ecd..000000000 --- a/sctp/src/param/param_test.rs +++ /dev/null @@ -1,270 +0,0 @@ -/////////////////////////////////////////////////////////////////// -//param_type_test -/////////////////////////////////////////////////////////////////// -use super::param_type::*; -use super::*; - -#[test] -fn test_parse_param_type_success() -> Result<()> { - let tests = vec![ - (Bytes::from_static(&[0x0, 0x1]), ParamType::HeartbeatInfo), - (Bytes::from_static(&[0x0, 0xd]), ParamType::OutSsnResetReq), - ]; - - for (mut binary, expected) in tests { - let pt: ParamType = binary.get_u16().into(); - assert_eq!(pt, expected); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//param_header_test -/////////////////////////////////////////////////////////////////// -use super::param_header::*; - -static PARAM_HEADER_BYTES: Bytes = Bytes::from_static(&[0x0, 0x1, 0x0, 0x4]); - -#[test] -fn test_param_header_success() -> Result<()> { - let tests = vec![( - PARAM_HEADER_BYTES.clone(), - ParamHeader { - typ: ParamType::HeartbeatInfo, - value_length: 0, - }, - )]; - - for (binary, parsed) in tests { - let actual = ParamHeader::unmarshal(&binary)?; - assert_eq!(actual, parsed); - let b = actual.marshal()?; - assert_eq!(b, binary); - } - - Ok(()) -} - -#[test] -fn test_param_header_unmarshal_failure() -> Result<()> { - let tests = vec![ - ("header too short", PARAM_HEADER_BYTES.slice(..2)), - // {"wrong param type", []byte{0x0, 0x0, 0x0, 0x4}}, // Not possible to fail parseParamType atm. - ( - "reported length below header length", - Bytes::from_static(&[0x0, 0xd, 0x0, 0x3]), - ), - ("wrong reported length", CHUNK_RECONFIG_PARAM_A.slice(0..4)), - ]; - - for (name, binary) in tests { - let result = ParamHeader::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//param_forward_tsn_supported_test -/////////////////////////////////////////////////////////////////// -use super::param_forward_tsn_supported::*; - -static PARAM_FORWARD_TSN_SUPPORTED_BYTES: Bytes = Bytes::from_static(&[0xc0, 0x0, 0x0, 0x4]); - -#[test] -fn test_param_forward_tsn_supported_success() -> Result<()> { - let tests = vec![( - PARAM_FORWARD_TSN_SUPPORTED_BYTES.clone(), - ParamForwardTsnSupported {}, - )]; - - for (binary, parsed) in tests { - let actual = ParamForwardTsnSupported::unmarshal(&binary)?; - assert_eq!(actual, parsed); - let b = actual.marshal()?; - assert_eq!(b, binary); - } - - Ok(()) -} - -#[test] -fn test_param_forward_tsn_supported_failure() -> Result<()> { - let tests = vec![("param too short", Bytes::from_static(&[0x0, 0xd, 0x0]))]; - - for (name, binary) in tests { - let result = ParamForwardTsnSupported::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//param_outgoing_reset_request_test -/////////////////////////////////////////////////////////////////// -use super::param_outgoing_reset_request::*; - -static CHUNK_RECONFIG_PARAM_A: Bytes = Bytes::from_static(&[ - 0x0, 0xd, 0x0, 0x16, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, - 0x5, 0x0, 0x6, -]); -static CHUNK_RECONFIG_PARAM_B: Bytes = Bytes::from_static(&[ - 0x0, 0xd, 0x0, 0x10, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, -]); - -#[test] -fn test_param_outgoing_reset_request_success() -> Result<()> { - let tests = vec![ - ( - CHUNK_RECONFIG_PARAM_A.clone(), - ParamOutgoingResetRequest { - reconfig_request_sequence_number: 1, - reconfig_response_sequence_number: 2, - sender_last_tsn: 3, - stream_identifiers: vec![4, 5, 6], - }, - ), - ( - CHUNK_RECONFIG_PARAM_B.clone(), - ParamOutgoingResetRequest { - reconfig_request_sequence_number: 1, - reconfig_response_sequence_number: 2, - sender_last_tsn: 3, - stream_identifiers: vec![], - }, - ), - ]; - - for (binary, parsed) in tests { - let actual = ParamOutgoingResetRequest::unmarshal(&binary)?; - assert_eq!(actual, parsed); - let b = actual.marshal()?; - assert_eq!(b, binary); - } - - Ok(()) -} - -#[test] -fn test_param_outgoing_reset_request_failure() -> Result<()> { - let tests = vec![ - ("packet too short", CHUNK_RECONFIG_PARAM_A.slice(..8)), - ("param too short", Bytes::from_static(&[0x0, 0xd, 0x0, 0x4])), - ]; - - for (name, binary) in tests { - let result = ParamOutgoingResetRequest::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//param_reconfig_response_test -/////////////////////////////////////////////////////////////////// -use bytes::Buf; - -use super::param_reconfig_response::*; - -static CHUNK_RECONFIG_RESPONSE: Bytes = - Bytes::from_static(&[0x0, 0x10, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1]); - -#[test] -fn test_param_reconfig_response_success() -> Result<()> { - let tests = vec![( - CHUNK_RECONFIG_RESPONSE.clone(), - ParamReconfigResponse { - reconfig_response_sequence_number: 1, - result: ReconfigResult::SuccessPerformed, - }, - )]; - - for (binary, parsed) in tests { - let actual = ParamReconfigResponse::unmarshal(&binary)?; - assert_eq!(actual, parsed); - let b = actual.marshal()?; - assert_eq!(b, binary); - } - - Ok(()) -} - -#[test] -fn test_param_reconfig_response_failure() -> Result<()> { - let tests = vec![ - ("packet too short", CHUNK_RECONFIG_RESPONSE.slice(..8)), - ( - "param too short", - Bytes::from_static(&[0x0, 0x10, 0x0, 0x4]), - ), - ]; - - for (name, binary) in tests { - let result = ParamReconfigResponse::unmarshal(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} - -#[test] -fn test_reconfig_result_stringer() -> Result<()> { - let tests = vec![ - (ReconfigResult::SuccessNop, "0: Success - Nothing to do"), - (ReconfigResult::SuccessPerformed, "1: Success - Performed"), - (ReconfigResult::Denied, "2: Denied"), - (ReconfigResult::ErrorWrongSsn, "3: Error - Wrong SSN"), - ( - ReconfigResult::ErrorRequestAlreadyInProgress, - "4: Error - Request already in progress", - ), - ( - ReconfigResult::ErrorBadSequenceNumber, - "5: Error - Bad Sequence Number", - ), - (ReconfigResult::InProgress, "6: In progress"), - ]; - - for (result, expected) in tests { - let actual = result.to_string(); - assert_eq!(actual, expected, "Test case {expected}"); - } - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//param_test -/////////////////////////////////////////////////////////////////// - -#[test] -fn test_build_param_success() -> Result<()> { - let tests = vec![CHUNK_RECONFIG_PARAM_A.clone()]; - - for binary in tests { - let p = build_param(&binary)?; - let b = p.marshal()?; - assert_eq!(b, binary); - } - - Ok(()) -} - -#[test] -fn test_build_param_failure() -> Result<()> { - let tests = vec![ - ("invalid ParamType", Bytes::from_static(&[0x0, 0x0])), - ("build failure", CHUNK_RECONFIG_PARAM_A.slice(..8)), - ]; - - for (name, binary) in tests { - let result = build_param(&binary); - assert!(result.is_err(), "expected unmarshal: {name} to fail."); - } - - Ok(()) -} diff --git a/sctp/src/param/param_type.rs b/sctp/src/param/param_type.rs deleted file mode 100644 index 7e195e5db..000000000 --- a/sctp/src/param/param_type.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::fmt; - -/// paramType represents a SCTP INIT/INITACK parameter -#[derive(Debug, Copy, Clone, PartialEq)] -#[repr(C)] -pub(crate) enum ParamType { - HeartbeatInfo, - /// Heartbeat Info [RFCRFC4960] - Ipv4Addr, - /// IPv4 IP [RFCRFC4960] - Ipv6Addr, - /// IPv6 IP [RFCRFC4960] - StateCookie, - /// State Cookie [RFCRFC4960] - UnrecognizedParam, - /// Unrecognized Parameters [RFCRFC4960] - CookiePreservative, - /// Cookie Preservative [RFCRFC4960] - HostNameAddr, - /// Host Name IP [RFCRFC4960] - SupportedAddrTypes, - /// Supported IP Types [RFCRFC4960] - OutSsnResetReq, - /// Outgoing SSN Reset Request Parameter [RFCRFC6525] - IncSsnResetReq, - /// Incoming SSN Reset Request Parameter [RFCRFC6525] - SsnTsnResetReq, - /// SSN/TSN Reset Request Parameter [RFCRFC6525] - ReconfigResp, - /// Re-configuration Response Parameter [RFCRFC6525] - AddOutStreamsReq, - /// Add Outgoing Streams Request Parameter [RFCRFC6525] - AddIncStreamsReq, - /// Add Incoming Streams Request Parameter [RFCRFC6525] - Random, - /// Random (0x8002) [RFCRFC4805] - ChunkList, - /// Chunk List (0x8003) [RFCRFC4895] - ReqHmacAlgo, - /// Requested HMAC Algorithm Parameter (0x8004) [RFCRFC4895] - Padding, - /// Padding (0x8005) - SupportedExt, - /// Supported Extensions (0x8008) [RFCRFC5061] - ForwardTsnSupp, - /// Forward TSN supported (0xC000) [RFCRFC3758] - AddIpAddr, - /// Add IP IP (0xC001) [RFCRFC5061] - DelIpaddr, - /// Delete IP IP (0xC002) [RFCRFC5061] - ErrClauseInd, - /// Error Cause Indication (0xC003) [RFCRFC5061] - SetPriAddr, - /// Set Primary IP (0xC004) [RFCRFC5061] - SuccessInd, - /// Success Indication (0xC005) [RFCRFC5061] - AdaptLayerInd, - /// Adaptation Layer Indication (0xC006) [RFCRFC5061] - Unknown { - param_type: u16, - }, -} - -impl fmt::Display for ParamType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - ParamType::HeartbeatInfo => "Heartbeat Info", - ParamType::Ipv4Addr => "IPv4 IP", - ParamType::Ipv6Addr => "IPv6 IP", - ParamType::StateCookie => "State Cookie", - ParamType::UnrecognizedParam => "Unrecognized Parameters", - ParamType::CookiePreservative => "Cookie Preservative", - ParamType::HostNameAddr => "Host Name IP", - ParamType::SupportedAddrTypes => "Supported IP Types", - ParamType::OutSsnResetReq => "Outgoing SSN Reset Request Parameter", - ParamType::IncSsnResetReq => "Incoming SSN Reset Request Parameter", - ParamType::SsnTsnResetReq => "SSN/TSN Reset Request Parameter", - ParamType::ReconfigResp => "Re-configuration Response Parameter", - ParamType::AddOutStreamsReq => "Add Outgoing Streams Request Parameter", - ParamType::AddIncStreamsReq => "Add Incoming Streams Request Parameter", - ParamType::Random => "Random", - ParamType::ChunkList => "Chunk List", - ParamType::ReqHmacAlgo => "Requested HMAC Algorithm Parameter", - ParamType::Padding => "Padding", - ParamType::SupportedExt => "Supported Extensions", - ParamType::ForwardTsnSupp => "Forward TSN supported", - ParamType::AddIpAddr => "Add IP IP", - ParamType::DelIpaddr => "Delete IP IP", - ParamType::ErrClauseInd => "Error Cause Indication", - ParamType::SetPriAddr => "Set Primary IP", - ParamType::SuccessInd => "Success Indication", - ParamType::AdaptLayerInd => "Adaptation Layer Indication", - _ => "Unknown ParamType", - }; - write!(f, "{s}") - } -} - -impl From for ParamType { - fn from(v: u16) -> ParamType { - match v { - 1 => ParamType::HeartbeatInfo, - 5 => ParamType::Ipv4Addr, - 6 => ParamType::Ipv6Addr, - 7 => ParamType::StateCookie, - 8 => ParamType::UnrecognizedParam, - 9 => ParamType::CookiePreservative, - 11 => ParamType::HostNameAddr, - 12 => ParamType::SupportedAddrTypes, - 13 => ParamType::OutSsnResetReq, - 14 => ParamType::IncSsnResetReq, - 15 => ParamType::SsnTsnResetReq, - 16 => ParamType::ReconfigResp, - 17 => ParamType::AddOutStreamsReq, - 18 => ParamType::AddIncStreamsReq, - 32770 => ParamType::Random, - 32771 => ParamType::ChunkList, - 32772 => ParamType::ReqHmacAlgo, - 32773 => ParamType::Padding, - 32776 => ParamType::SupportedExt, - 49152 => ParamType::ForwardTsnSupp, - 49153 => ParamType::AddIpAddr, - 49154 => ParamType::DelIpaddr, - 49155 => ParamType::ErrClauseInd, - 49156 => ParamType::SetPriAddr, - 49157 => ParamType::SuccessInd, - 49158 => ParamType::AdaptLayerInd, - unknown => ParamType::Unknown { - param_type: unknown, - }, - } - } -} - -impl From for u16 { - fn from(v: ParamType) -> u16 { - match v { - ParamType::HeartbeatInfo => 1, - ParamType::Ipv4Addr => 5, - ParamType::Ipv6Addr => 6, - ParamType::StateCookie => 7, - ParamType::UnrecognizedParam => 8, - ParamType::CookiePreservative => 9, - ParamType::HostNameAddr => 11, - ParamType::SupportedAddrTypes => 12, - ParamType::OutSsnResetReq => 13, - ParamType::IncSsnResetReq => 14, - ParamType::SsnTsnResetReq => 15, - ParamType::ReconfigResp => 16, - ParamType::AddOutStreamsReq => 17, - ParamType::AddIncStreamsReq => 18, - ParamType::Random => 32770, - ParamType::ChunkList => 32771, - ParamType::ReqHmacAlgo => 32772, - ParamType::Padding => 32773, - ParamType::SupportedExt => 32776, - ParamType::ForwardTsnSupp => 49152, - ParamType::AddIpAddr => 49153, - ParamType::DelIpaddr => 49154, - ParamType::ErrClauseInd => 49155, - ParamType::SetPriAddr => 49156, - ParamType::SuccessInd => 49157, - ParamType::AdaptLayerInd => 49158, - ParamType::Unknown { param_type, .. } => param_type, - } - } -} diff --git a/sctp/src/param/param_unknown.rs b/sctp/src/param/param_unknown.rs deleted file mode 100644 index 028b38169..000000000 --- a/sctp/src/param/param_unknown.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::any::Any; -use std::fmt::{Debug, Display, Formatter}; - -use bytes::{Bytes, BytesMut}; - -use crate::param::param_header::{ParamHeader, PARAM_HEADER_LENGTH}; -use crate::param::param_type::ParamType; -use crate::param::Param; - -/// This type is meant to represent ANY parameter for un/remarshaling purposes, where we do not have a more specific type for it. -/// This means we do not really understand the semantics of the param but can represent it. -/// -/// This is useful for usage in e.g.`ParamUnrecognized` where we want to report some unrecognized params back to the sender. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ParamUnknown { - typ: u16, - value: Bytes, -} - -impl Display for ParamUnknown { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "ParamUnknown( {} {:?} )", self.header(), self.value) - } -} - -impl Param for ParamUnknown { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::Unknown { - param_type: self.typ, - }, - value_length: self.value.len() as u16, - } - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn unmarshal(raw: &Bytes) -> crate::error::Result - where - Self: Sized, - { - let header = ParamHeader::unmarshal(raw)?; - let value = raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - Ok(Self { - typ: header.typ.into(), - value, - }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> crate::error::Result { - self.header().marshal_to(buf)?; - buf.extend(self.value.clone()); - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.value.len() - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/sctp/src/param/param_unrecognized.rs b/sctp/src/param/param_unrecognized.rs deleted file mode 100644 index dc3c6dfb4..000000000 --- a/sctp/src/param/param_unrecognized.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::any::Any; -use std::fmt::{Debug, Display, Formatter}; - -use bytes::{Bytes, BytesMut}; - -use crate::param::param_header::PARAM_HEADER_LENGTH; -use crate::param::param_type::ParamType; -use crate::param::{build_param, Param, ParamHeader}; - -/// This is the parameter type used to report unrecognized parameters in e.g. init chunks back to the sender in the init ack. -/// The contained param is likely to be a `ParamUnknown` but might be something more specific. -#[derive(Clone, Debug)] -pub struct ParamUnrecognized { - param: Box, -} - -impl ParamUnrecognized { - pub(crate) fn wrap(param: Box) -> Self { - Self { param } - } -} - -impl Display for ParamUnrecognized { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str("UnrecognizedParam")?; - Display::fmt(&self.param, f) - } -} - -impl Param for ParamUnrecognized { - fn header(&self) -> ParamHeader { - ParamHeader { - typ: ParamType::UnrecognizedParam, - value_length: self.value_length() as u16, - } - } - - fn as_any(&self) -> &(dyn Any + Send + Sync) { - self - } - - fn unmarshal(raw: &Bytes) -> crate::error::Result - where - Self: Sized, - { - let header = ParamHeader::unmarshal(raw)?; - let raw_param = raw.slice(PARAM_HEADER_LENGTH..PARAM_HEADER_LENGTH + header.value_length()); - let param = build_param(&raw_param)?; - Ok(Self { param }) - } - - fn marshal_to(&self, buf: &mut BytesMut) -> crate::error::Result { - self.header().marshal_to(buf)?; - self.param.marshal_to(buf)?; - Ok(buf.len()) - } - - fn value_length(&self) -> usize { - self.param.value_length() + PARAM_HEADER_LENGTH - } - - fn clone_to(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/sctp/src/queue/control_queue.rs b/sctp/src/queue/control_queue.rs deleted file mode 100644 index 10b19539f..000000000 --- a/sctp/src/queue/control_queue.rs +++ /dev/null @@ -1,6 +0,0 @@ -use std::collections::VecDeque; - -use crate::packet::Packet; - -/// control queue -pub(crate) type ControlQueue = VecDeque; diff --git a/sctp/src/queue/mod.rs b/sctp/src/queue/mod.rs deleted file mode 100644 index 836e7aeb4..000000000 --- a/sctp/src/queue/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(test)] -mod queue_test; - -pub(crate) mod control_queue; -pub(crate) mod payload_queue; -pub(crate) mod pending_queue; -pub(crate) mod reassembly_queue; diff --git a/sctp/src/queue/payload_queue.rs b/sctp/src/queue/payload_queue.rs deleted file mode 100644 index 481d99d9e..000000000 --- a/sctp/src/queue/payload_queue.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicUsize; - -use crate::chunk::chunk_payload_data::ChunkPayloadData; -use crate::chunk::chunk_selective_ack::GapAckBlock; -use crate::util::*; - -#[derive(Default, Debug)] -pub(crate) struct PayloadQueue { - pub(crate) length: Arc, - pub(crate) chunk_map: HashMap, - pub(crate) sorted: VecDeque, - pub(crate) dup_tsn: Vec, - pub(crate) n_bytes: usize, -} - -impl PayloadQueue { - pub(crate) fn new(length: Arc) -> Self { - length.store(0, Ordering::SeqCst); - PayloadQueue { - length, - ..Default::default() - } - } - - pub(crate) fn can_push(&self, p: &ChunkPayloadData, cumulative_tsn: u32) -> bool { - !(self.chunk_map.contains_key(&p.tsn) || sna32lte(p.tsn, cumulative_tsn)) - } - - pub(crate) fn push_no_check(&mut self, p: ChunkPayloadData) { - let tsn = p.tsn; - self.n_bytes += p.user_data.len(); - self.chunk_map.insert(tsn, p); - self.length.fetch_add(1, Ordering::SeqCst); - - if self.sorted.is_empty() || sna32gt(tsn, *self.sorted.back().unwrap()) { - self.sorted.push_back(tsn); - } else if sna32lt(tsn, *self.sorted.front().unwrap()) { - self.sorted.push_front(tsn); - } else { - fn compare_tsn(a: u32, b: u32) -> std::cmp::Ordering { - if sna32lt(a, b) { - std::cmp::Ordering::Less - } else { - std::cmp::Ordering::Greater - } - } - let pos = match self - .sorted - .binary_search_by(|element| compare_tsn(*element, tsn)) - { - Ok(pos) => pos, - Err(pos) => pos, - }; - self.sorted.insert(pos, tsn); - } - } - - /// push pushes a payload data. If the payload data is already in our queue or - /// older than our cumulative_tsn marker, it will be recorded as duplications, - /// which can later be retrieved using popDuplicates. - pub(crate) fn push(&mut self, p: ChunkPayloadData, cumulative_tsn: u32) -> bool { - let ok = self.chunk_map.contains_key(&p.tsn); - if ok || sna32lte(p.tsn, cumulative_tsn) { - // Found the packet, log in dups - self.dup_tsn.push(p.tsn); - return false; - } - - self.push_no_check(p); - true - } - - /// pop pops only if the oldest chunk's TSN matches the given TSN. - pub(crate) fn pop(&mut self, tsn: u32) -> Option { - if Some(&tsn) == self.sorted.front() { - self.sorted.pop_front(); - if let Some(c) = self.chunk_map.remove(&tsn) { - self.length.fetch_sub(1, Ordering::SeqCst); - self.n_bytes -= c.user_data.len(); - return Some(c); - } - } - - None - } - - /// get returns reference to chunkPayloadData with the given TSN value. - pub(crate) fn get(&self, tsn: u32) -> Option<&ChunkPayloadData> { - self.chunk_map.get(&tsn) - } - pub(crate) fn get_mut(&mut self, tsn: u32) -> Option<&mut ChunkPayloadData> { - self.chunk_map.get_mut(&tsn) - } - - /// popDuplicates returns an array of TSN values that were found duplicate. - pub(crate) fn pop_duplicates(&mut self) -> Vec { - self.dup_tsn.drain(..).collect() - } - - pub(crate) fn get_gap_ack_blocks(&self, cumulative_tsn: u32) -> Vec { - if self.chunk_map.is_empty() { - return vec![]; - } - - let mut b = GapAckBlock::default(); - let mut gap_ack_blocks = vec![]; - for (i, tsn) in self.sorted.iter().enumerate() { - let diff = if *tsn >= cumulative_tsn { - (*tsn - cumulative_tsn) as u16 - } else { - 0 - }; - - if i == 0 { - b.start = diff; - b.end = b.start; - } else if b.end + 1 == diff { - b.end += 1; - } else { - gap_ack_blocks.push(b); - - b.start = diff; - b.end = diff; - } - } - - gap_ack_blocks.push(b); - - gap_ack_blocks - } - - pub(crate) fn get_gap_ack_blocks_string(&self, cumulative_tsn: u32) -> String { - let mut s = format!("cumTSN={cumulative_tsn}"); - for b in self.get_gap_ack_blocks(cumulative_tsn) { - s += format!(",{}-{}", b.start, b.end).as_str(); - } - s - } - - pub(crate) fn mark_as_acked(&mut self, tsn: u32) -> usize { - let n_bytes_acked = if let Some(c) = self.chunk_map.get_mut(&tsn) { - c.acked = true; - c.retransmit = false; - let n = c.user_data.len(); - self.n_bytes -= n; - c.user_data.clear(); - n - } else { - 0 - }; - - n_bytes_acked - } - - pub(crate) fn get_last_tsn_received(&self) -> Option<&u32> { - self.sorted.back() - } - - pub(crate) fn mark_all_to_retrasmit(&mut self) { - for c in self.chunk_map.values_mut() { - if c.acked || c.abandoned() { - continue; - } - c.retransmit = true; - } - } - - pub(crate) fn get_num_bytes(&self) -> usize { - self.n_bytes - } - - pub(crate) fn len(&self) -> usize { - assert_eq!(self.chunk_map.len(), self.length.load(Ordering::SeqCst)); - self.chunk_map.len() - } - - pub(crate) fn is_empty(&self) -> bool { - self.len() == 0 - } -} diff --git a/sctp/src/queue/pending_queue.rs b/sctp/src/queue/pending_queue.rs deleted file mode 100644 index caff6ab42..000000000 --- a/sctp/src/queue/pending_queue.rs +++ /dev/null @@ -1,260 +0,0 @@ -use std::collections::VecDeque; -use std::sync::atomic::Ordering; - -use portable_atomic::{AtomicBool, AtomicUsize}; -use tokio::sync::{Mutex, Semaphore}; -use util::sync::RwLock; - -use crate::chunk::chunk_payload_data::ChunkPayloadData; - -// TODO: benchmark performance between multiple Atomic+Mutex vs one Mutex - -// Some tests push a lot of data before starting to process any data... -#[cfg(test)] -const QUEUE_BYTES_LIMIT: usize = 128 * 1024 * 1024; -/// Maximum size of the pending queue, in bytes. -#[cfg(not(test))] -const QUEUE_BYTES_LIMIT: usize = 128 * 1024; -/// Total user data size, beyond which the packet will be split into chunks. The chunks will be -/// added to the pending queue one by one. -const QUEUE_APPEND_LARGE: usize = (QUEUE_BYTES_LIMIT * 2) / 3; - -/// Basic queue for either ordered or unordered chunks. -pub(crate) type PendingBaseQueue = VecDeque; - -/// A queue for both ordered and unordered chunks. -#[derive(Debug)] -pub(crate) struct PendingQueue { - // These two fields limit appending bytes to the queue - // This two step process is necessary because - // A) We need backpressure which the semaphore applies by limiting the total amount of bytes via the permits - // B) The chunks of one fragmented message need to be put in direct sequence into the queue which the lock guarantees - // - // The semaphore is not inside the lock because the permits need to be returned without needing a lock on the semaphore - semaphore_lock: Mutex<()>, - semaphore: Semaphore, - - unordered_queue: RwLock, - ordered_queue: RwLock, - queue_len: AtomicUsize, - n_bytes: AtomicUsize, - selected: AtomicBool, - unordered_is_selected: AtomicBool, -} - -impl Default for PendingQueue { - fn default() -> Self { - PendingQueue::new() - } -} - -impl PendingQueue { - pub(crate) fn new() -> Self { - Self { - semaphore_lock: Mutex::default(), - semaphore: Semaphore::new(QUEUE_BYTES_LIMIT), - unordered_queue: Default::default(), - ordered_queue: Default::default(), - queue_len: Default::default(), - n_bytes: Default::default(), - selected: Default::default(), - unordered_is_selected: Default::default(), - } - } - - /// Appends a chunk to the back of the pending queue. - pub(crate) async fn push(&self, c: ChunkPayloadData) { - let user_data_len = c.user_data.len(); - - { - let _sem_lock = self.semaphore_lock.lock().await; - let permits = self.semaphore.acquire_many(user_data_len as u32).await; - // unwrap ok because we never close the semaphore unless we have dropped self - permits.unwrap().forget(); - - if c.unordered { - let mut unordered_queue = self.unordered_queue.write(); - unordered_queue.push_back(c); - } else { - let mut ordered_queue = self.ordered_queue.write(); - ordered_queue.push_back(c); - } - } - - self.n_bytes.fetch_add(user_data_len, Ordering::SeqCst); - self.queue_len.fetch_add(1, Ordering::SeqCst); - } - - /// Appends chunks to the back of the pending queue. - /// - /// # Panics - /// - /// If it's a mix of unordered and ordered chunks. - pub(crate) async fn append(&self, chunks: Vec) { - if chunks.is_empty() { - return; - } - - let total_user_data_len = chunks.iter().fold(0, |acc, c| acc + c.user_data.len()); - - if total_user_data_len >= QUEUE_APPEND_LARGE { - self.append_large(chunks).await - } else { - let _sem_lock = self.semaphore_lock.lock().await; - let permits = self - .semaphore - .acquire_many(total_user_data_len as u32) - .await; - // unwrap ok because we never close the semaphore unless we have dropped self - permits.unwrap().forget(); - self.append_unlimited(chunks, total_user_data_len); - } - } - - // If this is a very large message we append chunks one by one to allow progress while we are appending - async fn append_large(&self, chunks: Vec) { - // lock this for the whole duration - let _sem_lock = self.semaphore_lock.lock().await; - - for chunk in chunks.into_iter() { - let user_data_len = chunk.user_data.len(); - let permits = self.semaphore.acquire_many(user_data_len as u32).await; - // unwrap ok because we never close the semaphore unless we have dropped self - permits.unwrap().forget(); - - if chunk.unordered { - let mut unordered_queue = self.unordered_queue.write(); - unordered_queue.push_back(chunk); - } else { - let mut ordered_queue = self.ordered_queue.write(); - ordered_queue.push_back(chunk); - } - self.n_bytes.fetch_add(user_data_len, Ordering::SeqCst); - self.queue_len.fetch_add(1, Ordering::SeqCst); - } - } - - /// Assumes that A) enough permits have been acquired and forget from the semaphore and that the semaphore_lock is held - fn append_unlimited(&self, chunks: Vec, total_user_data_len: usize) { - let chunks_len = chunks.len(); - let unordered = chunks - .first() - .expect("chunks to not be empty because of the above check") - .unordered; - if unordered { - let mut unordered_queue = self.unordered_queue.write(); - assert!( - chunks.iter().all(|c| c.unordered), - "expected all chunks to be unordered" - ); - unordered_queue.extend(chunks); - } else { - let mut ordered_queue = self.ordered_queue.write(); - assert!( - chunks.iter().all(|c| !c.unordered), - "expected all chunks to be ordered" - ); - ordered_queue.extend(chunks); - } - - self.n_bytes - .fetch_add(total_user_data_len, Ordering::SeqCst); - self.queue_len.fetch_add(chunks_len, Ordering::SeqCst); - } - - pub(crate) fn peek(&self) -> Option { - if self.selected.load(Ordering::SeqCst) { - if self.unordered_is_selected.load(Ordering::SeqCst) { - let unordered_queue = self.unordered_queue.read(); - return unordered_queue.front().cloned(); - } else { - let ordered_queue = self.ordered_queue.read(); - return ordered_queue.front().cloned(); - } - } - - let c = { - let unordered_queue = self.unordered_queue.read(); - unordered_queue.front().cloned() - }; - - if c.is_some() { - return c; - } - - let ordered_queue = self.ordered_queue.read(); - ordered_queue.front().cloned() - } - - pub(crate) fn pop( - &self, - beginning_fragment: bool, - unordered: bool, - ) -> Option { - let popped = if self.selected.load(Ordering::SeqCst) { - let popped = if self.unordered_is_selected.load(Ordering::SeqCst) { - let mut unordered_queue = self.unordered_queue.write(); - unordered_queue.pop_front() - } else { - let mut ordered_queue = self.ordered_queue.write(); - ordered_queue.pop_front() - }; - if let Some(p) = &popped { - if p.ending_fragment { - self.selected.store(false, Ordering::SeqCst); - } - } - popped - } else { - if !beginning_fragment { - return None; - } - if unordered { - let popped = { - let mut unordered_queue = self.unordered_queue.write(); - unordered_queue.pop_front() - }; - if let Some(p) = &popped { - if !p.ending_fragment { - self.selected.store(true, Ordering::SeqCst); - self.unordered_is_selected.store(true, Ordering::SeqCst); - } - } - popped - } else { - let popped = { - let mut ordered_queue = self.ordered_queue.write(); - ordered_queue.pop_front() - }; - if let Some(p) = &popped { - if !p.ending_fragment { - self.selected.store(true, Ordering::SeqCst); - self.unordered_is_selected.store(false, Ordering::SeqCst); - } - } - popped - } - }; - - if let Some(p) = &popped { - let user_data_len = p.user_data.len(); - self.n_bytes.fetch_sub(user_data_len, Ordering::SeqCst); - self.queue_len.fetch_sub(1, Ordering::SeqCst); - self.semaphore.add_permits(user_data_len); - } - - popped - } - - pub(crate) fn get_num_bytes(&self) -> usize { - self.n_bytes.load(Ordering::SeqCst) - } - - pub(crate) fn len(&self) -> usize { - self.queue_len.load(Ordering::SeqCst) - } - - pub(crate) fn is_empty(&self) -> bool { - self.len() == 0 - } -} diff --git a/sctp/src/queue/queue_test.rs b/sctp/src/queue/queue_test.rs deleted file mode 100644 index 4ba1a92c6..000000000 --- a/sctp/src/queue/queue_test.rs +++ /dev/null @@ -1,997 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -/////////////////////////////////////////////////////////////////// -//payload_queue_test -/////////////////////////////////////////////////////////////////// -use super::payload_queue::*; -use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier}; -use crate::chunk::chunk_selective_ack::GapAckBlock; -use crate::error::{Error, Result}; - -fn make_payload(tsn: u32, n_bytes: usize) -> ChunkPayloadData { - ChunkPayloadData { - tsn, - user_data: { - let mut b = BytesMut::new(); - b.resize(n_bytes, 0); - b.freeze() - }, - ..Default::default() - } -} - -#[test] -fn test_payload_queue_push_no_check() -> Result<()> { - let mut pq = PayloadQueue::new(Arc::new(AtomicUsize::new(0))); - - pq.push_no_check(make_payload(0, 10)); - assert_eq!(pq.get_num_bytes(), 10, "total bytes mismatch"); - assert_eq!(pq.len(), 1, "item count mismatch"); - pq.push_no_check(make_payload(1, 11)); - assert_eq!(pq.get_num_bytes(), 21, "total bytes mismatch"); - assert_eq!(pq.len(), 2, "item count mismatch"); - pq.push_no_check(make_payload(2, 12)); - assert_eq!(pq.get_num_bytes(), 33, "total bytes mismatch"); - assert_eq!(pq.len(), 3, "item count mismatch"); - - for i in 0..3 { - assert!(!pq.sorted.is_empty(), "should not be empty"); - let c = pq.pop(i); - assert!(c.is_some(), "pop should succeed"); - if let Some(c) = c { - assert_eq!(c.tsn, i, "TSN should match"); - } - } - - assert_eq!(pq.get_num_bytes(), 0, "total bytes mismatch"); - assert_eq!(pq.len(), 0, "item count mismatch"); - - assert!(pq.sorted.is_empty(), "should be empty"); - pq.push_no_check(make_payload(3, 13)); - assert_eq!(pq.get_num_bytes(), 13, "total bytes mismatch"); - pq.push_no_check(make_payload(4, 14)); - assert_eq!(pq.get_num_bytes(), 27, "total bytes mismatch"); - - for i in 3..5 { - assert!(!pq.sorted.is_empty(), "should not be empty"); - let c = pq.pop(i); - assert!(c.is_some(), "pop should succeed"); - if let Some(c) = c { - assert_eq!(c.tsn, i, "TSN should match"); - } - } - - assert_eq!(pq.get_num_bytes(), 0, "total bytes mismatch"); - assert_eq!(pq.len(), 0, "item count mismatch"); - - Ok(()) -} - -#[test] -fn test_payload_queue_get_gap_ack_block() -> Result<()> { - let mut pq = PayloadQueue::new(Arc::new(AtomicUsize::new(0))); - - pq.push(make_payload(1, 0), 0); - pq.push(make_payload(2, 0), 0); - pq.push(make_payload(3, 0), 0); - pq.push(make_payload(4, 0), 0); - pq.push(make_payload(5, 0), 0); - pq.push(make_payload(6, 0), 0); - - let gab1 = [GapAckBlock { start: 1, end: 6 }]; - let gab2 = pq.get_gap_ack_blocks(0); - assert!(!gab2.is_empty()); - assert_eq!(gab2.len(), 1); - - assert_eq!(gab2[0].start, gab1[0].start); - assert_eq!(gab2[0].end, gab1[0].end); - - pq.push(make_payload(8, 0), 0); - pq.push(make_payload(9, 0), 0); - - let gab1 = [ - GapAckBlock { start: 1, end: 6 }, - GapAckBlock { start: 8, end: 9 }, - ]; - let gab2 = pq.get_gap_ack_blocks(0); - assert!(!gab2.is_empty()); - assert_eq!(gab2.len(), 2); - - assert_eq!(gab2[0].start, gab1[0].start); - assert_eq!(gab2[0].end, gab1[0].end); - assert_eq!(gab2[1].start, gab1[1].start); - assert_eq!(gab2[1].end, gab1[1].end); - - Ok(()) -} - -#[test] -fn test_payload_queue_get_last_tsn_received() -> Result<()> { - let mut pq = PayloadQueue::new(Arc::new(AtomicUsize::new(0))); - - // empty queie should return false - let ok = pq.get_last_tsn_received(); - assert!(ok.is_none(), "should be none"); - - let ok = pq.push(make_payload(20, 0), 0); - assert!(ok, "should be true"); - let tsn = pq.get_last_tsn_received(); - assert!(tsn.is_some(), "should be false"); - assert_eq!(tsn, Some(&20), "should match"); - - // append should work - let ok = pq.push(make_payload(21, 0), 0); - assert!(ok, "should be true"); - let tsn = pq.get_last_tsn_received(); - assert!(tsn.is_some(), "should be false"); - assert_eq!(tsn, Some(&21), "should match"); - - // check if sorting applied - let ok = pq.push(make_payload(19, 0), 0); - assert!(ok, "should be true"); - let tsn = pq.get_last_tsn_received(); - assert!(tsn.is_some(), "should be false"); - assert_eq!(tsn, Some(&21), "should match"); - - Ok(()) -} - -#[test] -fn test_payload_queue_mark_all_to_retrasmit() -> Result<()> { - let mut pq = PayloadQueue::new(Arc::new(AtomicUsize::new(0))); - - for i in 0..3 { - pq.push(make_payload(i + 1, 10), 0); - } - pq.mark_as_acked(2); - pq.mark_all_to_retrasmit(); - - let c = pq.get(1); - assert!(c.is_some(), "should be true"); - assert!(c.unwrap().retransmit, "should be marked as retransmit"); - let c = pq.get(2); - assert!(c.is_some(), "should be true"); - assert!(!c.unwrap().retransmit, "should NOT be marked as retransmit"); - let c = pq.get(3); - assert!(c.is_some(), "should be true"); - assert!(c.unwrap().retransmit, "should be marked as retransmit"); - - Ok(()) -} - -#[test] -fn test_payload_queue_reset_retransmit_flag_on_ack() -> Result<()> { - let mut pq = PayloadQueue::new(Arc::new(AtomicUsize::new(0))); - - for i in 0..4 { - pq.push(make_payload(i + 1, 10), 0); - } - - pq.mark_all_to_retrasmit(); - pq.mark_as_acked(2); // should cancel retransmission for TSN 2 - pq.mark_as_acked(4); // should cancel retransmission for TSN 4 - - let c = pq.get(1); - assert!(c.is_some(), "should be true"); - assert!(c.unwrap().retransmit, "should be marked as retransmit"); - let c = pq.get(2); - assert!(c.is_some(), "should be true"); - assert!(!c.unwrap().retransmit, "should NOT be marked as retransmit"); - let c = pq.get(3); - assert!(c.is_some(), "should be true"); - assert!(c.unwrap().retransmit, "should be marked as retransmit"); - let c = pq.get(4); - assert!(c.is_some(), "should be true"); - assert!(!c.unwrap().retransmit, "should NOT be marked as retransmit"); - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//pending_queue_test -/////////////////////////////////////////////////////////////////// -use super::pending_queue::*; - -const NO_FRAGMENT: usize = 0; -const FRAG_BEGIN: usize = 1; -const FRAG_MIDDLE: usize = 2; -const FRAG_END: usize = 3; - -fn make_data_chunk(tsn: u32, unordered: bool, frag: usize) -> ChunkPayloadData { - let mut b = false; - let mut e = false; - - match frag { - NO_FRAGMENT => { - b = true; - e = true; - } - FRAG_BEGIN => { - b = true; - } - FRAG_END => e = true, - _ => {} - }; - - ChunkPayloadData { - tsn, - unordered, - beginning_fragment: b, - ending_fragment: e, - user_data: { - let mut b = BytesMut::new(); - b.resize(10, 0); // always 10 bytes - b.freeze() - }, - ..Default::default() - } -} - -#[test] -fn test_pending_base_queue_push_and_pop() -> Result<()> { - let mut pq = PendingBaseQueue::new(); - pq.push_back(make_data_chunk(0, false, NO_FRAGMENT)); - pq.push_back(make_data_chunk(1, false, NO_FRAGMENT)); - pq.push_back(make_data_chunk(2, false, NO_FRAGMENT)); - - for i in 0..3 { - let c = pq.get(i); - assert!(c.is_some(), "should not be none"); - assert_eq!(c.unwrap().tsn, i as u32, "TSN should match"); - } - - for i in 0..3 { - let c = pq.pop_front(); - assert!(c.is_some(), "should not be none"); - assert_eq!(c.unwrap().tsn, i, "TSN should match"); - } - - pq.push_back(make_data_chunk(3, false, NO_FRAGMENT)); - pq.push_back(make_data_chunk(4, false, NO_FRAGMENT)); - - for i in 3..5 { - let c = pq.pop_front(); - assert!(c.is_some(), "should not be none"); - assert_eq!(c.unwrap().tsn, i, "TSN should match"); - } - Ok(()) -} - -#[test] -fn test_pending_base_queue_out_of_bounce() -> Result<()> { - let mut pq = PendingBaseQueue::new(); - assert!(pq.pop_front().is_none(), "should be none"); - assert!(pq.front().is_none(), "should be none"); - - pq.push_back(make_data_chunk(0, false, NO_FRAGMENT)); - assert!(pq.get(1).is_none(), "should be none"); - - Ok(()) -} - -// NOTE: TSN is not used in pendingQueue in the actual usage. -// Following tests use TSN field as a chunk ID. -#[tokio::test] -async fn test_pending_queue_push_and_pop() -> Result<()> { - let pq = PendingQueue::new(); - pq.push(make_data_chunk(0, false, NO_FRAGMENT)).await; - assert_eq!(pq.get_num_bytes(), 10, "total bytes mismatch"); - pq.push(make_data_chunk(1, false, NO_FRAGMENT)).await; - assert_eq!(pq.get_num_bytes(), 20, "total bytes mismatch"); - pq.push(make_data_chunk(2, false, NO_FRAGMENT)).await; - assert_eq!(pq.get_num_bytes(), 30, "total bytes mismatch"); - - for i in 0..3 { - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, i, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error: {i}"); - } - - assert_eq!(pq.get_num_bytes(), 0, "total bytes mismatch"); - - pq.push(make_data_chunk(3, false, NO_FRAGMENT)).await; - assert_eq!(pq.get_num_bytes(), 10, "total bytes mismatch"); - pq.push(make_data_chunk(4, false, NO_FRAGMENT)).await; - assert_eq!(pq.get_num_bytes(), 20, "total bytes mismatch"); - - for i in 3..5 { - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, i, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error: {i}"); - } - - assert_eq!(pq.get_num_bytes(), 0, "total bytes mismatch"); - - Ok(()) -} - -#[tokio::test] -async fn test_pending_queue_unordered_wins() -> Result<()> { - let pq = PendingQueue::new(); - - pq.push(make_data_chunk(0, false, NO_FRAGMENT)).await; - assert_eq!(10, pq.get_num_bytes(), "total bytes mismatch"); - pq.push(make_data_chunk(1, true, NO_FRAGMENT)).await; - assert_eq!(20, pq.get_num_bytes(), "total bytes mismatch"); - pq.push(make_data_chunk(2, false, NO_FRAGMENT)).await; - assert_eq!(30, pq.get_num_bytes(), "total bytes mismatch"); - pq.push(make_data_chunk(3, true, NO_FRAGMENT)).await; - assert_eq!(40, pq.get_num_bytes(), "total bytes mismatch"); - - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, 1, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error"); - - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, 3, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error"); - - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, 0, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error"); - - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, 2, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error"); - - assert_eq!(pq.get_num_bytes(), 0, "total bytes mismatch"); - - Ok(()) -} - -#[tokio::test] -async fn test_pending_queue_fragments() -> Result<()> { - let pq = PendingQueue::new(); - pq.push(make_data_chunk(0, false, FRAG_BEGIN)).await; - pq.push(make_data_chunk(1, false, FRAG_MIDDLE)).await; - pq.push(make_data_chunk(2, false, FRAG_END)).await; - pq.push(make_data_chunk(3, true, FRAG_BEGIN)).await; - pq.push(make_data_chunk(4, true, FRAG_MIDDLE)).await; - pq.push(make_data_chunk(5, true, FRAG_END)).await; - - let expects = vec![3, 4, 5, 0, 1, 2]; - - for exp in expects { - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, exp, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error: {exp}"); - } - - Ok(()) -} - -// Once decided ordered or unordered, the decision should persist until -// it pops a chunk with ending_fragment flags set to true. -#[tokio::test] -async fn test_pending_queue_selection_persistence() -> Result<()> { - let pq = PendingQueue::new(); - pq.push(make_data_chunk(0, false, FRAG_BEGIN)).await; - - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, 0, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error: {}", 0); - - pq.push(make_data_chunk(1, true, NO_FRAGMENT)).await; - pq.push(make_data_chunk(2, false, FRAG_MIDDLE)).await; - pq.push(make_data_chunk(3, false, FRAG_END)).await; - - let expects = vec![2, 3, 1]; - - for exp in expects { - let c = pq.peek(); - assert!(c.is_some(), "peek error"); - let c = c.unwrap(); - assert_eq!(c.tsn, exp, "TSN should match"); - let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); - let result = pq.pop(beginning_fragment, unordered); - assert!(result.is_some(), "should not error: {exp}"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_pending_queue_append() -> Result<()> { - let pq = PendingQueue::new(); - pq.append(vec![ - make_data_chunk(0, false, NO_FRAGMENT), - make_data_chunk(1, false, NO_FRAGMENT), - make_data_chunk(3, false, NO_FRAGMENT), - ]) - .await; - assert_eq!(pq.get_num_bytes(), 30, "total bytes mismatch"); - assert_eq!(pq.len(), 3, "len mismatch"); - - Ok(()) -} - -/////////////////////////////////////////////////////////////////// -//reassembly_queue_test -/////////////////////////////////////////////////////////////////// -use std::sync::Arc; - -use portable_atomic::AtomicUsize; - -use super::reassembly_queue::*; - -#[test] -fn test_reassembly_queue_ordered_fragments() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - tsn: 1, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - ending_fragment: true, - tsn: 2, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"DEFG"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "chunk set should be complete"); - assert_eq!(7, rq.get_num_bytes(), "num bytes mismatch"); - - let mut buf = vec![0u8; 16]; - - let (n, ppi) = rq.read(&mut buf)?; - assert_eq!(n, 7, "should received 7 bytes"); - assert_eq!(rq.get_num_bytes(), 0, "num bytes mismatch"); - assert_eq!(ppi, org_ppi, "should have valid ppi"); - assert_eq!(&buf[..n], b"ABCDEFG", "data should match"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_unordered_fragments() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - beginning_fragment: true, - tsn: 1, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - tsn: 2, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"DEFG"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 7, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - ending_fragment: true, - tsn: 3, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"H"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "chunk set should be complete"); - assert_eq!(rq.get_num_bytes(), 8, "num bytes mismatch"); - - let mut buf = vec![0u8; 16]; - - let (n, ppi) = rq.read(&mut buf)?; - assert_eq!(n, 8, "should received 8 bytes"); - assert_eq!(rq.get_num_bytes(), 0, "num bytes mismatch"); - assert_eq!(ppi, org_ppi, "should have valid ppi"); - assert_eq!(&buf[..n], b"ABCDEFGH", "data should match"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_ordered_and_unordered_fragments() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - let org_ppi = PayloadProtocolIdentifier::Binary; - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - ending_fragment: true, - tsn: 1, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "chunk set should be complete"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - beginning_fragment: true, - ending_fragment: true, - tsn: 2, - stream_sequence_number: 1, - user_data: Bytes::from_static(b"DEF"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "chunk set should be complete"); - assert_eq!(rq.get_num_bytes(), 6, "num bytes mismatch"); - - // - // Now we have two complete chunks ready to read in the reassemblyQueue. - // - - let mut buf = vec![0u8; 16]; - - // Should read unordered chunks first - let (n, ppi) = rq.read(&mut buf)?; - assert_eq!(n, 3, "should received 3 bytes"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - assert_eq!(ppi, org_ppi, "should have valid ppi"); - assert_eq!(&buf[..n], b"DEF", "data should match"); - - // Next should read ordered chunks - let (n, ppi) = rq.read(&mut buf)?; - assert_eq!(n, 3, "should received 3 bytes"); - assert_eq!(rq.get_num_bytes(), 0, "num bytes mismatch"); - assert_eq!(ppi, org_ppi, "should have valid ppi"); - assert_eq!(&buf[..n], b"ABC", "data should match"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_unordered_complete_skips_incomplete() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - beginning_fragment: true, - tsn: 10, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"IN"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(2, rq.get_num_bytes(), "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - ending_fragment: true, - tsn: 12, // <- incongiguous - stream_sequence_number: 1, - user_data: Bytes::from_static(b"COMPLETE"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 10, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - beginning_fragment: true, - ending_fragment: true, - tsn: 13, - stream_sequence_number: 1, - user_data: Bytes::from_static(b"GOOD"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "chunk set should be complete"); - assert_eq!(rq.get_num_bytes(), 14, "num bytes mismatch"); - - // - // Now we have two complete chunks ready to read in the reassemblyQueue. - // - - let mut buf = vec![0u8; 16]; - - // Should pick the one that has "GOOD" - let (n, ppi) = rq.read(&mut buf)?; - assert_eq!(n, 4, "should receive 4 bytes"); - assert_eq!(rq.get_num_bytes(), 10, "num bytes mismatch"); - assert_eq!(ppi, org_ppi, "should have valid ppi"); - assert_eq!(&buf[..n], b"GOOD", "data should match"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_ignores_chunk_with_wrong_si() -> Result<()> { - let mut rq = ReassemblyQueue::new(123); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - stream_identifier: 124, - beginning_fragment: true, - ending_fragment: true, - tsn: 10, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"IN"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk should be ignored"); - assert_eq!(rq.get_num_bytes(), 0, "num bytes mismatch"); - Ok(()) -} - -#[test] -fn test_reassembly_queue_ignores_chunk_with_stale_ssn() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - rq.next_ssn = 7; // forcibly set expected SSN to 7 - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - ending_fragment: true, - tsn: 10, - stream_sequence_number: 6, // <-- stale - user_data: Bytes::from_static(b"IN"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk should not be ignored"); - assert_eq!(rq.get_num_bytes(), 0, "num bytes mismatch"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_should_fail_to_read_incomplete_chunk() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - tsn: 123, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"IN"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "the set should not be complete"); - assert_eq!(rq.get_num_bytes(), 2, "num bytes mismatch"); - - let mut buf = vec![0u8; 16]; - let result = rq.read(&mut buf); - assert!(result.is_err(), "read() should not succeed"); - assert_eq!(rq.get_num_bytes(), 2, "num bytes mismatch"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_should_fail_to_read_if_the_nex_ssn_is_not_ready() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - ending_fragment: true, - tsn: 123, - stream_sequence_number: 1, - user_data: Bytes::from_static(b"IN"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "the set should be complete"); - assert_eq!(rq.get_num_bytes(), 2, "num bytes mismatch"); - - let mut buf = vec![0u8; 16]; - let result = rq.read(&mut buf); - assert!(result.is_err(), "read() should not succeed"); - assert_eq!(rq.get_num_bytes(), 2, "num bytes mismatch"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_detect_buffer_too_short() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - ending_fragment: true, - tsn: 123, - stream_sequence_number: 0, - user_data: Bytes::from_static(b"0123456789"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "the set should be complete"); - assert_eq!(rq.get_num_bytes(), 10, "num bytes mismatch"); - - let mut buf = vec![0u8; 8]; // <- passing buffer too short - let result = rq.read(&mut buf); - assert!(result.is_err(), "read() should not succeed"); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrShortBuffer { size: 8 }, - "read() should not succeed" - ); - } - assert_eq!(rq.get_num_bytes(), 0, "num bytes mismatch"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_forward_tsn_for_ordered_fragments() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let ssn_complete = 5u16; - let ssn_dropped = 6u16; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - ending_fragment: true, - tsn: 10, - stream_sequence_number: ssn_complete, - user_data: Bytes::from_static(b"123"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(complete, "chunk set should be complete"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - beginning_fragment: true, - tsn: 11, - stream_sequence_number: ssn_dropped, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 6, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - tsn: 12, - stream_sequence_number: ssn_dropped, - user_data: Bytes::from_static(b"DEF"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 9, "num bytes mismatch"); - - rq.forward_tsn_for_ordered(ssn_dropped); - - assert_eq!(rq.ordered.len(), 1, "there should be one chunk left"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - Ok(()) -} - -#[test] -fn test_reassembly_queue_forward_tsn_for_unordered_fragments() -> Result<()> { - let mut rq = ReassemblyQueue::new(0); - - let org_ppi = PayloadProtocolIdentifier::Binary; - - let ssn_dropped = 6u16; - let ssn_kept = 7u16; - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - beginning_fragment: true, - tsn: 11, - stream_sequence_number: ssn_dropped, - user_data: Bytes::from_static(b"ABC"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - tsn: 12, - stream_sequence_number: ssn_dropped, - user_data: Bytes::from_static(b"DEF"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 6, "num bytes mismatch"); - - let chunk = ChunkPayloadData { - payload_type: org_ppi, - unordered: true, - tsn: 14, - beginning_fragment: true, - stream_sequence_number: ssn_kept, - user_data: Bytes::from_static(b"SOS"), - ..Default::default() - }; - - let complete = rq.push(chunk); - assert!(!complete, "chunk set should not be complete yet"); - assert_eq!(rq.get_num_bytes(), 9, "num bytes mismatch"); - - // At this point, there are 3 chunks in the rq.unorderedChunks. - // This call should remove chunks with tsn equals to 13 or older. - rq.forward_tsn_for_unordered(13); - - // As a result, there should be one chunk (tsn=14) - assert_eq!( - rq.unordered_chunks.len(), - 1, - "there should be one chunk kept" - ); - assert_eq!(rq.get_num_bytes(), 3, "num bytes mismatch"); - - Ok(()) -} - -#[test] -fn test_chunk_set_empty_chunk_set() -> Result<()> { - let cset = ChunkSet::new(0, PayloadProtocolIdentifier::default()); - assert!(!cset.is_complete(), "empty chunkSet cannot be complete"); - Ok(()) -} - -#[test] -fn test_chunk_set_push_dup_chunks_to_chunk_set() -> Result<()> { - let mut cset = ChunkSet::new(0, PayloadProtocolIdentifier::default()); - cset.push(ChunkPayloadData { - tsn: 100, - beginning_fragment: true, - ..Default::default() - }); - let complete = cset.push(ChunkPayloadData { - tsn: 100, - ending_fragment: true, - ..Default::default() - }); - assert!(!complete, "chunk with dup TSN is not complete"); - assert_eq!(cset.chunks.len(), 1, "chunk with dup TSN should be ignored"); - Ok(()) -} - -#[test] -fn test_chunk_set_incomplete_chunk_set_no_beginning() -> Result<()> { - let cset = ChunkSet { - ssn: 0, - ppi: PayloadProtocolIdentifier::default(), - chunks: vec![], - }; - assert!( - !cset.is_complete(), - "chunkSet not starting with B=1 cannot be complete" - ); - Ok(()) -} - -#[test] -fn test_chunk_set_incomplete_chunk_set_no_contiguous_tsn() -> Result<()> { - let cset = ChunkSet { - ssn: 0, - ppi: PayloadProtocolIdentifier::default(), - chunks: vec![ - ChunkPayloadData { - tsn: 100, - beginning_fragment: true, - ..Default::default() - }, - ChunkPayloadData { - tsn: 101, - ..Default::default() - }, - ChunkPayloadData { - tsn: 103, - ending_fragment: true, - ..Default::default() - }, - ], - }; - assert!( - !cset.is_complete(), - "chunkSet not starting with incontiguous tsn cannot be complete" - ); - Ok(()) -} diff --git a/sctp/src/queue/reassembly_queue.rs b/sctp/src/queue/reassembly_queue.rs deleted file mode 100644 index 3aafecd23..000000000 --- a/sctp/src/queue/reassembly_queue.rs +++ /dev/null @@ -1,353 +0,0 @@ -use std::cmp::Ordering; - -use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier}; -use crate::error::{Error, Result}; -use crate::util::*; - -fn sort_chunks_by_tsn(c: &mut [ChunkPayloadData]) { - c.sort_by(|a, b| { - if sna32lt(a.tsn, b.tsn) { - Ordering::Less - } else { - Ordering::Greater - } - }); -} - -fn sort_chunks_by_ssn(c: &mut [ChunkSet]) { - c.sort_by(|a, b| { - if sna16lt(a.ssn, b.ssn) { - Ordering::Less - } else { - Ordering::Greater - } - }); -} - -/// chunkSet is a set of chunks that share the same SSN -#[derive(Debug, Clone)] -pub(crate) struct ChunkSet { - /// used only with the ordered chunks - pub(crate) ssn: u16, - pub(crate) ppi: PayloadProtocolIdentifier, - pub(crate) chunks: Vec, -} - -impl ChunkSet { - pub(crate) fn new(ssn: u16, ppi: PayloadProtocolIdentifier) -> Self { - ChunkSet { - ssn, - ppi, - chunks: vec![], - } - } - - pub(crate) fn push(&mut self, chunk: ChunkPayloadData) -> bool { - // check if dup - for c in &self.chunks { - if c.tsn == chunk.tsn { - return false; - } - } - - // append and sort - self.chunks.push(chunk); - sort_chunks_by_tsn(&mut self.chunks); - - // Check if we now have a complete set - self.is_complete() - } - - pub(crate) fn is_complete(&self) -> bool { - // Condition for complete set - // 0. Has at least one chunk. - // 1. Begins with beginningFragment set to true - // 2. Ends with endingFragment set to true - // 3. TSN monotinically increase by 1 from beginning to end - - // 0. - let n_chunks = self.chunks.len(); - if n_chunks == 0 { - return false; - } - - // 1. - if !self.chunks[0].beginning_fragment { - return false; - } - - // 2. - if !self.chunks[n_chunks - 1].ending_fragment { - return false; - } - - // 3. - let mut last_tsn = 0u32; - for (i, c) in self.chunks.iter().enumerate() { - if i > 0 { - // Fragments must have contiguous TSN - // From RFC 4960 Section 3.3.1: - // When a user message is fragmented into multiple chunks, the TSNs are - // used by the receiver to reassemble the message. This means that the - // TSNs for each fragment of a fragmented user message MUST be strictly - // sequential. - if c.tsn != last_tsn + 1 { - // mid or end fragment is missing - return false; - } - } - - last_tsn = c.tsn; - } - - true - } -} - -#[derive(Default, Debug)] -pub(crate) struct ReassemblyQueue { - pub(crate) si: u16, - pub(crate) next_ssn: u16, - /// expected SSN for next ordered chunk - pub(crate) ordered: Vec, - pub(crate) unordered: Vec, - pub(crate) unordered_chunks: Vec, - pub(crate) n_bytes: usize, -} - -impl ReassemblyQueue { - /// From RFC 4960 Sec 6.5: - /// The Stream Sequence Number in all the streams MUST start from 0 when - /// the association is Established. Also, when the Stream Sequence - /// Number reaches the value 65535 the next Stream Sequence Number MUST - /// be set to 0. - pub(crate) fn new(si: u16) -> Self { - ReassemblyQueue { - si, - next_ssn: 0, // From RFC 4960 Sec 6.5: - ordered: vec![], - unordered: vec![], - unordered_chunks: vec![], - n_bytes: 0, - } - } - - pub(crate) fn push(&mut self, chunk: ChunkPayloadData) -> bool { - if chunk.stream_identifier != self.si { - return false; - } - - if chunk.unordered { - // First, insert into unordered_chunks array - //atomic.AddUint64(&r.n_bytes, uint64(len(chunk.userData))) - self.n_bytes += chunk.user_data.len(); - self.unordered_chunks.push(chunk); - sort_chunks_by_tsn(&mut self.unordered_chunks); - - // Scan unordered_chunks that are contiguous (in TSN) - // If found, append the complete set to the unordered array - if let Some(cset) = self.find_complete_unordered_chunk_set() { - self.unordered.push(cset); - return true; - } - - false - } else { - // This is an ordered chunk - if sna16lt(chunk.stream_sequence_number, self.next_ssn) { - return false; - } - - self.n_bytes += chunk.user_data.len(); - - // Check if a chunkSet with the SSN already exists - for s in &mut self.ordered { - if s.ssn == chunk.stream_sequence_number { - return s.push(chunk); - } - } - - // If not found, create a new chunkSet - let mut cset = ChunkSet::new(chunk.stream_sequence_number, chunk.payload_type); - let unordered = chunk.unordered; - let ok = cset.push(chunk); - self.ordered.push(cset); - if !unordered { - sort_chunks_by_ssn(&mut self.ordered); - } - - ok - } - } - - pub(crate) fn find_complete_unordered_chunk_set(&mut self) -> Option { - let mut start_idx = -1isize; - let mut n_chunks = 0usize; - let mut last_tsn = 0u32; - let mut found = false; - - for (i, c) in self.unordered_chunks.iter().enumerate() { - // seek beigining - if c.beginning_fragment { - start_idx = i as isize; - n_chunks = 1; - last_tsn = c.tsn; - - if c.ending_fragment { - found = true; - break; - } - continue; - } - - if start_idx < 0 { - continue; - } - - // Check if contiguous in TSN - if c.tsn != last_tsn + 1 { - start_idx = -1; - continue; - } - - last_tsn = c.tsn; - n_chunks += 1; - - if c.ending_fragment { - found = true; - break; - } - } - - if !found { - return None; - } - - // Extract the range of chunks - let chunks: Vec = self - .unordered_chunks - .drain(start_idx as usize..(start_idx as usize) + n_chunks) - .collect(); - - let mut chunk_set = ChunkSet::new(0, chunks[0].payload_type); - chunk_set.chunks = chunks; - - Some(chunk_set) - } - - pub(crate) fn is_readable(&self) -> bool { - // Check unordered first - if !self.unordered.is_empty() { - // The chunk sets in r.unordered should all be complete. - return true; - } - - // Check ordered sets - if !self.ordered.is_empty() { - let cset = &self.ordered[0]; - if cset.is_complete() && sna16lte(cset.ssn, self.next_ssn) { - return true; - } - } - false - } - - pub(crate) fn read(&mut self, buf: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> { - // Check unordered first - let cset = if !self.unordered.is_empty() { - self.unordered.remove(0) - } else if !self.ordered.is_empty() { - // Now, check ordered - let cset = &self.ordered[0]; - if !cset.is_complete() { - return Err(Error::ErrTryAgain); - } - if sna16gt(cset.ssn, self.next_ssn) { - return Err(Error::ErrTryAgain); - } - if cset.ssn == self.next_ssn { - // From RFC 4960 Sec 6.5: - self.next_ssn = self.next_ssn.wrapping_add(1); - } - self.ordered.remove(0) - } else { - return Err(Error::ErrTryAgain); - }; - - // Concat all fragments into the buffer - let mut n_written = 0; - let mut err = None; - for c in &cset.chunks { - let to_copy = c.user_data.len(); - self.subtract_num_bytes(to_copy); - if err.is_none() { - let n = std::cmp::min(to_copy, buf.len() - n_written); - buf[n_written..n_written + n].copy_from_slice(&c.user_data[..n]); - n_written += n; - if n < to_copy { - err = Some(Error::ErrShortBuffer { size: buf.len() }); - } - } - } - - if let Some(err) = err { - Err(err) - } else { - Ok((n_written, cset.ppi)) - } - } - - /// Use last_ssn to locate a chunkSet then remove it if the set has - /// not been complete - pub(crate) fn forward_tsn_for_ordered(&mut self, last_ssn: u16) { - let num_bytes = self - .ordered - .iter() - .filter(|s| sna16lte(s.ssn, last_ssn) && !s.is_complete()) - .fold(0, |n, s| { - n + s.chunks.iter().fold(0, |acc, c| acc + c.user_data.len()) - }); - self.subtract_num_bytes(num_bytes); - - self.ordered - .retain(|s| !sna16lte(s.ssn, last_ssn) || s.is_complete()); - - // Finally, forward next_ssn - if sna16lte(self.next_ssn, last_ssn) { - self.next_ssn = last_ssn.wrapping_add(1); - } - } - - /// Remove all fragments in the unordered sets that contains chunks - /// equal to or older than `new_cumulative_tsn`. - /// We know all sets in the r.unordered are complete ones. - /// Just remove chunks that are equal to or older than new_cumulative_tsn - /// from the unordered_chunks - pub(crate) fn forward_tsn_for_unordered(&mut self, new_cumulative_tsn: u32) { - let mut last_idx: isize = -1; - for (i, c) in self.unordered_chunks.iter().enumerate() { - if sna32gt(c.tsn, new_cumulative_tsn) { - break; - } - last_idx = i as isize; - } - if last_idx >= 0 { - for i in 0..(last_idx + 1) as usize { - self.subtract_num_bytes(self.unordered_chunks[i].user_data.len()); - } - self.unordered_chunks.drain(..(last_idx + 1) as usize); - } - } - - pub(crate) fn subtract_num_bytes(&mut self, n_bytes: usize) { - if self.n_bytes >= n_bytes { - self.n_bytes -= n_bytes; - } else { - self.n_bytes = 0; - } - } - - pub(crate) fn get_num_bytes(&self) -> usize { - self.n_bytes - } -} diff --git a/sctp/src/stream/mod.rs b/sctp/src/stream/mod.rs deleted file mode 100644 index 1ae1980c1..000000000 --- a/sctp/src/stream/mod.rs +++ /dev/null @@ -1,852 +0,0 @@ -#[cfg(test)] -mod stream_test; - -use std::future::Future; -use std::net::Shutdown; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{fmt, io}; - -use arc_swap::ArcSwapOption; -use bytes::Bytes; -use portable_atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::sync::{mpsc, Mutex, Notify}; - -use crate::association::AssociationState; -use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier}; -use crate::error::{Error, Result}; -use crate::queue::pending_queue::PendingQueue; -use crate::queue::reassembly_queue::ReassemblyQueue; - -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -#[repr(C)] -pub enum ReliabilityType { - /// ReliabilityTypeReliable is used for reliable transmission - #[default] - Reliable = 0, - /// ReliabilityTypeRexmit is used for partial reliability by retransmission count - Rexmit = 1, - /// ReliabilityTypeTimed is used for partial reliability by retransmission duration - Timed = 2, -} - -impl fmt::Display for ReliabilityType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - ReliabilityType::Reliable => "Reliable", - ReliabilityType::Rexmit => "Rexmit", - ReliabilityType::Timed => "Timed", - }; - write!(f, "{s}") - } -} - -impl From for ReliabilityType { - fn from(v: u8) -> ReliabilityType { - match v { - 1 => ReliabilityType::Rexmit, - 2 => ReliabilityType::Timed, - _ => ReliabilityType::Reliable, - } - } -} - -pub type OnBufferedAmountLowFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -// TODO: benchmark performance between multiple Atomic+Mutex vs one Mutex - -/// Stream represents an SCTP stream -#[derive(Default)] -pub struct Stream { - pub(crate) max_payload_size: u32, - pub(crate) max_message_size: Arc, // clone from association - pub(crate) state: Arc, // clone from association - pub(crate) awake_write_loop_ch: Option>>, - pub(crate) pending_queue: Arc, - - pub(crate) stream_identifier: u16, - pub(crate) default_payload_type: AtomicU32, //PayloadProtocolIdentifier, - pub(crate) reassembly_queue: Mutex, - pub(crate) sequence_number: AtomicU16, - pub(crate) read_notifier: Notify, - pub(crate) read_shutdown: AtomicBool, - pub(crate) write_shutdown: AtomicBool, - pub(crate) unordered: AtomicBool, - pub(crate) reliability_type: AtomicU8, //ReliabilityType, - pub(crate) reliability_value: AtomicU32, - pub(crate) buffered_amount: AtomicUsize, - pub(crate) buffered_amount_low: AtomicUsize, - pub(crate) on_buffered_amount_low: ArcSwapOption>, - pub(crate) name: String, -} - -impl fmt::Debug for Stream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Stream") - .field("max_payload_size", &self.max_payload_size) - .field("max_message_size", &self.max_message_size) - .field("state", &self.state) - .field("awake_write_loop_ch", &self.awake_write_loop_ch) - .field("stream_identifier", &self.stream_identifier) - .field("default_payload_type", &self.default_payload_type) - .field("reassembly_queue", &self.reassembly_queue) - .field("sequence_number", &self.sequence_number) - .field("read_shutdown", &self.read_shutdown) - .field("write_shutdown", &self.write_shutdown) - .field("unordered", &self.unordered) - .field("reliability_type", &self.reliability_type) - .field("reliability_value", &self.reliability_value) - .field("buffered_amount", &self.buffered_amount) - .field("buffered_amount_low", &self.buffered_amount_low) - .field("name", &self.name) - .finish() - } -} - -impl Stream { - pub(crate) fn new( - name: String, - stream_identifier: u16, - max_payload_size: u32, - max_message_size: Arc, - state: Arc, - awake_write_loop_ch: Option>>, - pending_queue: Arc, - ) -> Self { - Stream { - max_payload_size, - max_message_size, - state, - awake_write_loop_ch, - pending_queue, - - stream_identifier, - default_payload_type: AtomicU32::new(0), //PayloadProtocolIdentifier::Unknown, - reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)), - sequence_number: AtomicU16::new(0), - read_notifier: Notify::new(), - read_shutdown: AtomicBool::new(false), - write_shutdown: AtomicBool::new(false), - unordered: AtomicBool::new(false), - reliability_type: AtomicU8::new(0), //ReliabilityType::Reliable, - reliability_value: AtomicU32::new(0), - buffered_amount: AtomicUsize::new(0), - buffered_amount_low: AtomicUsize::new(0), - on_buffered_amount_low: ArcSwapOption::empty(), - name, - } - } - - /// stream_identifier returns the Stream identifier associated to the stream. - pub fn stream_identifier(&self) -> u16 { - self.stream_identifier - } - - /// set_default_payload_type sets the default payload type used by write. - pub fn set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier) { - self.default_payload_type - .store(default_payload_type as u32, Ordering::SeqCst); - } - - /// set_reliability_params sets reliability parameters for this stream. - pub fn set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32) { - log::debug!( - "[{}] reliability params: ordered={} type={} value={}", - self.name, - !unordered, - rel_type, - rel_val - ); - self.unordered.store(unordered, Ordering::SeqCst); - self.reliability_type - .store(rel_type as u8, Ordering::SeqCst); - self.reliability_value.store(rel_val, Ordering::SeqCst); - } - - /// Reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. - /// - /// Returns `Error::ErrShortBuffer` if `p` is too short. - /// Returns `0` if the reading half of this stream is shutdown or it (the stream) was reset. - pub async fn read(&self, p: &mut [u8]) -> Result { - let (n, _) = self.read_sctp(p).await?; - Ok(n) - } - - /// Reads a packet of len(p) bytes and returns the associated Payload Protocol Identifier. - /// - /// Returns `Error::ErrShortBuffer` if `p` is too short. - /// Returns `(0, PayloadProtocolIdentifier::Unknown)` if the reading half of this stream is shutdown or it (the stream) was reset. - pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> { - loop { - if self.read_shutdown.load(Ordering::SeqCst) { - return Ok((0, PayloadProtocolIdentifier::Unknown)); - } - - let result = { - let mut reassembly_queue = self.reassembly_queue.lock().await; - reassembly_queue.read(p) - }; - - match result { - Ok(_) | Err(Error::ErrShortBuffer { .. }) => return result, - Err(_) => { - // wait for the next chunk to become available - self.read_notifier.notified().await; - } - } - } - } - - pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) { - let readable = { - let mut reassembly_queue = self.reassembly_queue.lock().await; - if reassembly_queue.push(pd) { - let readable = reassembly_queue.is_readable(); - log::debug!("[{}] reassemblyQueue readable={}", self.name, readable); - readable - } else { - false - } - }; - - if readable { - log::debug!("[{}] readNotifier.signal()", self.name); - self.read_notifier.notify_one(); - log::debug!("[{}] readNotifier.signal() done", self.name); - } - } - - pub(crate) async fn handle_forward_tsn_for_ordered(&self, ssn: u16) { - if self.unordered.load(Ordering::SeqCst) { - return; // unordered chunks are handled by handleForwardUnordered method - } - - // Remove all chunks older than or equal to the new TSN from - // the reassembly_queue. - let readable = { - let mut reassembly_queue = self.reassembly_queue.lock().await; - reassembly_queue.forward_tsn_for_ordered(ssn); - reassembly_queue.is_readable() - }; - - // Notify the reader asynchronously if there's a data chunk to read. - if readable { - self.read_notifier.notify_one(); - } - } - - pub(crate) async fn handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32) { - if !self.unordered.load(Ordering::SeqCst) { - return; // ordered chunks are handled by handleForwardTSNOrdered method - } - - // Remove all chunks older than or equal to the new TSN from - // the reassembly_queue. - let readable = { - let mut reassembly_queue = self.reassembly_queue.lock().await; - reassembly_queue.forward_tsn_for_unordered(new_cumulative_tsn); - reassembly_queue.is_readable() - }; - - // Notify the reader asynchronously if there's a data chunk to read. - if readable { - self.read_notifier.notify_one(); - } - } - - /// Writes `p` to the DTLS connection with the default Payload Protocol Identifier. - /// - /// Returns an error if the write half of this stream is shutdown or `p` is too large. - pub async fn write(&self, p: &Bytes) -> Result { - self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into()) - .await - } - - /// Writes `p` to the DTLS connection with the given Payload Protocol Identifier. - /// - /// Returns an error if the write half of this stream is shutdown or `p` is too large. - pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result { - let chunks = self.prepare_write(p, ppi)?; - self.send_payload_data(chunks).await?; - - Ok(p.len()) - } - - /// common stuff for write and try_write - fn prepare_write( - &self, - p: &Bytes, - ppi: PayloadProtocolIdentifier, - ) -> Result> { - if self.write_shutdown.load(Ordering::SeqCst) { - return Err(Error::ErrStreamClosed); - } - - if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize { - return Err(Error::ErrOutboundPacketTooLarge); - } - - let state: AssociationState = self.state.load(Ordering::SeqCst).into(); - match state { - AssociationState::ShutdownSent - | AssociationState::ShutdownAckSent - | AssociationState::ShutdownPending - | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed), - _ => {} - }; - - Ok(self.packetize(p, ppi)) - } - - fn packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec { - let mut i = 0; - let mut remaining = raw.len(); - - // From draft-ietf-rtcweb-data-protocol-09, section 6: - // All Data Channel Establishment Protocol messages MUST be sent using - // ordered delivery and reliable transmission. - let unordered = - ppi != PayloadProtocolIdentifier::Dcep && self.unordered.load(Ordering::SeqCst); - - let mut chunks = vec![]; - - let head_abandoned = Arc::new(AtomicBool::new(false)); - let head_all_inflight = Arc::new(AtomicBool::new(false)); - while remaining != 0 { - let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); //self.association.max_payload_size - - // Copy the userdata since we'll have to store it until acked - // and the caller may re-use the buffer in the mean time - let user_data = raw.slice(i..i + fragment_size); - - let chunk = ChunkPayloadData { - stream_identifier: self.stream_identifier, - user_data, - unordered, - beginning_fragment: i == 0, - ending_fragment: remaining - fragment_size == 0, - immediate_sack: false, - payload_type: ppi, - stream_sequence_number: self.sequence_number.load(Ordering::SeqCst), - abandoned: head_abandoned.clone(), // all fragmented chunks use the same abandoned - all_inflight: head_all_inflight.clone(), // all fragmented chunks use the same all_inflight - ..Default::default() - }; - - chunks.push(chunk); - - remaining -= fragment_size; - i += fragment_size; - } - - // RFC 4960 Sec 6.6 - // Note: When transmitting ordered and unordered data, an endpoint does - // not increment its Stream Sequence Number when transmitting a DATA - // chunk with U flag set to 1. - if !unordered { - self.sequence_number.fetch_add(1, Ordering::SeqCst); - } - - let old_value = self.buffered_amount.fetch_add(raw.len(), Ordering::SeqCst); - log::trace!("[{}] bufferedAmount = {}", self.name, old_value + raw.len()); - - chunks - } - - /// Closes both read and write halves of this stream. - /// - /// Use [`Stream::shutdown`] instead. - #[deprecated] - pub async fn close(&self) -> Result<()> { - self.shutdown(Shutdown::Both).await - } - - /// Shuts down the read, write, or both halves of this stream. - /// - /// This function will cause all pending and future I/O on the specified portions to return - /// immediately with an appropriate value (see the documentation of [`Shutdown`]). - /// - /// Resets the stream when both halves of this stream are shutdown. - pub async fn shutdown(&self, how: Shutdown) -> Result<()> { - if self.read_shutdown.load(Ordering::SeqCst) && self.write_shutdown.load(Ordering::SeqCst) { - return Ok(()); - } - - if how == Shutdown::Write || how == Shutdown::Both { - self.write_shutdown.store(true, Ordering::SeqCst); - } - - if (how == Shutdown::Read || how == Shutdown::Both) - && !self.read_shutdown.swap(true, Ordering::SeqCst) - { - self.read_notifier.notify_waiters(); - } - - if how == Shutdown::Both - || (self.read_shutdown.load(Ordering::SeqCst) - && self.write_shutdown.load(Ordering::SeqCst)) - { - // Reset the stream - // https://tools.ietf.org/html/rfc6525 - self.send_reset_request(self.stream_identifier).await?; - } - - Ok(()) - } - - /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream. - pub fn buffered_amount(&self) -> usize { - self.buffered_amount.load(Ordering::SeqCst) - } - - /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is - /// considered "low." Defaults to 0. - pub fn buffered_amount_low_threshold(&self) -> usize { - self.buffered_amount_low.load(Ordering::SeqCst) - } - - /// set_buffered_amount_low_threshold is used to update the threshold. - /// See buffered_amount_low_threshold(). - pub fn set_buffered_amount_low_threshold(&self, th: usize) { - self.buffered_amount_low.store(th, Ordering::SeqCst); - } - - /// on_buffered_amount_low sets the callback handler which would be called when the number of - /// bytes of outgoing data buffered is lower than the threshold. - pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { - self.on_buffered_amount_low - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// This method is called by association's read_loop (go-)routine to notify this stream - /// of the specified amount of outgoing data has been delivered to the peer. - pub(crate) async fn on_buffer_released(&self, n_bytes_released: i64) { - if n_bytes_released <= 0 { - return; - } - - let from_amount = self.buffered_amount.load(Ordering::SeqCst); - let new_amount = if from_amount < n_bytes_released as usize { - self.buffered_amount.store(0, Ordering::SeqCst); - log::error!( - "[{}] released buffer size {} should be <= {}", - self.name, - n_bytes_released, - 0, - ); - 0 - } else { - self.buffered_amount - .fetch_sub(n_bytes_released as usize, Ordering::SeqCst); - - from_amount - n_bytes_released as usize - }; - - let buffered_amount_low = self.buffered_amount_low.load(Ordering::SeqCst); - - log::trace!( - "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}", - self.name, - new_amount, - from_amount, - buffered_amount_low, - ); - - if from_amount > buffered_amount_low && new_amount <= buffered_amount_low { - if let Some(handler) = &*self.on_buffered_amount_low.load() { - let mut f = handler.lock().await; - f().await; - } - } - } - - /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to - /// be read (once chunk is complete). - pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { - // No lock is required as it reads the size with atomic load function. - let reassembly_queue = self.reassembly_queue.lock().await; - reassembly_queue.get_num_bytes() - } - - /// get_state atomically returns the state of the Association. - fn get_state(&self) -> AssociationState { - self.state.load(Ordering::SeqCst).into() - } - - fn awake_write_loop(&self) { - //log::debug!("[{}] awake_write_loop_ch.notify_one", self.name); - if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch { - let _ = awake_write_loop_ch.try_send(()); - } - } - - async fn send_payload_data(&self, chunks: Vec) -> Result<()> { - let state = self.get_state(); - if state != AssociationState::Established { - return Err(Error::ErrPayloadDataStateNotExist); - } - - // NOTE: append is used here instead of push in order to prevent chunks interlacing. - self.pending_queue.append(chunks).await; - - self.awake_write_loop(); - Ok(()) - } - - async fn send_reset_request(&self, stream_identifier: u16) -> Result<()> { - let state = self.get_state(); - if state != AssociationState::Established { - return Err(Error::ErrResetPacketInStateNotExist); - } - - // Create DATA chunk which only contains valid stream identifier with - // nil userData and use it as a EOS from the stream. - let c = ChunkPayloadData { - stream_identifier, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::new(), - ..Default::default() - }; - - self.pending_queue.push(c).await; - - self.awake_write_loop(); - Ok(()) - } -} - -/// Default capacity of the temporary read buffer used by [`PollStream`]. -const DEFAULT_READ_BUF_SIZE: usize = 8192; - -/// State of the read `Future` in [`PollStream`]. -enum ReadFut { - /// Nothing in progress. - Idle, - /// Reading data from the underlying stream. - Reading(Pin>> + Send>>), - /// Finished reading, but there's unread data in the temporary buffer. - RemainingData(Vec), -} - -enum ShutdownFut { - /// Nothing in progress. - Idle, - /// Reading data from the underlying stream. - ShuttingDown(Pin>>>), - /// Shutdown future has run - Done, - Errored(crate::error::Error), -} - -impl ReadFut { - /// Gets a mutable reference to the future stored inside `Reading(future)`. - /// - /// # Panics - /// - /// Panics if `ReadFut` variant is not `Reading`. - fn get_reading_mut(&mut self) -> &mut Pin>> + Send>> { - match self { - ReadFut::Reading(ref mut fut) => fut, - _ => panic!("expected ReadFut to be Reading"), - } - } -} - -impl ShutdownFut { - /// Gets a mutable reference to the future stored inside `ShuttingDown(future)`. - /// - /// # Panics - /// - /// Panics if `ShutdownFut` variant is not `ShuttingDown`. - fn get_shutting_down_mut( - &mut self, - ) -> &mut Pin>>> { - match self { - ShutdownFut::ShuttingDown(ref mut fut) => fut, - _ => panic!("expected ShutdownFut to be ShuttingDown"), - } - } -} - -/// A wrapper around around [`Stream`], which implements [`AsyncRead`] and -/// [`AsyncWrite`]. -/// -/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an -/// additional overhead. -pub struct PollStream { - stream: Arc, - - read_fut: ReadFut, - write_fut: Option>>>>, - shutdown_fut: ShutdownFut, - - read_buf_cap: usize, -} - -impl PollStream { - /// Constructs a new `PollStream`. - /// - /// # Examples - /// - /// ``` - /// use webrtc_sctp::stream::{Stream, PollStream}; - /// use std::sync::Arc; - /// - /// let stream = Arc::new(Stream::default()); - /// let poll_stream = PollStream::new(stream); - /// ``` - pub fn new(stream: Arc) -> Self { - Self { - stream, - read_fut: ReadFut::Idle, - write_fut: None, - shutdown_fut: ShutdownFut::Idle, - read_buf_cap: DEFAULT_READ_BUF_SIZE, - } - } - - /// Get back the inner stream. - #[must_use] - pub fn into_inner(self) -> Arc { - self.stream - } - - /// Obtain a clone of the inner stream. - #[must_use] - pub fn clone_inner(&self) -> Arc { - self.stream.clone() - } - - /// stream_identifier returns the Stream identifier associated to the stream. - pub fn stream_identifier(&self) -> u16 { - self.stream.stream_identifier - } - - /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream. - pub fn buffered_amount(&self) -> usize { - self.stream.buffered_amount.load(Ordering::SeqCst) - } - - /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is - /// considered "low." Defaults to 0. - pub fn buffered_amount_low_threshold(&self) -> usize { - self.stream.buffered_amount_low.load(Ordering::SeqCst) - } - - /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to - /// be read (once chunk is complete). - pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { - // No lock is required as it reads the size with atomic load function. - let reassembly_queue = self.stream.reassembly_queue.lock().await; - reassembly_queue.get_num_bytes() - } - - /// Set the capacity of the temporary read buffer (default: 8192). - pub fn set_read_buf_capacity(&mut self, capacity: usize) { - self.read_buf_cap = capacity - } -} - -impl AsyncRead for PollStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if buf.remaining() == 0 { - return Poll::Ready(Ok(())); - } - - let fut = match self.read_fut { - ReadFut::Idle => { - // read into a temporary buffer because `buf` has an unonymous lifetime, which can - // be shorter than the lifetime of `read_fut`. - let stream = self.stream.clone(); - let mut temp_buf = vec![0; self.read_buf_cap]; - self.read_fut = ReadFut::Reading(Box::pin(async move { - stream.read(temp_buf.as_mut_slice()).await.map(|n| { - temp_buf.truncate(n); - temp_buf - }) - })); - self.read_fut.get_reading_mut() - } - ReadFut::Reading(ref mut fut) => fut, - ReadFut::RemainingData(ref mut data) => { - let remaining = buf.remaining(); - let len = std::cmp::min(data.len(), remaining); - buf.put_slice(&data[..len]); - if data.len() > remaining { - // ReadFut remains to be RemainingData - data.drain(0..len); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(())); - } - }; - - loop { - match fut.as_mut().poll(cx) { - Poll::Pending => return Poll::Pending, - // retry immediately upon empty data or incomplete chunks - // since there's no way to setup a waker. - Poll::Ready(Err(Error::ErrTryAgain)) => {} - // EOF has been reached => don't touch buf and just return Ok - Poll::Ready(Err(Error::ErrEof)) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Ok(())); - } - Poll::Ready(Err(e)) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Err(e.into())); - } - Poll::Ready(Ok(mut temp_buf)) => { - let remaining = buf.remaining(); - let len = std::cmp::min(temp_buf.len(), remaining); - buf.put_slice(&temp_buf[..len]); - if temp_buf.len() > remaining { - temp_buf.drain(0..len); - self.read_fut = ReadFut::RemainingData(temp_buf); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(())); - } - } - } - } -} - -impl AsyncWrite for PollStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - if let Some(fut) = self.write_fut.as_mut() { - match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let stream = self.stream.clone(); - let bytes = Bytes::copy_from_slice(buf); - self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await })); - Poll::Ready(Err(e.into())) - } - // Given the data is buffered, it's okay to ignore the number of written bytes. - // - // TODO: In the long term, `stream.write` should be made sync. Then we could - // remove the whole `if` condition and just call `stream.write`. - Poll::Ready(Ok(_)) => { - let stream = self.stream.clone(); - let bytes = Bytes::copy_from_slice(buf); - self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await })); - Poll::Ready(Ok(buf.len())) - } - } - } else { - let stream = self.stream.clone(); - let bytes = Bytes::copy_from_slice(buf); - let fut = self - .write_fut - .insert(Box::pin(async move { stream.write(&bytes).await })); - - match fut.as_mut().poll(cx) { - // If it's the first time we're polling the future, `Poll::Pending` can't be - // returned because that would mean the `PollStream` is not ready for writing. And - // this is not true since we've just created a future, which is going to write the - // buf to the underlying stream. - // - // It's okay to return `Poll::Ready` if the data is buffered (this is what the - // buffered writer and `File` do). - Poll::Pending => Poll::Ready(Ok(buf.len())), - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(n)) => { - self.write_fut = None; - Poll::Ready(Ok(n)) - } - } - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.write_fut.as_mut() { - Some(fut) => match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(_)) => { - self.write_fut = None; - Poll::Ready(Ok(())) - } - }, - None => Poll::Ready(Ok(())), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.as_mut().poll_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(_) => {} - } - let fut = match self.shutdown_fut { - ShutdownFut::Done => return Poll::Ready(Ok(())), - ShutdownFut::Errored(ref err) => return Poll::Ready(Err(err.clone().into())), - ShutdownFut::ShuttingDown(ref mut fut) => fut, - ShutdownFut::Idle => { - let stream = self.stream.clone(); - self.shutdown_fut = ShutdownFut::ShuttingDown(Box::pin(async move { - stream.shutdown(Shutdown::Write).await - })); - self.shutdown_fut.get_shutting_down_mut() - } - }; - - match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - self.shutdown_fut = ShutdownFut::Errored(e.clone()); - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(_)) => { - self.shutdown_fut = ShutdownFut::Done; - Poll::Ready(Ok(())) - } - } - } -} - -impl Clone for PollStream { - fn clone(&self) -> PollStream { - PollStream::new(self.clone_inner()) - } -} - -impl fmt::Debug for PollStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PollStream") - .field("stream", &self.stream) - .field("read_buf_cap", &self.read_buf_cap) - .finish() - } -} - -impl AsRef for PollStream { - fn as_ref(&self) -> &Stream { - &self.stream - } -} diff --git a/sctp/src/stream/stream_test.rs b/sctp/src/stream/stream_test.rs deleted file mode 100644 index 59aaa5ec3..000000000 --- a/sctp/src/stream/stream_test.rs +++ /dev/null @@ -1,215 +0,0 @@ -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicU32; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -use super::*; - -#[test] -fn test_stream_buffered_amount() -> Result<()> { - let s = Stream::default(); - - assert_eq!(s.buffered_amount(), 0); - assert_eq!(s.buffered_amount_low_threshold(), 0); - - s.buffered_amount.store(8192, Ordering::SeqCst); - s.set_buffered_amount_low_threshold(2048); - assert_eq!(s.buffered_amount(), 8192, "unexpected bufferedAmount"); - assert_eq!( - s.buffered_amount_low_threshold(), - 2048, - "unexpected threshold" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_stream_amount_on_buffered_amount_low() -> Result<()> { - let s = Stream::default(); - - s.buffered_amount.store(4096, Ordering::SeqCst); - s.set_buffered_amount_low_threshold(2048); - - let n_cbs = Arc::new(AtomicU32::new(0)); - let n_cbs2 = n_cbs.clone(); - - s.on_buffered_amount_low(Box::new(move || { - n_cbs2.fetch_add(1, Ordering::SeqCst); - Box::pin(async {}) - })); - - // Negative value should be ignored (by design) - s.on_buffer_released(-32).await; // bufferedAmount = 3072 - assert_eq!(s.buffered_amount(), 4096, "unexpected bufferedAmount"); - assert_eq!(n_cbs.load(Ordering::SeqCst), 0, "callback count mismatch"); - - // Above to above, no callback - s.on_buffer_released(1024).await; // bufferedAmount = 3072 - assert_eq!(s.buffered_amount(), 3072, "unexpected bufferedAmount"); - assert_eq!(n_cbs.load(Ordering::SeqCst), 0, "callback count mismatch"); - - // Above to equal, callback should be made - s.on_buffer_released(1024).await; // bufferedAmount = 2048 - assert_eq!(s.buffered_amount(), 2048, "unexpected bufferedAmount"); - assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch"); - - // Eaual to below, no callback - s.on_buffer_released(1024).await; // bufferedAmount = 1024 - assert_eq!(s.buffered_amount(), 1024, "unexpected bufferedAmount"); - assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch"); - - // Blow to below, no callback - s.on_buffer_released(1024).await; // bufferedAmount = 0 - assert_eq!(s.buffered_amount(), 0, "unexpected bufferedAmount"); - assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch"); - - // Capped at 0, no callback - s.on_buffer_released(1024).await; // bufferedAmount = 0 - assert_eq!(s.buffered_amount(), 0, "unexpected bufferedAmount"); - assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch"); - - Ok(()) -} - -#[tokio::test] -async fn test_stream() -> std::result::Result<(), io::Error> { - let s = Stream::new( - "test_poll_stream".to_owned(), - 0, - 4096, - Arc::new(AtomicU32::new(4096)), - Arc::new(AtomicU8::new(AssociationState::Established as u8)), - None, - Arc::new(PendingQueue::new()), - ); - - // getters - assert_eq!(s.stream_identifier(), 0); - assert_eq!(s.buffered_amount(), 0); - assert_eq!(s.buffered_amount_low_threshold(), 0); - assert_eq!(s.get_num_bytes_in_reassembly_queue().await, 0); - - // setters - s.set_default_payload_type(PayloadProtocolIdentifier::Binary); - s.set_reliability_params(true, ReliabilityType::Reliable, 0); - - // write - let n = s.write(&Bytes::from("Hello ")).await?; - assert_eq!(n, 6); - assert_eq!(s.buffered_amount(), 6); - let n = s - .write_sctp(&Bytes::from("world"), PayloadProtocolIdentifier::Binary) - .await?; - assert_eq!(n, 5); - assert_eq!(s.buffered_amount(), 11); - - // async read - // 1. pretend that we've received a chunk - s.handle_data(ChunkPayloadData { - unordered: true, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), - payload_type: PayloadProtocolIdentifier::Binary, - ..Default::default() - }) - .await; - // 2. read it - let mut buf = [0; 5]; - s.read(&mut buf).await?; - assert_eq!(buf, [0, 1, 2, 3, 4]); - - // shutdown write - s.shutdown(Shutdown::Write).await?; - // write must fail - assert!(s.write(&Bytes::from("error")).await.is_err()); - // read should continue working - s.handle_data(ChunkPayloadData { - unordered: true, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::from_static(&[5, 6, 7, 8, 9]), - payload_type: PayloadProtocolIdentifier::Binary, - ..Default::default() - }) - .await; - let mut buf = [0; 5]; - s.read(&mut buf).await?; - assert_eq!(buf, [5, 6, 7, 8, 9]); - - // shutdown read - s.shutdown(Shutdown::Read).await?; - // read must return 0 - assert_eq!(s.read(&mut buf).await, Ok(0)); - - Ok(()) -} - -#[tokio::test] -async fn test_poll_stream() -> std::result::Result<(), io::Error> { - let s = Arc::new(Stream::new( - "test_poll_stream".to_owned(), - 0, - 4096, - Arc::new(AtomicU32::new(4096)), - Arc::new(AtomicU8::new(AssociationState::Established as u8)), - None, - Arc::new(PendingQueue::new()), - )); - let mut poll_stream = PollStream::new(s.clone()); - - // getters - assert_eq!(poll_stream.stream_identifier(), 0); - assert_eq!(poll_stream.buffered_amount(), 0); - assert_eq!(poll_stream.buffered_amount_low_threshold(), 0); - assert_eq!(poll_stream.get_num_bytes_in_reassembly_queue().await, 0); - - // async write - let n = poll_stream.write(&[1, 2, 3]).await?; - assert_eq!(n, 3); - poll_stream.flush().await?; - assert_eq!(poll_stream.buffered_amount(), 3); - - // async read - // 1. pretend that we've received a chunk - let sc = s.clone(); - sc.handle_data(ChunkPayloadData { - unordered: true, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), - payload_type: PayloadProtocolIdentifier::Binary, - ..Default::default() - }) - .await; - // 2. read it - let mut buf = [0; 5]; - poll_stream.read_exact(&mut buf).await?; - assert_eq!(buf, [0, 1, 2, 3, 4]); - - // shutdown write - poll_stream.shutdown().await?; - // write must fail - assert!(poll_stream.write(&[1, 2, 3]).await.is_err()); - // read should continue working - sc.handle_data(ChunkPayloadData { - unordered: true, - beginning_fragment: true, - ending_fragment: true, - user_data: Bytes::from_static(&[5, 6, 7, 8, 9]), - payload_type: PayloadProtocolIdentifier::Binary, - ..Default::default() - }) - .await; - let mut buf = [0; 5]; - poll_stream.read_exact(&mut buf).await?; - assert_eq!(buf, [5, 6, 7, 8, 9]); - - // misc. - let clone = poll_stream.clone(); - assert_eq!(clone.stream_identifier(), poll_stream.stream_identifier()); - - Ok(()) -} diff --git a/sctp/src/timer/ack_timer.rs b/sctp/src/timer/ack_timer.rs deleted file mode 100644 index 9a3d8074f..000000000 --- a/sctp/src/timer/ack_timer.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::sync::Weak; - -use async_trait::async_trait; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; - -pub(crate) const ACK_INTERVAL: Duration = Duration::from_millis(200); - -/// ackTimerObserver is the interface to an ack timer observer. -#[async_trait] -pub(crate) trait AckTimerObserver { - async fn on_ack_timeout(&mut self); -} - -/// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 -#[derive(Default, Debug)] -pub(crate) struct AckTimer { - pub(crate) timeout_observer: Weak>, - pub(crate) interval: Duration, - pub(crate) close_tx: Option>, -} - -impl AckTimer { - /// newAckTimer creates a new acknowledgement timer used to enable delayed ack. - pub(crate) fn new(timeout_observer: Weak>, interval: Duration) -> Self { - AckTimer { - timeout_observer, - interval, - close_tx: None, - } - } - - /// start starts the timer. - pub(crate) fn start(&mut self) -> bool { - // this timer is already closed - if self.close_tx.is_some() { - return false; - } - - let (close_tx, mut close_rx) = mpsc::channel(1); - let interval = self.interval; - let timeout_observer = self.timeout_observer.clone(); - - tokio::spawn(async move { - let timer = tokio::time::sleep(interval); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() => { - if let Some(observer) = timeout_observer.upgrade(){ - let mut observer = observer.lock().await; - observer.on_ack_timeout().await; - } - } - _ = close_rx.recv() => {}, - } - }); - - self.close_tx = Some(close_tx); - true - } - - /// stops the timer. this is similar to stop() but subsequent start() call - /// will fail (the timer is no longer usable) - pub(crate) fn stop(&mut self) { - self.close_tx.take(); - } - - /// isRunning tests if the timer is running. - /// Debug purpose only - pub(crate) fn is_running(&self) -> bool { - self.close_tx.is_some() - } -} diff --git a/sctp/src/timer/mod.rs b/sctp/src/timer/mod.rs deleted file mode 100644 index 822246260..000000000 --- a/sctp/src/timer/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -#[cfg(test)] -mod timer_test; - -pub(crate) mod ack_timer; -pub(crate) mod rtx_timer; diff --git a/sctp/src/timer/rtx_timer.rs b/sctp/src/timer/rtx_timer.rs deleted file mode 100644 index ab8ad66dc..000000000 --- a/sctp/src/timer/rtx_timer.rs +++ /dev/null @@ -1,208 +0,0 @@ -use std::sync::{Arc, Weak}; - -use async_trait::async_trait; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; - -use crate::association::RtxTimerId; - -pub(crate) const RTO_INITIAL: u64 = 3000; // msec -pub(crate) const RTO_MIN: u64 = 1000; // msec -pub(crate) const RTO_MAX: u64 = 60000; // msec -pub(crate) const RTO_ALPHA: u64 = 1; -pub(crate) const RTO_BETA: u64 = 2; -pub(crate) const RTO_BASE: u64 = 8; -pub(crate) const MAX_INIT_RETRANS: usize = 8; -pub(crate) const PATH_MAX_RETRANS: usize = 5; -pub(crate) const NO_MAX_RETRANS: usize = 0; - -/// rtoManager manages Rtx timeout values. -/// This is an implementation of RFC 4960 sec 6.3.1. -#[derive(Default, Debug)] -pub(crate) struct RtoManager { - pub(crate) srtt: u64, - pub(crate) rttvar: f64, - pub(crate) rto: u64, - pub(crate) no_update: bool, -} - -impl RtoManager { - /// newRTOManager creates a new rtoManager. - pub(crate) fn new() -> Self { - RtoManager { - rto: RTO_INITIAL, - ..Default::default() - } - } - - /// set_new_rtt takes a newly measured RTT then adjust the RTO in msec. - pub(crate) fn set_new_rtt(&mut self, rtt: u64) -> u64 { - if self.no_update { - return self.srtt; - } - - if self.srtt == 0 { - // First measurement - self.srtt = rtt; - self.rttvar = rtt as f64 / 2.0; - } else { - // Subsequent rtt measurement - self.rttvar = ((RTO_BASE - RTO_BETA) as f64 * self.rttvar - + RTO_BETA as f64 * (self.srtt as i64 - rtt as i64).abs() as f64) - / RTO_BASE as f64; - self.srtt = ((RTO_BASE - RTO_ALPHA) * self.srtt + RTO_ALPHA * rtt) / RTO_BASE; - } - - self.rto = (self.srtt + (4.0 * self.rttvar) as u64).clamp(RTO_MIN, RTO_MAX); - - self.srtt - } - - /// get_rto simply returns the current RTO in msec. - pub(crate) fn get_rto(&self) -> u64 { - self.rto - } - - /// reset resets the RTO variables to the initial values. - pub(crate) fn reset(&mut self) { - if self.no_update { - return; - } - - self.srtt = 0; - self.rttvar = 0.0; - self.rto = RTO_INITIAL; - } - - /// set RTO value for testing - pub(crate) fn set_rto(&mut self, rto: u64, no_update: bool) { - self.rto = rto; - self.no_update = no_update; - } -} - -pub(crate) fn calculate_next_timeout(rto: u64, n_rtos: usize) -> u64 { - // RFC 4096 sec 6.3.3. Handle T3-rtx Expiration - // E2) For the destination address for which the timer expires, set RTO - // <- RTO * 2 ("back off the timer"). The maximum value discussed - // in rule C7 above (RTO.max) may be used to provide an upper bound - // to this doubling operation. - if n_rtos < 31 { - std::cmp::min(rto << n_rtos, RTO_MAX) - } else { - RTO_MAX - } -} - -/// rtxTimerObserver is the interface to a timer observer. -/// NOTE: Observers MUST NOT call start() or stop() method on rtxTimer -/// from within these callbacks. -#[async_trait] -pub(crate) trait RtxTimerObserver { - async fn on_retransmission_timeout(&mut self, timer_id: RtxTimerId, n: usize); - async fn on_retransmission_failure(&mut self, timer_id: RtxTimerId); -} - -/// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 -#[derive(Default, Debug)] -pub(crate) struct RtxTimer { - pub(crate) timeout_observer: Weak>, - pub(crate) id: RtxTimerId, - pub(crate) max_retrans: usize, - pub(crate) close_tx: Arc>>>, -} - -impl RtxTimer { - /// newRTXTimer creates a new retransmission timer. - /// if max_retrans is set to 0, it will keep retransmitting until stop() is called. - /// (it will never make on_retransmission_failure() callback. - pub(crate) fn new( - timeout_observer: Weak>, - id: RtxTimerId, - max_retrans: usize, - ) -> Self { - RtxTimer { - timeout_observer, - id, - max_retrans, - close_tx: Arc::new(Mutex::new(None)), - } - } - - /// start starts the timer. - pub(crate) async fn start(&self, rto: u64) -> bool { - // Note: rto value is intentionally not capped by RTO.Min to allow - // fast timeout for the tests. Non-test code should pass in the - // rto generated by rtoManager get_rto() method which caps the - // value at RTO.Min or at RTO.Max. - - // this timer is already closed - let mut close_rx = { - let mut close = self.close_tx.lock().await; - if close.is_some() { - return false; - } - - let (close_tx, close_rx) = mpsc::channel(1); - *close = Some(close_tx); - close_rx - }; - - let id = self.id; - let max_retrans = self.max_retrans; - let close_tx = Arc::clone(&self.close_tx); - let timeout_observer = self.timeout_observer.clone(); - - tokio::spawn(async move { - let mut n_rtos = 0; - - loop { - let interval = calculate_next_timeout(rto, n_rtos); - let timer = tokio::time::sleep(Duration::from_millis(interval)); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() => { - n_rtos+=1; - - let failure = { - if let Some(observer) = timeout_observer.upgrade(){ - let mut observer = observer.lock().await; - if max_retrans == 0 || n_rtos <= max_retrans { - observer.on_retransmission_timeout(id, n_rtos).await; - false - } else { - observer.on_retransmission_failure(id).await; - true - } - }else{ - true - } - }; - if failure { - let mut close = close_tx.lock().await; - *close = None; - break; - } - } - _ = close_rx.recv() => break, - } - } - }); - - true - } - - /// stop stops the timer. - pub(crate) async fn stop(&self) { - let mut close_tx = self.close_tx.lock().await; - close_tx.take(); - } - - /// isRunning tests if the timer is running. - /// Debug purpose only - pub(crate) async fn is_running(&self) -> bool { - let close_tx = self.close_tx.lock().await; - close_tx.is_some() - } -} diff --git a/sctp/src/timer/timer_test.rs b/sctp/src/timer/timer_test.rs deleted file mode 100644 index d7afdf300..000000000 --- a/sctp/src/timer/timer_test.rs +++ /dev/null @@ -1,511 +0,0 @@ -// Silence warning on `for i in 0..vec.len() { โ€ฆ }`: -#![allow(clippy::needless_range_loop)] -// Silence warning on `..Default::default()` with no effect: -#![allow(clippy::needless_update)] - -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use portable_atomic::AtomicU32; -use tokio::sync::Mutex; -use tokio::time::{sleep, Duration}; - -/////////////////////////////////////////////////////////////////// -//ack_timer_test -/////////////////////////////////////////////////////////////////// -use super::ack_timer::*; - -mod test_ack_timer { - use super::*; - use crate::error::Result; - - struct TestAckTimerObserver { - ncbs: Arc, - } - - #[async_trait] - impl AckTimerObserver for TestAckTimerObserver { - async fn on_ack_timeout(&mut self) { - log::trace!("ack timed out"); - self.ncbs.fetch_add(1, Ordering::SeqCst); - } - } - - #[tokio::test] - async fn test_ack_timer_start_and_stop() -> Result<()> { - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestAckTimerObserver { ncbs: ncbs.clone() })); - - let mut rt = AckTimer::new(Arc::downgrade(&obs), ACK_INTERVAL); - - // should start ok - let ok = rt.start(); - assert!(ok, "start() should succeed"); - assert!(rt.is_running(), "should be running"); - - // stop immedidately - rt.stop(); - assert!(!rt.is_running(), "should not be running"); - - // Sleep more than 200msec of interval to test if it never times out - sleep(ACK_INTERVAL + Duration::from_millis(50)).await; - - assert_eq!( - ncbs.load(Ordering::SeqCst), - 0, - "should not be timed out (actual: {})", - ncbs.load(Ordering::SeqCst) - ); - - // can start again - let ok = rt.start(); - assert!(ok, "start() should succeed again"); - assert!(rt.is_running(), "should be running"); - - // should close ok - rt.stop(); - assert!(!rt.is_running(), "should not be running"); - - Ok(()) - } -} - -/////////////////////////////////////////////////////////////////// -//rtx_timer_test -/////////////////////////////////////////////////////////////////// -use super::rtx_timer::*; - -mod test_rto_manager { - use super::*; - use crate::error::Result; - - #[tokio::test] - async fn test_rto_manager_initial_values() -> Result<()> { - let m = RtoManager::new(); - assert_eq!(m.rto, RTO_INITIAL, "should be rtoInitial"); - assert_eq!(m.get_rto(), RTO_INITIAL, "should be rtoInitial"); - assert_eq!(m.srtt, 0, "should be 0"); - assert_eq!(m.rttvar, 0.0, "should be 0.0"); - - Ok(()) - } - - #[tokio::test] - async fn test_rto_manager_rto_calculation_small_rtt() -> Result<()> { - let mut m = RtoManager::new(); - let exp = [1800, 1500, 1275, 1106, 1000]; - - for i in 0..5 { - m.set_new_rtt(600); - let rto = m.get_rto(); - assert_eq!(rto, exp[i], "should be equal: {i}"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_rto_manager_rto_calculation_large_rtt() -> Result<()> { - let mut m = RtoManager::new(); - let exp = [ - 60000, // capped at RTO.Max - 60000, // capped at RTO.Max - 60000, // capped at RTO.Max - 55312, 48984, - ]; - - for i in 0..5 { - m.set_new_rtt(30000); - let rto = m.get_rto(); - assert_eq!(rto, exp[i], "should be equal: {i}"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_rto_manager_calculate_next_timeout() -> Result<()> { - let rto = calculate_next_timeout(1, 0); - assert_eq!(rto, 1, "should match"); - let rto = calculate_next_timeout(1, 1); - assert_eq!(rto, 2, "should match"); - let rto = calculate_next_timeout(1, 2); - assert_eq!(rto, 4, "should match"); - let rto = calculate_next_timeout(1, 30); - assert_eq!(rto, 60000, "should match"); - let rto = calculate_next_timeout(1, 63); - assert_eq!(rto, 60000, "should match"); - let rto = calculate_next_timeout(1, 64); - assert_eq!(rto, 60000, "should match"); - - Ok(()) - } - - #[tokio::test] - async fn test_rto_manager_reset() -> Result<()> { - let mut m = RtoManager::new(); - for _ in 0..10 { - m.set_new_rtt(200); - } - - m.reset(); - assert_eq!(m.get_rto(), RTO_INITIAL, "should be rtoInitial"); - assert_eq!(m.srtt, 0, "should be 0"); - assert_eq!(m.rttvar, 0.0, "should be 0"); - - Ok(()) - } -} - -//TODO: remove this conditional test -#[cfg(not(any(target_os = "macos", target_os = "windows")))] -mod test_rtx_timer { - use std::time::SystemTime; - - use tokio::sync::mpsc; - - use super::*; - use crate::association::RtxTimerId; - use crate::error::Result; - - struct TestTimerObserver { - ncbs: Arc, - timer_id: RtxTimerId, - done_tx: Option>, - max_rtos: usize, - } - - impl Default for TestTimerObserver { - fn default() -> Self { - TestTimerObserver { - ncbs: Arc::new(AtomicU32::new(0)), - timer_id: RtxTimerId::T1Init, - done_tx: None, - max_rtos: 0, - } - } - } - - #[async_trait] - impl RtxTimerObserver for TestTimerObserver { - async fn on_retransmission_timeout(&mut self, timer_id: RtxTimerId, n_rtos: usize) { - self.ncbs.fetch_add(1, Ordering::SeqCst); - // 30 : 1 (30) - // 60 : 2 (90) - // 120: 3 (210) - // 240: 4 (550) <== expected in 650 msec - assert_eq!(self.timer_id, timer_id, "unexpected timer ID: {timer_id}"); - if (self.max_rtos > 0 && n_rtos == self.max_rtos) || self.max_rtos == usize::MAX { - if let Some(done) = &self.done_tx { - let elapsed = SystemTime::now(); - let _ = done.send(elapsed).await; - } - } - } - - async fn on_retransmission_failure(&mut self, timer_id: RtxTimerId) { - if self.max_rtos == 0 { - if let Some(done) = &self.done_tx { - assert_eq!(self.timer_id, timer_id, "unexpted timer ID: {timer_id}"); - let elapsed = SystemTime::now(); - //t.Logf("onRtxFailure: elapsed=%.03f\n", elapsed) - let _ = done.send(elapsed).await; - } - } else { - panic!("timer should not fail"); - } - } - } - - #[tokio::test] - async fn test_rtx_timer_callback_interval() -> Result<()> { - let timer_id = RtxTimerId::T1Init; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - assert!(!rt.is_running().await, "should not be running"); - - // since := time.Now() - let ok = rt.start(30).await; - assert!(ok, "should be true"); - assert!(rt.is_running().await, "should be running"); - - sleep(Duration::from_millis(650)).await; - rt.stop().await; - assert!(!rt.is_running().await, "should not be running"); - - assert_eq!(ncbs.load(Ordering::SeqCst), 4, "should be called 4 times"); - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_last_start_wins() -> Result<()> { - let timer_id = RtxTimerId::T3RTX; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - let interval = 30; - let ok = rt.start(interval).await; - assert!(ok, "should be accepted"); - let ok = rt.start(interval * 99).await; // should ignored - assert!(!ok, "should be ignored"); - let ok = rt.start(interval * 99).await; // should ignored - assert!(!ok, "should be ignored"); - - sleep(Duration::from_millis((interval * 3) / 2)).await; - rt.stop().await; - - assert!(!rt.is_running().await, "should not be running"); - assert_eq!(ncbs.load(Ordering::SeqCst), 1, "must be called once"); - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_stop_right_after_start() -> Result<()> { - let timer_id = RtxTimerId::T3RTX; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - let interval = 30; - let ok = rt.start(interval).await; - assert!(ok, "should be accepted"); - rt.stop().await; - - sleep(Duration::from_millis((interval * 3) / 2)).await; - rt.stop().await; - - assert!(!rt.is_running().await, "should not be running"); - assert_eq!(ncbs.load(Ordering::SeqCst), 0, "no callback should be made"); - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_start_stop_then_start() -> Result<()> { - let timer_id = RtxTimerId::T1Cookie; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - let interval = 30; - let ok = rt.start(interval).await; - assert!(ok, "should be accepted"); - rt.stop().await; - assert!(!rt.is_running().await, "should NOT be running"); - let ok = rt.start(interval).await; - assert!(ok, "should be accepted"); - assert!(rt.is_running().await, "should be running"); - - sleep(Duration::from_millis((interval * 3) / 2)).await; - rt.stop().await; - - assert!(!rt.is_running().await, "should NOT be running"); - assert_eq!(ncbs.load(Ordering::SeqCst), 1, "must be called once"); - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_start_and_stop_in_atight_loop() -> Result<()> { - let timer_id = RtxTimerId::T2Shutdown; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - for _ in 0..1000 { - let ok = rt.start(30).await; - assert!(ok, "should be accepted"); - assert!(rt.is_running().await, "should be running"); - rt.stop().await; - assert!(!rt.is_running().await, "should NOT be running"); - } - - assert_eq!(ncbs.load(Ordering::SeqCst), 0, "no callback should be made"); - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_should_stop_after_rtx_failure() -> Result<()> { - let (done_tx, mut done_rx) = mpsc::channel(1); - - let timer_id = RtxTimerId::Reconfig; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - done_tx: Some(done_tx), - ..Default::default() - })); - - let since = SystemTime::now(); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - // RTO(msec) Total(msec) - // 10 10 1st RTO - // 20 30 2nd RTO - // 40 70 3rd RTO - // 80 150 4th RTO - // 160 310 5th RTO (== Path.Max.Retrans) - // 320 630 Failure - - let interval = 10; - let ok = rt.start(interval).await; - assert!(ok, "should be accepted"); - assert!(rt.is_running().await, "should be running"); - - let elapsed = done_rx.recv().await; - - assert!(!rt.is_running().await, "should not be running"); - assert_eq!(ncbs.load(Ordering::SeqCst), 5, "should be called 5 times"); - - if let Some(elapsed) = elapsed { - let diff = elapsed.duration_since(since).unwrap(); - assert!( - diff > Duration::from_millis(600), - "must have taken more than 600 msec" - ); - assert!( - diff < Duration::from_millis(700), - "must fail in less than 700 msec" - ); - } - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_should_not_stop_if_max_retrans_is_zero() -> Result<()> { - let (done_tx, mut done_rx) = mpsc::channel(1); - - let timer_id = RtxTimerId::Reconfig; - let max_rtos = 6; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - done_tx: Some(done_tx), - max_rtos, - ..Default::default() - })); - - let since = SystemTime::now(); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, 0); - - // RTO(msec) Total(msec) - // 10 10 1st RTO - // 20 30 2nd RTO - // 40 70 3rd RTO - // 80 150 4th RTO - // 160 310 5th RTO - // 320 630 6th RTO => exit test (timer should still be running) - - let interval = 10; - let ok = rt.start(interval).await; - assert!(ok, "should be accepted"); - assert!(rt.is_running().await, "should be running"); - - let elapsed = done_rx.recv().await; - - assert!(rt.is_running().await, "should still be running"); - assert_eq!(ncbs.load(Ordering::SeqCst), 6, "should be called 6 times"); - - if let Some(elapsed) = elapsed { - let diff = elapsed.duration_since(since).unwrap(); - assert!( - diff > Duration::from_millis(600), - "must have taken more than 600 msec" - ); - assert!( - diff < Duration::from_millis(700), - "must fail in less than 700 msec" - ); - } - - rt.stop().await; - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_stop_timer_that_is_not_running_is_noop() -> Result<()> { - let (done_tx, mut done_rx) = mpsc::channel(1); - - let timer_id = RtxTimerId::Reconfig; - let obs = Arc::new(Mutex::new(TestTimerObserver { - timer_id, - done_tx: Some(done_tx), - max_rtos: usize::MAX, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - for _ in 0..10 { - rt.stop().await; - } - - let ok = rt.start(20).await; - assert!(ok, "should be accepted"); - assert!(rt.is_running().await, "must be running"); - - let _ = done_rx.recv().await; - rt.stop().await; - assert!(!rt.is_running().await, "must be false"); - - Ok(()) - } - - #[tokio::test] - async fn test_rtx_timer_closed_timer_wont_start() -> Result<()> { - let timer_id = RtxTimerId::Reconfig; - let ncbs = Arc::new(AtomicU32::new(0)); - let obs = Arc::new(Mutex::new(TestTimerObserver { - ncbs: ncbs.clone(), - timer_id, - ..Default::default() - })); - let rt = RtxTimer::new(Arc::downgrade(&obs), timer_id, PATH_MAX_RETRANS); - - let ok = rt.start(20).await; - assert!(ok, "should be accepted"); - assert!(rt.is_running().await, "must be running"); - - rt.stop().await; - assert!(!rt.is_running().await, "must be false"); - - //let ok = rt.start(obs.clone(), 20).await; - //assert!(!ok, "should not start"); - assert!(!rt.is_running().await, "must not be running"); - - sleep(Duration::from_millis(100)).await; - assert_eq!(ncbs.load(Ordering::SeqCst), 0, "RTO should not occur"); - - Ok(()) - } -} diff --git a/sctp/src/util.rs b/sctp/src/util.rs deleted file mode 100644 index ee0121a4f..000000000 --- a/sctp/src/util.rs +++ /dev/null @@ -1,241 +0,0 @@ -use bytes::Bytes; -use crc::{Crc, Table, CRC_32_ISCSI}; - -pub(crate) const PADDING_MULTIPLE: usize = 4; - -pub(crate) fn get_padding_size(len: usize) -> usize { - (PADDING_MULTIPLE - (len % PADDING_MULTIPLE)) % PADDING_MULTIPLE -} - -/// Allocate and zero this data once. -/// We need to use it for the checksum and don't want to allocate/clear each time. -pub(crate) static FOUR_ZEROES: Bytes = Bytes::from_static(&[0, 0, 0, 0]); - -pub(crate) const ISCSI_CRC: Crc> = Crc::>::new(&CRC_32_ISCSI); - -/// Fastest way to do a crc32 without allocating. -pub(crate) fn generate_packet_checksum(raw: &Bytes) -> u32 { - let mut digest = ISCSI_CRC.digest(); - digest.update(&raw[0..8]); - digest.update(&FOUR_ZEROES[..]); - digest.update(&raw[12..]); - digest.finalize() -} - -/// Serial Number Arithmetic (RFC 1982) -#[inline] -pub(crate) fn sna32lt(i1: u32, i2: u32) -> bool { - (i1 < i2 && i2 - i1 < 1 << 31) || (i1 > i2 && i1 - i2 > 1 << 31) -} - -#[inline] -pub(crate) fn sna32lte(i1: u32, i2: u32) -> bool { - i1 == i2 || sna32lt(i1, i2) -} - -#[inline] -pub(crate) fn sna32gt(i1: u32, i2: u32) -> bool { - (i1 < i2 && (i2 - i1) >= 1 << 31) || (i1 > i2 && (i1 - i2) <= 1 << 31) -} - -#[inline] -pub(crate) fn sna32gte(i1: u32, i2: u32) -> bool { - i1 == i2 || sna32gt(i1, i2) -} - -#[inline] -pub(crate) fn sna32eq(i1: u32, i2: u32) -> bool { - i1 == i2 -} - -#[inline] -pub(crate) fn sna16lt(i1: u16, i2: u16) -> bool { - (i1 < i2 && (i2 - i1) < 1 << 15) || (i1 > i2 && (i1 - i2) > 1 << 15) -} - -#[inline] -pub(crate) fn sna16lte(i1: u16, i2: u16) -> bool { - i1 == i2 || sna16lt(i1, i2) -} - -#[inline] -pub(crate) fn sna16gt(i1: u16, i2: u16) -> bool { - (i1 < i2 && (i2 - i1) >= 1 << 15) || (i1 > i2 && (i1 - i2) <= 1 << 15) -} - -#[inline] -pub(crate) fn sna16gte(i1: u16, i2: u16) -> bool { - i1 == i2 || sna16gt(i1, i2) -} - -#[inline] -pub(crate) fn sna16eq(i1: u16, i2: u16) -> bool { - i1 == i2 -} - -#[cfg(test)] -mod test { - use super::*; - use crate::error::Result; - - const DIV: isize = 16; - - #[test] - fn test_serial_number_arithmetic32bit() -> Result<()> { - const SERIAL_BITS: u32 = 32; - const INTERVAL: u32 = ((1u64 << (SERIAL_BITS as u64)) / (DIV as u64)) as u32; - const MAX_FORWARD_DISTANCE: u32 = 1 << ((SERIAL_BITS - 1) - 1); - const MAX_BACKWARD_DISTANCE: u32 = 1 << (SERIAL_BITS - 1); - - for i in 0..DIV as u32 { - let s1 = i * INTERVAL; - let s2f = s1.checked_add(MAX_FORWARD_DISTANCE); - let s2b = s1.checked_add(MAX_BACKWARD_DISTANCE); - - if let (Some(s2f), Some(s2b)) = (s2f, s2b) { - assert!(sna32lt(s1, s2f), "s1 < s2 should be true: s1={s1} s2={s2f}"); - assert!( - !sna32lt(s1, s2b), - "s1 < s2 should be false: s1={s1} s2={s2b}" - ); - - assert!( - !sna32gt(s1, s2f), - "s1 > s2 should be false: s1={s1} s2={s2f}" - ); - assert!(sna32gt(s1, s2b), "s1 > s2 should be true: s1={s1} s2={s2b}"); - - assert!( - sna32lte(s1, s2f), - "s1 <= s2 should be true: s1={s1} s2={s2f}" - ); - assert!( - !sna32lte(s1, s2b), - "s1 <= s2 should be false: s1={s1} s2={s2b}" - ); - - assert!( - !sna32gte(s1, s2f), - "s1 >= s2 should be fales: s1={s1} s2={s2f}" - ); - assert!( - sna32gte(s1, s2b), - "s1 >= s2 should be true: s1={s1} s2={s2b}" - ); - - assert!( - sna32eq(s2b, s2b), - "s2 == s2 should be true: s2={s2b} s2={s2b}" - ); - assert!( - sna32lte(s2b, s2b), - "s2 == s2 should be true: s2={s2b} s2={s2b}" - ); - assert!( - sna32gte(s2b, s2b), - "s2 == s2 should be true: s2={s2b} s2={s2b}" - ); - } - - if let Some(s1add1) = s1.checked_add(1) { - assert!( - !sna32eq(s1, s1add1), - "s1 == s1+1 should be false: s1={s1} s1+1={s1add1}" - ); - } - - if let Some(s1sub1) = s1.checked_sub(1) { - assert!( - !sna32eq(s1, s1sub1), - "s1 == s1-1 should be false: s1={s1} s1-1={s1sub1}" - ); - } - - assert!(sna32eq(s1, s1), "s1 == s1 should be true: s1={s1} s2={s1}"); - assert!(sna32lte(s1, s1), "s1 == s1 should be true: s1={s1} s2={s1}"); - - assert!(sna32gte(s1, s1), "s1 == s1 should be true: s1={s1} s2={s1}"); - } - - Ok(()) - } - - #[test] - fn test_serial_number_arithmetic16bit() -> Result<()> { - const SERIAL_BITS: u16 = 16; - const INTERVAL: u16 = ((1u64 << (SERIAL_BITS as u64)) / (DIV as u64)) as u16; - const MAX_FORWARD_DISTANCE: u16 = 1 << ((SERIAL_BITS - 1) - 1); - const MAX_BACKWARD_DISTANCE: u16 = 1 << (SERIAL_BITS - 1); - - for i in 0..DIV as u16 { - let s1 = i * INTERVAL; - let s2f = s1.checked_add(MAX_FORWARD_DISTANCE); - let s2b = s1.checked_add(MAX_BACKWARD_DISTANCE); - - if let (Some(s2f), Some(s2b)) = (s2f, s2b) { - assert!(sna16lt(s1, s2f), "s1 < s2 should be true: s1={s1} s2={s2f}"); - assert!( - !sna16lt(s1, s2b), - "s1 < s2 should be false: s1={s1} s2={s2b}" - ); - - assert!( - !sna16gt(s1, s2f), - "s1 > s2 should be fales: s1={s1} s2={s2f}" - ); - assert!(sna16gt(s1, s2b), "s1 > s2 should be true: s1={s1} s2={s2b}"); - - assert!( - sna16lte(s1, s2f), - "s1 <= s2 should be true: s1={s1} s2={s2f}" - ); - assert!( - !sna16lte(s1, s2b), - "s1 <= s2 should be false: s1={s1} s2={s2b}" - ); - - assert!( - !sna16gte(s1, s2f), - "s1 >= s2 should be fales: s1={s1} s2={s2f}" - ); - assert!( - sna16gte(s1, s2b), - "s1 >= s2 should be true: s1={s1} s2={s2b}" - ); - - assert!( - sna16eq(s2b, s2b), - "s2 == s2 should be true: s2={s2b} s2={s2b}" - ); - assert!( - sna16lte(s2b, s2b), - "s2 == s2 should be true: s2={s2b} s2={s2b}" - ); - assert!( - sna16gte(s2b, s2b), - "s2 == s2 should be true: s2={s2b} s2={s2b}" - ); - } - - assert!(sna16eq(s1, s1), "s1 == s1 should be true: s1={s1} s2={s1}"); - - if let Some(s1add1) = s1.checked_add(1) { - assert!( - !sna16eq(s1, s1add1), - "s1 == s1+1 should be false: s1={s1} s1+1={s1add1}" - ); - } - if let Some(s1sub1) = s1.checked_sub(1) { - assert!( - !sna16eq(s1, s1sub1), - "s1 == s1-1 should be false: s1={s1} s1-1={s1sub1}" - ); - } - - assert!(sna16lte(s1, s1), "s1 == s1 should be true: s1={s1} s2={s1}"); - assert!(sna16gte(s1, s1), "s1 == s1 should be true: s1={s1} s2={s1}"); - } - - Ok(()) - } -} diff --git a/sdp/.gitignore b/sdp/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/sdp/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/sdp/CHANGELOG.md b/sdp/CHANGELOG.md deleted file mode 100644 index dc8be9014..000000000 --- a/sdp/CHANGELOG.md +++ /dev/null @@ -1,18 +0,0 @@ -# sdp changelog - -## Unreleased - -* Implement from and tryfrom string traits for SessionDescription. - -## v0.5.3 - -* Increased minimum support rust version to `1.60.0`. - -## v0.5.2 - -* [#10 update deps + loosen some requirements](https://github.com/webrtc-rs/sdp/pull/10) by [@melekes](https://github.com/melekes). - -## Prior to 0.5.2 - -Before 0.5.2 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/sdp/releases). - diff --git a/sdp/Cargo.toml b/sdp/Cargo.toml deleted file mode 100644 index 6f5eda650..000000000 --- a/sdp/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "sdp" -version = "0.6.2" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of SDP" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/sdp" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/sdp" - -[dependencies] -url = "2" -rand = "0.8" -thiserror = "1" -substring = "1" - -[dev-dependencies] -criterion = "0.5" - -[[bench]] -name = "bench" -harness = false diff --git a/sdp/LICENSE-APACHE b/sdp/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/sdp/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/sdp/LICENSE-MIT b/sdp/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/sdp/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/sdp/README.md b/sdp/README.md deleted file mode 100644 index 651165c8b..000000000 --- a/sdp/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of SDP. Rewrite Pion SDP in Rust -

diff --git a/sdp/benches/bench.rs b/sdp/benches/bench.rs deleted file mode 100644 index 35b40047b..000000000 --- a/sdp/benches/bench.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::io::Cursor; - -use criterion::{criterion_group, criterion_main, Criterion}; -use sdp::SessionDescription; - -const CANONICAL_UNMARSHAL_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -i=A Seminar on the session description protocol\r\n\ -u=http://www.example.com/seminars/sdp.pdf\r\n\ -e=j.doe@example.com (Jane Doe)\r\n\ -p=+1 617 555-6011\r\n\ -c=IN IP4 224.2.17.12/127\r\n\ -b=X-YZ:128\r\n\ -b=AS:12345\r\n\ -t=2873397496 2873404696\r\n\ -t=3034423619 3042462419\r\n\ -r=604800 3600 0 90000\r\n\ -z=2882844526 -3600 2898848070 0\r\n\ -k=prompt\r\n\ -a=candidate:0 1 UDP 2113667327 203.0.113.1 54400 typ host\r\n\ -a=recvonly\r\n\ -m=audio 49170 RTP/AVP 0\r\n\ -i=Vivamus a posuere nisl\r\n\ -c=IN IP4 203.0.113.1\r\n\ -b=X-YZ:128\r\n\ -k=prompt\r\n\ -a=sendrecv\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -a=rtpmap:99 h263-1998/90000\r\n"; - -fn benchmark_sdp(c: &mut Criterion) { - let mut reader = Cursor::new(CANONICAL_UNMARSHAL_SDP.as_bytes()); - let sdp = SessionDescription::unmarshal(&mut reader).unwrap(); - - /////////////////////////////////////////////////////////////////////////////////////////////// - c.bench_function("Benchmark Marshal", |b| { - b.iter(|| { - let _ = sdp.marshal(); - }) - }); - - c.bench_function("Benchmark Unmarshal ", |b| { - b.iter(|| { - let mut reader = Cursor::new(CANONICAL_UNMARSHAL_SDP.as_bytes()); - let _ = SessionDescription::unmarshal(&mut reader).unwrap(); - }) - }); -} - -criterion_group!(benches, benchmark_sdp); -criterion_main!(benches); diff --git a/sdp/codecov.yml b/sdp/codecov.yml deleted file mode 100644 index e794d966a..000000000 --- a/sdp/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 40894be8-0942-482a-b7cf-e58721cff2c5 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/sdp/doc/webrtc.rs.png b/sdp/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/sdp/doc/webrtc.rs.png and /dev/null differ diff --git a/sdp/fuzz/.gitignore b/sdp/fuzz/.gitignore deleted file mode 100644 index 572e03bdf..000000000 --- a/sdp/fuzz/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ - -target -corpus -artifacts diff --git a/sdp/fuzz/Cargo.toml b/sdp/fuzz/Cargo.toml deleted file mode 100644 index 52768eb2d..000000000 --- a/sdp/fuzz/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ - -[package] -name = "sdp-fuzz" -version = "0.0.0" -authors = ["Automatically generated"] -publish = false -edition = "2021" - -[package.metadata] -cargo-fuzz = true - -[dependencies] -libfuzzer-sys = "0.4" - -[dependencies.sdp] -path = ".." - -# Prevent this from interfering with workspaces -[workspace] -members = ["."] - -[[bin]] -name = "parse_session" -path = "fuzz_targets/parse_session.rs" -test = false -doc = false diff --git a/sdp/fuzz/fuzz_targets/parse_session.rs b/sdp/fuzz/fuzz_targets/parse_session.rs deleted file mode 100644 index 4889ab62a..000000000 --- a/sdp/fuzz/fuzz_targets/parse_session.rs +++ /dev/null @@ -1,7 +0,0 @@ -#![no_main] -use libfuzzer_sys::fuzz_target; - -fuzz_target!(|data: &[u8]| { - let mut cursor = std::io::Cursor::new(data); - let _session = sdp::SessionDescription::unmarshal(&mut cursor); -}); diff --git a/sdp/src/description/common.rs b/sdp/src/description/common.rs deleted file mode 100644 index 4f93d9d46..000000000 --- a/sdp/src/description/common.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::fmt; - -/// Information describes the "i=" field which provides textual information -/// about the session. -pub type Information = String; - -/// ConnectionInformation defines the representation for the "c=" field -/// containing connection data. -#[derive(Debug, Default, Clone)] -pub struct ConnectionInformation { - pub network_type: String, - pub address_type: String, - pub address: Option
, -} - -impl fmt::Display for ConnectionInformation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(address) = &self.address { - write!(f, "{} {} {}", self.network_type, self.address_type, address,) - } else { - write!(f, "{} {}", self.network_type, self.address_type,) - } - } -} - -/// Address describes a structured address token from within the "c=" field. -#[derive(Debug, Default, Clone)] -pub struct Address { - pub address: String, - pub ttl: Option, - pub range: Option, -} - -impl fmt::Display for Address { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut parts = vec![self.address.to_owned()]; - if let Some(t) = &self.ttl { - parts.push(t.to_string()); - } - if let Some(r) = &self.range { - parts.push(r.to_string()); - } - write!(f, "{}", parts.join("/")) - } -} - -/// Bandwidth describes an optional field which denotes the proposed bandwidth -/// to be used by the session or media. -#[derive(Debug, Default, Clone)] -pub struct Bandwidth { - pub experimental: bool, - pub bandwidth_type: String, - pub bandwidth: u64, -} - -impl fmt::Display for Bandwidth { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let output = if self.experimental { "X-" } else { "" }; - write!(f, "{}{}:{}", output, self.bandwidth_type, self.bandwidth) - } -} - -/// EncryptionKey describes the "k=" which conveys encryption key information. -pub type EncryptionKey = String; - -/// Attribute describes the "a=" field which represents the primary means for -/// extending SDP. -#[derive(Debug, Default, Clone)] -pub struct Attribute { - pub key: String, - pub value: Option, -} - -impl fmt::Display for Attribute { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(value) = &self.value { - write!(f, "{}:{}", self.key, value) - } else { - write!(f, "{}", self.key) - } - } -} - -impl Attribute { - /// new constructs a new attribute - pub fn new(key: String, value: Option) -> Self { - Attribute { key, value } - } - - /// is_ice_candidate returns true if the attribute key equals "candidate". - pub fn is_ice_candidate(&self) -> bool { - self.key.as_str() == "candidate" - } -} diff --git a/sdp/src/description/description_test.rs b/sdp/src/description/description_test.rs deleted file mode 100644 index 58ba6a282..000000000 --- a/sdp/src/description/description_test.rs +++ /dev/null @@ -1,607 +0,0 @@ -use std::io::Cursor; - -use url::Url; - -use super::common::*; -use super::media::*; -use super::session::*; -use crate::error::{Error, Result}; - -const CANONICAL_MARSHAL_SDP: &str = "v=0\r\n\ - o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ - s=SDP Seminar\r\n\ - i=A Seminar on the session description protocol\r\n\ - u=http://www.example.com/seminars/sdp.pdf\r\n\ - e=j.doe@example.com (Jane Doe)\r\n\ - p=+1 617 555-6011\r\n\ - c=IN IP4 224.2.17.12/127\r\n\ - b=X-YZ:128\r\n\ - b=AS:12345\r\n\ - t=2873397496 2873404696\r\n\ - t=3034423619 3042462419\r\n\ - r=604800 3600 0 90000\r\n\ - z=2882844526 -3600 2898848070 0\r\n\ - k=prompt\r\n\ - a=candidate:0 1 UDP 2113667327 203.0.113.1 54400 typ host\r\n\ - a=recvonly\r\n\ - m=audio 49170 RTP/AVP 0\r\n\ - i=Vivamus a posuere nisl\r\n\ - c=IN IP4 203.0.113.1\r\n\ - b=X-YZ:128\r\n\ - k=prompt\r\n\ - a=sendrecv\r\n\ - m=video 51372 RTP/AVP 99\r\n\ - a=rtpmap:99 h263-1998/90000\r\n"; - -#[test] -fn test_unmarshal_marshal() -> Result<()> { - let input = CANONICAL_MARSHAL_SDP; - let mut reader = Cursor::new(input.as_bytes()); - let sdp = SessionDescription::unmarshal(&mut reader)?; - let output = sdp.marshal(); - assert_eq!(output, input); - - Ok(()) -} - -#[test] -fn test_marshal() -> Result<()> { - let sd = SessionDescription { - version: 0, - origin: Origin { - username: "jdoe".to_string(), - session_id: 2890844526, - session_version: 2890842807, - network_type: "IN".to_string(), - address_type: "IP4".to_string(), - unicast_address: "10.47.16.5".to_string(), - }, - session_name: "SDP Seminar".to_string(), - session_information: Some("A Seminar on the session description protocol".to_string()), - uri: Some(Url::parse("http://www.example.com/seminars/sdp.pdf")?), - email_address: Some("j.doe@example.com (Jane Doe)".to_string()), - phone_number: Some("+1 617 555-6011".to_string()), - connection_information: Some(ConnectionInformation { - network_type: "IN".to_string(), - address_type: "IP4".to_string(), - address: Some(Address { - address: "224.2.17.12".to_string(), - ttl: Some(127), - range: None, - }), - }), - bandwidth: vec![ - Bandwidth { - experimental: true, - bandwidth_type: "YZ".to_string(), - bandwidth: 128, - }, - Bandwidth { - experimental: false, - bandwidth_type: "AS".to_string(), - bandwidth: 12345, - }, - ], - time_descriptions: vec![ - TimeDescription { - timing: Timing { - start_time: 2873397496, - stop_time: 2873404696, - }, - repeat_times: vec![], - }, - TimeDescription { - timing: Timing { - start_time: 3034423619, - stop_time: 3042462419, - }, - repeat_times: vec![RepeatTime { - interval: 604800, - duration: 3600, - offsets: vec![0, 90000], - }], - }, - ], - time_zones: vec![ - TimeZone { - adjustment_time: 2882844526, - offset: -3600, - }, - TimeZone { - adjustment_time: 2898848070, - offset: 0, - }, - ], - encryption_key: Some("prompt".to_string()), - attributes: vec![ - Attribute::new( - "candidate".to_string(), - Some("0 1 UDP 2113667327 203.0.113.1 54400 typ host".to_string()), - ), - Attribute::new("recvonly".to_string(), None), - ], - media_descriptions: vec![ - MediaDescription { - media_name: MediaName { - media: "audio".to_string(), - port: RangedPort { - value: 49170, - range: None, - }, - protos: vec!["RTP".to_string(), "AVP".to_string()], - formats: vec!["0".to_string()], - }, - media_title: Some("Vivamus a posuere nisl".to_string()), - connection_information: Some(ConnectionInformation { - network_type: "IN".to_string(), - address_type: "IP4".to_string(), - address: Some(Address { - address: "203.0.113.1".to_string(), - ttl: None, - range: None, - }), - }), - bandwidth: vec![Bandwidth { - experimental: true, - bandwidth_type: "YZ".to_string(), - bandwidth: 128, - }], - encryption_key: Some("prompt".to_string()), - attributes: vec![Attribute::new("sendrecv".to_string(), None)], - }, - MediaDescription { - media_name: MediaName { - media: "video".to_string(), - port: RangedPort { - value: 51372, - range: None, - }, - protos: vec!["RTP".to_string(), "AVP".to_string()], - formats: vec!["99".to_string()], - }, - media_title: None, - connection_information: None, - bandwidth: vec![], - encryption_key: None, - attributes: vec![Attribute::new( - "rtpmap".to_string(), - Some("99 h263-1998/90000".to_string()), - )], - }, - ], - }; - - let actual = sd.marshal(); - assert!( - actual == CANONICAL_MARSHAL_SDP, - "error:\n\nEXPECTED:\n{CANONICAL_MARSHAL_SDP}\nACTUAL:\n{actual}!!!!\n" - ); - - Ok(()) -} - -const BASE_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n"; - -const SESSION_INFORMATION_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -i=A Seminar on the session description protocol\r\n\ -t=3034423619 3042462419\r\n"; - -// https://tools.ietf.org/html/rfc4566#section-5 -// Parsers SHOULD be tolerant and also accept records terminated -// with a single newline character. -const SESSION_INFORMATION_SDPLFONLY: &str = "v=0\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\n\ -s=SDP Seminar\n\ -i=A Seminar on the session description protocol\n\ -t=3034423619 3042462419\n"; - -// SessionInformationSDPCROnly = "v=0\r" + -// "o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r" + -// "s=SDP Seminar\r" -// "i=A Seminar on the session description protocol\r" + -// "t=3034423619 3042462419\r" - -// Other SDP parsers (e.g. one in VLC media player) allow -// empty lines. -const SESSION_INFORMATION_SDPEXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -\r\n\ -s=SDP Seminar\r\n\ -\r\n\ -i=A Seminar on the session description protocol\r\n\ -\r\n\ -t=3034423619 3042462419\r\n\ -\r\n"; - -const URI_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -u=http://www.example.com/seminars/sdp.pdf\r\n\ -t=3034423619 3042462419\r\n"; - -const EMAIL_ADDRESS_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -e=j.doe@example.com (Jane Doe)\r\n\ -t=3034423619 3042462419\r\n"; - -const PHONE_NUMBER_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -p=+1 617 555-6011\r\n\ -t=3034423619 3042462419\r\n"; - -const SESSION_CONNECTION_INFORMATION_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -c=IN IP4 224.2.17.12/127\r\n\ -t=3034423619 3042462419\r\n"; - -const SESSION_BANDWIDTH_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -b=X-YZ:128\r\n\ -b=AS:12345\r\n\ -t=3034423619 3042462419\r\n"; - -const TIMING_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n"; - -// Short hand time notation is converted into NTP timestamp format in -// seconds. Because of that unittest comparisons will fail as the same time -// will be expressed in different units. -const REPEAT_TIMES_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -r=604800 3600 0 90000\r\n\ -r=3d 2h 0 21h\r\n"; - -const REPEAT_TIMES_SDPEXPECTED: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -r=604800 3600 0 90000\r\n\ -r=259200 7200 0 75600\r\n"; - -const REPEAT_TIMES_OVERFLOW_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -r=604800 3600 0 90000\r\n\ -r=106751991167301d 2h 0 21h\r\n"; - -const REPEAT_TIMES_SDPEXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -r=604800 3600 0 90000\r\n\ -r=259200 7200 0 75600\r\n\ -\r\n"; - -// The expected value looks a bit different for the same reason as mentioned -// above regarding RepeatTimes. -const TIME_ZONES_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -r=2882844526 -1h 2898848070 0\r\n"; - -const TIME_ZONES_SDPEXPECTED: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -r=2882844526 -3600 2898848070 0\r\n"; - -const TIME_ZONES_SDP2: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -z=2882844526 -3600 2898848070 0\r\n"; - -const TIME_ZONES_SDP2EXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -z=2882844526 -3600 2898848070 0\r\n\ -\r\n"; - -const SESSION_ENCRYPTION_KEY_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -k=prompt\r\n"; - -const SESSION_ENCRYPTION_KEY_SDPEXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -k=prompt\r\n -\r\n"; - -const SESSION_ATTRIBUTES_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -a=rtpmap:96 opus/48000\r\n"; - -const MEDIA_NAME_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n"; - -const MEDIA_NAME_SDPEXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n -\r\n"; - -const MEDIA_TITLE_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -i=Vivamus a posuere nisl\r\n"; - -const MEDIA_CONNECTION_INFORMATION_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -c=IN IP4 203.0.113.1\r\n"; - -const MEDIA_CONNECTION_INFORMATION_SDPEXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -c=IN IP4 203.0.113.1\r\n\ -\r\n"; - -const MEDIA_DESCRIPTION_OUT_OF_ORDER_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -a=rtpmap:99 h263-1998/90000\r\n\ -a=candidate:0 1 UDP 2113667327 203.0.113.1 54400 typ host\r\n\ -c=IN IP4 203.0.113.1\r\n\ -i=Vivamus a posuere nisl\r\n"; - -const MEDIA_DESCRIPTION_OUT_OF_ORDER_SDPACTUAL: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -i=Vivamus a posuere nisl\r\n\ -c=IN IP4 203.0.113.1\r\n\ -a=rtpmap:99 h263-1998/90000\r\n\ -a=candidate:0 1 UDP 2113667327 203.0.113.1 54400 typ host\r\n"; - -const MEDIA_BANDWIDTH_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -b=X-YZ:128\r\n\ -b=AS:12345\r\n"; - -const MEDIA_TRANSPORT_BANDWIDTH_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -b=AS:12345\r\n\ -b=TIAS:12345\r\n"; - -const MEDIA_ENCRYPTION_KEY_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -k=prompt\r\n"; - -const MEDIA_ENCRYPTION_KEY_SDPEXTRA_CRLF: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -k=prompt\r\n\ -\r\n"; - -const MEDIA_ATTRIBUTES_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -t=2873397496 2873404696\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -m=audio 54400 RTP/SAVPF 0 96\r\n\ -a=rtpmap:99 h263-1998/90000\r\n\ -a=candidate:0 1 UDP 2113667327 203.0.113.1 54400 typ host\r\n\ -a=rtcp-fb:97 ccm fir\r\n\ -a=rtcp-fb:97 nack\r\n\ -a=rtcp-fb:97 nack pli\r\n"; - -const CANONICAL_UNMARSHAL_SDP: &str = "v=0\r\n\ -o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5\r\n\ -s=SDP Seminar\r\n\ -i=A Seminar on the session description protocol\r\n\ -u=http://www.example.com/seminars/sdp.pdf\r\n\ -e=j.doe@example.com (Jane Doe)\r\n\ -p=+1 617 555-6011\r\n\ -c=IN IP4 224.2.17.12/127\r\n\ -b=X-YZ:128\r\n\ -b=AS:12345\r\n\ -t=2873397496 2873404696\r\n\ -t=3034423619 3042462419\r\n\ -r=604800 3600 0 90000\r\n\ -z=2882844526 -3600 2898848070 0\r\n\ -k=prompt\r\n\ -a=candidate:0 1 UDP 2113667327 203.0.113.1 54400 typ host\r\n\ -a=recvonly\r\n\ -m=audio 49170 RTP/AVP 0\r\n\ -i=Vivamus a posuere nisl\r\n\ -c=IN IP4 203.0.113.1\r\n\ -b=X-YZ:128\r\n\ -k=prompt\r\n\ -a=sendrecv\r\n\ -m=video 51372 RTP/AVP 99\r\n\ -a=rtpmap:99 h263-1998/90000\r\n"; - -#[test] -fn test_round_trip() -> Result<()> { - let tests = vec![ - ( - "SessionInformationSDPLFOnly", - SESSION_INFORMATION_SDPLFONLY, - Some(SESSION_INFORMATION_SDP), - ), - ( - "SessionInformationSDPExtraCRLF", - SESSION_INFORMATION_SDPEXTRA_CRLF, - Some(SESSION_INFORMATION_SDP), - ), - ("SessionInformation", SESSION_INFORMATION_SDP, None), - ("URI", URI_SDP, None), - ("EmailAddress", EMAIL_ADDRESS_SDP, None), - ("PhoneNumber", PHONE_NUMBER_SDP, None), - ( - "RepeatTimesSDPExtraCRLF", - REPEAT_TIMES_SDPEXTRA_CRLF, - Some(REPEAT_TIMES_SDPEXPECTED), - ), - ( - "SessionConnectionInformation", - SESSION_CONNECTION_INFORMATION_SDP, - None, - ), - ("SessionBandwidth", SESSION_BANDWIDTH_SDP, None), - ("SessionEncryptionKey", SESSION_ENCRYPTION_KEY_SDP, None), - ( - "SessionEncryptionKeyExtraCRLF", - SESSION_ENCRYPTION_KEY_SDPEXTRA_CRLF, - Some(SESSION_ENCRYPTION_KEY_SDP), - ), - ("SessionAttributes", SESSION_ATTRIBUTES_SDP, None), - ( - "TimeZonesSDP2ExtraCRLF", - TIME_ZONES_SDP2EXTRA_CRLF, - Some(TIME_ZONES_SDP2), - ), - ("MediaName", MEDIA_NAME_SDP, None), - ( - "MediaNameExtraCRLF", - MEDIA_NAME_SDPEXTRA_CRLF, - Some(MEDIA_NAME_SDP), - ), - ("MediaTitle", MEDIA_TITLE_SDP, None), - ( - "MediaConnectionInformation", - MEDIA_CONNECTION_INFORMATION_SDP, - None, - ), - ( - "MediaConnectionInformationExtraCRLF", - MEDIA_CONNECTION_INFORMATION_SDPEXTRA_CRLF, - Some(MEDIA_CONNECTION_INFORMATION_SDP), - ), - ( - "MediaDescriptionOutOfOrder", - MEDIA_DESCRIPTION_OUT_OF_ORDER_SDP, - Some(MEDIA_DESCRIPTION_OUT_OF_ORDER_SDPACTUAL), - ), - ("MediaBandwidth", MEDIA_BANDWIDTH_SDP, None), - ( - "MediaTransportBandwidth", - MEDIA_TRANSPORT_BANDWIDTH_SDP, - None, - ), - ("MediaEncryptionKey", MEDIA_ENCRYPTION_KEY_SDP, None), - ( - "MediaEncryptionKeyExtraCRLF", - MEDIA_ENCRYPTION_KEY_SDPEXTRA_CRLF, - Some(MEDIA_ENCRYPTION_KEY_SDP), - ), - ("MediaAttributes", MEDIA_ATTRIBUTES_SDP, None), - ("CanonicalUnmarshal", CANONICAL_UNMARSHAL_SDP, None), - ]; - - for (name, sdp_str, expected) in tests { - let mut reader = Cursor::new(sdp_str.as_bytes()); - let sdp = SessionDescription::unmarshal(&mut reader); - if let Ok(sdp) = sdp { - let actual = sdp.marshal(); - if let Some(expected) = expected { - assert_eq!(actual.as_str(), expected, "{name}\n{sdp_str}"); - } else { - assert_eq!(actual.as_str(), sdp_str, "{name}\n{sdp_str}"); - } - } else { - panic!("{name}\n{sdp_str}"); - } - } - - Ok(()) -} - -#[test] -fn test_unmarshal_repeat_times() -> Result<()> { - let mut reader = Cursor::new(REPEAT_TIMES_SDP.as_bytes()); - let sdp = SessionDescription::unmarshal(&mut reader)?; - let actual = sdp.marshal(); - assert_eq!(actual.as_str(), REPEAT_TIMES_SDPEXPECTED); - Ok(()) -} - -#[test] -fn test_unmarshal_repeat_times_overflow() -> Result<()> { - let mut reader = Cursor::new(REPEAT_TIMES_OVERFLOW_SDP.as_bytes()); - let result = SessionDescription::unmarshal(&mut reader); - assert!(result.is_err()); - assert_eq!( - Error::SdpInvalidValue("106751991167301d".to_owned()), - result.unwrap_err() - ); - Ok(()) -} - -#[test] -fn test_unmarshal_time_zones() -> Result<()> { - let mut reader = Cursor::new(TIME_ZONES_SDP.as_bytes()); - let sdp = SessionDescription::unmarshal(&mut reader)?; - let actual = sdp.marshal(); - assert_eq!(actual.as_str(), TIME_ZONES_SDPEXPECTED); - Ok(()) -} - -#[test] -fn test_unmarshal_non_nil_address() -> Result<()> { - let input = "v=0\r\no=0 0 0 IN IP4 0\r\ns=0\r\nc=IN IP4\r\nt=0 0\r\n"; - let mut reader = Cursor::new(input); - let sdp = SessionDescription::unmarshal(&mut reader); - if let Ok(sdp) = sdp { - let output = sdp.marshal(); - assert_eq!(output.as_str(), input); - } else { - panic!("{}", input); - } - Ok(()) -} diff --git a/sdp/src/description/media.rs b/sdp/src/description/media.rs deleted file mode 100644 index d6c973cc2..000000000 --- a/sdp/src/description/media.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::collections::HashMap; -use std::fmt; - -use url::Url; - -use crate::description::common::*; -use crate::extmap::*; - -/// Constants for extmap key -pub const EXT_MAP_VALUE_TRANSPORT_CC_KEY: isize = 3; -pub const EXT_MAP_VALUE_TRANSPORT_CC_URI: &str = - "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"; - -fn ext_map_uri() -> HashMap { - let mut m = HashMap::new(); - m.insert( - EXT_MAP_VALUE_TRANSPORT_CC_KEY, - EXT_MAP_VALUE_TRANSPORT_CC_URI, - ); - m -} - -/// MediaDescription represents a media type. -/// -#[derive(Debug, Default, Clone)] -pub struct MediaDescription { - /// `m= / ...` - /// - /// - pub media_name: MediaName, - - /// `i=` - /// - /// - pub media_title: Option, - - /// `c= ` - /// - /// - pub connection_information: Option, - - /// `b=:` - /// - /// - pub bandwidth: Vec, - - /// `k=` - /// - /// `k=:` - /// - /// - pub encryption_key: Option, - - /// Attributes are the primary means for extending SDP. Attributes may - /// be defined to be used as "session-level" attributes, "media-level" - /// attributes, or both. - /// - /// - pub attributes: Vec, -} - -impl MediaDescription { - /// attribute returns the value of an attribute and if it exists - pub fn attribute(&self, key: &str) -> Option> { - for a in &self.attributes { - if a.key == key { - return Some(a.value.as_ref().map(|s| s.as_ref())); - } - } - None - } - - /// new_jsep_media_description creates a new MediaName with - /// some settings that are required by the JSEP spec. - pub fn new_jsep_media_description(codec_type: String, _codec_prefs: Vec<&str>) -> Self { - MediaDescription { - media_name: MediaName { - media: codec_type, - port: RangedPort { - value: 9, - range: None, - }, - protos: vec![ - "UDP".to_string(), - "TLS".to_string(), - "RTP".to_string(), - "SAVPF".to_string(), - ], - formats: vec![], - }, - media_title: None, - connection_information: Some(ConnectionInformation { - network_type: "IN".to_string(), - address_type: "IP4".to_string(), - address: Some(Address { - address: "0.0.0.0".to_string(), - ttl: None, - range: None, - }), - }), - bandwidth: vec![], - encryption_key: None, - attributes: vec![], - } - } - - /// with_property_attribute adds a property attribute 'a=key' to the media description - pub fn with_property_attribute(mut self, key: String) -> Self { - self.attributes.push(Attribute::new(key, None)); - self - } - - /// with_value_attribute adds a value attribute 'a=key:value' to the media description - pub fn with_value_attribute(mut self, key: String, value: String) -> Self { - self.attributes.push(Attribute::new(key, Some(value))); - self - } - - /// with_fingerprint adds a fingerprint to the media description - pub fn with_fingerprint(self, algorithm: String, value: String) -> Self { - self.with_value_attribute("fingerprint".to_owned(), algorithm + " " + &value) - } - - /// with_ice_credentials adds ICE credentials to the media description - pub fn with_ice_credentials(self, username: String, password: String) -> Self { - self.with_value_attribute("ice-ufrag".to_string(), username) - .with_value_attribute("ice-pwd".to_string(), password) - } - - /// with_codec adds codec information to the media description - pub fn with_codec( - mut self, - payload_type: u8, - name: String, - clockrate: u32, - channels: u16, - fmtp: String, - ) -> Self { - self.media_name.formats.push(payload_type.to_string()); - let mut rtpmap = format!("{payload_type} {name}/{clockrate}"); - if channels > 0 { - rtpmap += format!("/{channels}").as_str(); - } - - if !fmtp.is_empty() { - self.with_value_attribute("rtpmap".to_string(), rtpmap) - .with_value_attribute("fmtp".to_string(), format!("{payload_type} {fmtp}")) - } else { - self.with_value_attribute("rtpmap".to_string(), rtpmap) - } - } - - /// with_media_source adds media source information to the media description - pub fn with_media_source( - self, - ssrc: u32, - cname: String, - stream_label: String, - label: String, - ) -> Self { - self. - with_value_attribute("ssrc".to_string(), format!("{ssrc} cname:{cname}")). // Deprecated but not phased out? - with_value_attribute("ssrc".to_string(), format!("{ssrc} msid:{stream_label} {label}")). - with_value_attribute("ssrc".to_string(), format!("{ssrc} mslabel:{stream_label}")). // Deprecated but not phased out? - with_value_attribute("ssrc".to_string(), format!("{ssrc} label:{label}")) - // Deprecated but not phased out? - } - - /// with_candidate adds an ICE candidate to the media description - /// Deprecated: use WithICECandidate instead - pub fn with_candidate(self, value: String) -> Self { - self.with_value_attribute("candidate".to_string(), value) - } - - pub fn with_extmap(self, e: ExtMap) -> Self { - self.with_property_attribute(e.marshal()) - } - - /// with_transport_cc_extmap adds an extmap to the media description - pub fn with_transport_cc_extmap(self) -> Self { - let uri = { - let m = ext_map_uri(); - if let Some(uri_str) = m.get(&EXT_MAP_VALUE_TRANSPORT_CC_KEY) { - match Url::parse(uri_str) { - Ok(uri) => Some(uri), - Err(_) => None, - } - } else { - None - } - }; - - let e = ExtMap { - value: EXT_MAP_VALUE_TRANSPORT_CC_KEY, - uri, - ..Default::default() - }; - - self.with_extmap(e) - } -} - -/// RangedPort supports special format for the media field "m=" port value. If -/// it may be necessary to specify multiple transport ports, the protocol allows -/// to write it as: `/` where number of ports is a an -/// offsetting range. -#[derive(Debug, Default, Clone)] -pub struct RangedPort { - pub value: isize, - pub range: Option, -} - -impl fmt::Display for RangedPort { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(range) = self.range { - write!(f, "{}/{}", self.value, range) - } else { - write!(f, "{}", self.value) - } - } -} - -/// MediaName describes the "m=" field storage structure. -#[derive(Debug, Default, Clone)] -pub struct MediaName { - pub media: String, - pub port: RangedPort, - pub protos: Vec, - pub formats: Vec, -} - -impl fmt::Display for MediaName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = [ - self.media.clone(), - self.port.to_string(), - self.protos.join("/"), - self.formats.join(" "), - ]; - write!(f, "{}", s.join(" ")) - } -} - -#[cfg(test)] -mod tests { - use super::MediaDescription; - - #[test] - fn test_attribute_missing() { - let media_description = MediaDescription::default(); - - assert_eq!(media_description.attribute("recvonly"), None); - } - - #[test] - fn test_attribute_present_with_no_value() { - let media_description = - MediaDescription::default().with_property_attribute("recvonly".to_owned()); - - assert_eq!(media_description.attribute("recvonly"), Some(None)); - } - - #[test] - fn test_attribute_present_with_value() { - let media_description = - MediaDescription::default().with_value_attribute("ptime".to_owned(), "1".to_owned()); - - assert_eq!(media_description.attribute("ptime"), Some(Some("1"))); - } -} diff --git a/sdp/src/description/mod.rs b/sdp/src/description/mod.rs deleted file mode 100644 index 0cab15c04..000000000 --- a/sdp/src/description/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -#[cfg(test)] -mod description_test; - -pub mod common; -pub mod media; -pub mod session; diff --git a/sdp/src/description/session.rs b/sdp/src/description/session.rs deleted file mode 100644 index 472fae8c1..000000000 --- a/sdp/src/description/session.rs +++ /dev/null @@ -1,1365 +0,0 @@ -use std::collections::HashMap; -use std::convert::TryFrom; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use std::{fmt, io}; - -use url::Url; - -use super::common::*; -use super::media::*; -use crate::error::{Error, Result}; -use crate::lexer::*; -use crate::util::*; - -/// Constants for SDP attributes used in JSEP -pub const ATTR_KEY_CANDIDATE: &str = "candidate"; -pub const ATTR_KEY_END_OF_CANDIDATES: &str = "end-of-candidates"; -pub const ATTR_KEY_IDENTITY: &str = "identity"; -pub const ATTR_KEY_GROUP: &str = "group"; -pub const ATTR_KEY_SSRC: &str = "ssrc"; -pub const ATTR_KEY_SSRCGROUP: &str = "ssrc-group"; -pub const ATTR_KEY_MSID: &str = "msid"; -pub const ATTR_KEY_MSID_SEMANTIC: &str = "msid-semantic"; -pub const ATTR_KEY_CONNECTION_SETUP: &str = "setup"; -pub const ATTR_KEY_MID: &str = "mid"; -pub const ATTR_KEY_ICELITE: &str = "ice-lite"; -pub const ATTR_KEY_RTCPMUX: &str = "rtcp-mux"; -pub const ATTR_KEY_RTCPRSIZE: &str = "rtcp-rsize"; -pub const ATTR_KEY_INACTIVE: &str = "inactive"; -pub const ATTR_KEY_RECV_ONLY: &str = "recvonly"; -pub const ATTR_KEY_SEND_ONLY: &str = "sendonly"; -pub const ATTR_KEY_SEND_RECV: &str = "sendrecv"; -pub const ATTR_KEY_EXT_MAP: &str = "extmap"; - -/// Constants for semantic tokens used in JSEP -pub const SEMANTIC_TOKEN_LIP_SYNCHRONIZATION: &str = "LS"; -pub const SEMANTIC_TOKEN_FLOW_IDENTIFICATION: &str = "FID"; -pub const SEMANTIC_TOKEN_FORWARD_ERROR_CORRECTION: &str = "FEC"; -pub const SEMANTIC_TOKEN_WEBRTC_MEDIA_STREAMS: &str = "WMS"; - -/// Version describes the value provided by the "v=" field which gives -/// the version of the Session Description Protocol. -pub type Version = isize; - -/// Origin defines the structure for the "o=" field which provides the -/// originator of the session plus a session identifier and version number. -#[derive(Debug, Default, Clone)] -pub struct Origin { - pub username: String, - pub session_id: u64, - pub session_version: u64, - pub network_type: String, - pub address_type: String, - pub unicast_address: String, -} - -impl fmt::Display for Origin { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {} {} {} {} {}", - self.username, - self.session_id, - self.session_version, - self.network_type, - self.address_type, - self.unicast_address, - ) - } -} - -impl Origin { - pub fn new() -> Self { - Origin { - username: "".to_owned(), - session_id: 0, - session_version: 0, - network_type: "".to_owned(), - address_type: "".to_owned(), - unicast_address: "".to_owned(), - } - } -} - -/// SessionName describes a structured representations for the "s=" field -/// and is the textual session name. -pub type SessionName = String; - -/// EmailAddress describes a structured representations for the "e=" line -/// which specifies email contact information for the person responsible for -/// the conference. -pub type EmailAddress = String; - -/// PhoneNumber describes a structured representations for the "p=" line -/// specify phone contact information for the person responsible for the -/// conference. -pub type PhoneNumber = String; - -/// TimeZone defines the structured object for "z=" line which describes -/// repeated sessions scheduling. -#[derive(Debug, Default, Clone)] -pub struct TimeZone { - pub adjustment_time: u64, - pub offset: i64, -} - -impl fmt::Display for TimeZone { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {}", self.adjustment_time, self.offset) - } -} - -/// TimeDescription describes "t=", "r=" fields of the session description -/// which are used to specify the start and stop times for a session as well as -/// repeat intervals and durations for the scheduled session. -#[derive(Debug, Default, Clone)] -pub struct TimeDescription { - /// `t= ` - /// - /// - pub timing: Timing, - - /// `r= ` - /// - /// - pub repeat_times: Vec, -} - -/// Timing defines the "t=" field's structured representation for the start and -/// stop times. -#[derive(Debug, Default, Clone)] -pub struct Timing { - pub start_time: u64, - pub stop_time: u64, -} - -impl fmt::Display for Timing { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {}", self.start_time, self.stop_time) - } -} - -/// RepeatTime describes the "r=" fields of the session description which -/// represents the intervals and durations for repeated scheduled sessions. -#[derive(Debug, Default, Clone)] -pub struct RepeatTime { - pub interval: i64, - pub duration: i64, - pub offsets: Vec, -} - -impl fmt::Display for RepeatTime { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut fields = vec![format!("{}", self.interval), format!("{}", self.duration)]; - for value in &self.offsets { - fields.push(format!("{value}")); - } - write!(f, "{}", fields.join(" ")) - } -} - -/// SessionDescription is a a well-defined format for conveying sufficient -/// information to discover and participate in a multimedia session. -#[derive(Debug, Default, Clone)] -pub struct SessionDescription { - /// `v=0` - /// - /// - pub version: Version, - - /// `o= ` - /// - /// - pub origin: Origin, - - /// `s=` - /// - /// - pub session_name: SessionName, - - /// `i=` - /// - /// - pub session_information: Option, - - /// `u=` - /// - /// - pub uri: Option, - - /// `e=` - /// - /// - pub email_address: Option, - - /// `p=` - /// - /// - pub phone_number: Option, - - /// `c= ` - /// - /// - pub connection_information: Option, - - /// `b=:` - /// - /// - pub bandwidth: Vec, - - /// - /// - pub time_descriptions: Vec, - - /// `z= ...` - /// - /// - pub time_zones: Vec, - - /// `k=` - /// - /// `k=:` - /// - /// - pub encryption_key: Option, - - /// `a=` - /// - /// `a=:` - /// - /// - pub attributes: Vec, - - /// - pub media_descriptions: Vec, -} - -/// Reset cleans the SessionDescription, and sets all fields back to their default values -impl SessionDescription { - /// API to match draft-ietf-rtcweb-jsep - /// Move to webrtc or its own package? - - /// NewJSEPSessionDescription creates a new SessionDescription with - /// some settings that are required by the JSEP spec. - pub fn new_jsep_session_description(identity: bool) -> Self { - let d = SessionDescription { - version: 0, - origin: Origin { - username: "-".to_string(), - session_id: new_session_id(), - session_version: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_else(|_| Duration::from_secs(0)) - .subsec_nanos() as u64, - network_type: "IN".to_string(), - address_type: "IP4".to_string(), - unicast_address: "0.0.0.0".to_string(), - }, - session_name: "-".to_string(), - session_information: None, - uri: None, - email_address: None, - phone_number: None, - connection_information: None, - bandwidth: vec![], - time_descriptions: vec![TimeDescription { - timing: Timing { - start_time: 0, - stop_time: 0, - }, - repeat_times: vec![], - }], - time_zones: vec![], - encryption_key: None, - attributes: vec![], // TODO: implement trickle ICE - media_descriptions: vec![], - }; - - if identity { - d.with_property_attribute(ATTR_KEY_IDENTITY.to_string()) - } else { - d - } - } - - /// WithPropertyAttribute adds a property attribute 'a=key' to the session description - pub fn with_property_attribute(mut self, key: String) -> Self { - self.attributes.push(Attribute::new(key, None)); - self - } - - /// WithValueAttribute adds a value attribute 'a=key:value' to the session description - pub fn with_value_attribute(mut self, key: String, value: String) -> Self { - self.attributes.push(Attribute::new(key, Some(value))); - self - } - - /// WithFingerprint adds a fingerprint to the session description - pub fn with_fingerprint(self, algorithm: String, value: String) -> Self { - self.with_value_attribute("fingerprint".to_string(), algorithm + " " + value.as_str()) - } - - /// WithMedia adds a media description to the session description - pub fn with_media(mut self, md: MediaDescription) -> Self { - self.media_descriptions.push(md); - self - } - - fn build_codec_map(&self) -> HashMap { - let mut codecs: HashMap = HashMap::new(); - - for m in &self.media_descriptions { - for a in &m.attributes { - let attr = a.to_string(); - if attr.starts_with("rtpmap:") { - if let Ok(codec) = parse_rtpmap(&attr) { - merge_codecs(codec, &mut codecs); - } - } else if attr.starts_with("fmtp:") { - if let Ok(codec) = parse_fmtp(&attr) { - merge_codecs(codec, &mut codecs); - } - } else if attr.starts_with("rtcp-fb:") { - if let Ok(codec) = parse_rtcp_fb(&attr) { - merge_codecs(codec, &mut codecs); - } - } - } - } - - codecs - } - - /// get_codec_for_payload_type scans the SessionDescription for the given payload type and returns the codec - pub fn get_codec_for_payload_type(&self, payload_type: u8) -> Result { - let codecs = self.build_codec_map(); - - if let Some(codec) = codecs.get(&payload_type) { - Ok(codec.clone()) - } else { - Err(Error::PayloadTypeNotFound) - } - } - - /// get_payload_type_for_codec scans the SessionDescription for a codec that matches the provided codec - /// as closely as possible and returns its payload type - pub fn get_payload_type_for_codec(&self, wanted: &Codec) -> Result { - let codecs = self.build_codec_map(); - - for (payload_type, codec) in codecs.iter() { - if codecs_match(wanted, codec) { - return Ok(*payload_type); - } - } - - Err(Error::CodecNotFound) - } - - /// Attribute returns the value of an attribute and if it exists - pub fn attribute(&self, key: &str) -> Option<&String> { - for a in &self.attributes { - if a.key == key { - return a.value.as_ref(); - } - } - None - } - - /// Marshal takes a SDP struct to text - /// - /// - /// - /// Session description - /// v= (protocol version) - /// o= (originator and session identifier) - /// s= (session name) - /// i=* (session information) - /// u=* (URI of description) - /// e=* (email address) - /// p=* (phone number) - /// c=* (connection information -- not required if included in - /// all media) - /// b=* (zero or more bandwidth information lines) - /// One or more time descriptions ("t=" and "r=" lines; see below) - /// z=* (time zone adjustments) - /// k=* (encryption key) - /// a=* (zero or more session attribute lines) - /// Zero or more media descriptions - /// - /// Time description - /// t= (time the session is active) - /// r=* (zero or more repeat times) - /// - /// Media description, if present - /// m= (media name and transport address) - /// i=* (media title) - /// c=* (connection information -- optional if included at - /// session level) - /// b=* (zero or more bandwidth information lines) - /// k=* (encryption key) - /// a=* (zero or more media attribute lines) - pub fn marshal(&self) -> String { - let mut result = String::new(); - - result += key_value_build("v=", Some(&self.version.to_string())).as_str(); - result += key_value_build("o=", Some(&self.origin.to_string())).as_str(); - result += key_value_build("s=", Some(&self.session_name)).as_str(); - - result += key_value_build("i=", self.session_information.as_ref()).as_str(); - - if let Some(uri) = &self.uri { - result += key_value_build("u=", Some(&format!("{uri}"))).as_str(); - } - result += key_value_build("e=", self.email_address.as_ref()).as_str(); - result += key_value_build("p=", self.phone_number.as_ref()).as_str(); - if let Some(connection_information) = &self.connection_information { - result += key_value_build("c=", Some(&connection_information.to_string())).as_str(); - } - - for bandwidth in &self.bandwidth { - result += key_value_build("b=", Some(&bandwidth.to_string())).as_str(); - } - for time_description in &self.time_descriptions { - result += key_value_build("t=", Some(&time_description.timing.to_string())).as_str(); - for repeat_time in &time_description.repeat_times { - result += key_value_build("r=", Some(&repeat_time.to_string())).as_str(); - } - } - if !self.time_zones.is_empty() { - let mut time_zones = vec![]; - for time_zone in &self.time_zones { - time_zones.push(time_zone.to_string()); - } - result += key_value_build("z=", Some(&time_zones.join(" "))).as_str(); - } - result += key_value_build("k=", self.encryption_key.as_ref()).as_str(); - for attribute in &self.attributes { - result += key_value_build("a=", Some(&attribute.to_string())).as_str(); - } - - for media_description in &self.media_descriptions { - result += - key_value_build("m=", Some(&media_description.media_name.to_string())).as_str(); - result += key_value_build("i=", media_description.media_title.as_ref()).as_str(); - if let Some(connection_information) = &media_description.connection_information { - result += key_value_build("c=", Some(&connection_information.to_string())).as_str(); - } - for bandwidth in &media_description.bandwidth { - result += key_value_build("b=", Some(&bandwidth.to_string())).as_str(); - } - result += key_value_build("k=", media_description.encryption_key.as_ref()).as_str(); - for attribute in &media_description.attributes { - result += key_value_build("a=", Some(&attribute.to_string())).as_str(); - } - } - - result - } - - /// Unmarshal is the primary function that deserializes the session description - /// message and stores it inside of a structured SessionDescription object. - /// - /// The States Transition Table describes the computation flow between functions - /// (namely s1, s2, s3, ...) for a parsing procedure that complies with the - /// specifications laid out by the rfc4566#section-5 as well as by JavaScript - /// Session Establishment Protocol draft. Links: - /// - /// - /// - /// - /// - /// Session description - /// v= (protocol version) - /// o= (originator and session identifier) - /// s= (session name) - /// i=* (session information) - /// u=* (URI of description) - /// e=* (email address) - /// p=* (phone number) - /// c=* (connection information -- not required if included in - /// all media) - /// b=* (zero or more bandwidth information lines) - /// One or more time descriptions ("t=" and "r=" lines; see below) - /// z=* (time zone adjustments) - /// k=* (encryption key) - /// a=* (zero or more session attribute lines) - /// Zero or more media descriptions - /// - /// Time description - /// t= (time the session is active) - /// r=* (zero or more repeat times) - /// - /// Media description, if present - /// m= (media name and transport address) - /// i=* (media title) - /// c=* (connection information -- optional if included at - /// session level) - /// b=* (zero or more bandwidth information lines) - /// k=* (encryption key) - /// a=* (zero or more media attribute lines) - /// - /// In order to generate the following state table and draw subsequent - /// deterministic finite-state automota ("DFA") the following regex was used to - /// derive the DFA: - /// vosi?u?e?p?c?b*(tr*)+z?k?a*(mi?c?b*k?a*)* - /// possible place and state to exit: - /// ** * * * ** * * * * - /// 99 1 1 1 11 1 1 1 1 - /// 3 1 1 26 5 5 4 4 - /// - /// Please pay close attention to the `k`, and `a` parsing states. In the table - /// below in order to distinguish between the states belonging to the media - /// description as opposed to the session description, the states are marked - /// with an asterisk ("a*", "k*"). - /// - /// ```ignore - /// +--------+----+-------+----+-----+----+-----+---+----+----+---+---+-----+---+---+----+---+----+ - /// | STATES | a* | a*,k* | a | a,k | b | b,c | e | i | m | o | p | r,t | s | t | u | v | z | - /// +--------+----+-------+----+-----+----+-----+---+----+----+---+---+-----+---+---+----+---+----+ - /// | s1 | | | | | | | | | | | | | | | | 2 | | - /// | s2 | | | | | | | | | | 3 | | | | | | | | - /// | s3 | | | | | | | | | | | | | 4 | | | | | - /// | s4 | | | | | | 5 | 6 | 7 | | | 8 | | | 9 | 10 | | | - /// | s5 | | | | | 5 | | | | | | | | | 9 | | | | - /// | s6 | | | | | | 5 | | | | | 8 | | | 9 | | | | - /// | s7 | | | | | | 5 | 6 | | | | 8 | | | 9 | 10 | | | - /// | s8 | | | | | | 5 | | | | | | | | 9 | | | | - /// | s9 | | | | 11 | | | | | 12 | | | 9 | | | | | 13 | - /// | s10 | | | | | | 5 | 6 | | | | 8 | | | 9 | | | | - /// | s11 | | | 11 | | | | | | 12 | | | | | | | | | - /// | s12 | | 14 | | | | 15 | | 16 | 12 | | | | | | | | | - /// | s13 | | | | 11 | | | | | 12 | | | | | | | | | - /// | s14 | 14 | | | | | | | | 12 | | | | | | | | | - /// | s15 | | 14 | | | 15 | | | | 12 | | | | | | | | | - /// | s16 | | 14 | | | | 15 | | | 12 | | | | | | | | | - /// +--------+----+-------+----+-----+----+-----+---+----+----+---+---+-----+---+---+----+---+----+ - /// ``` - pub fn unmarshal(reader: &mut R) -> Result { - let mut lexer = Lexer { - desc: SessionDescription { - version: 0, - origin: Origin::new(), - session_name: "".to_owned(), - session_information: None, - uri: None, - email_address: None, - phone_number: None, - connection_information: None, - bandwidth: vec![], - time_descriptions: vec![], - time_zones: vec![], - encryption_key: None, - attributes: vec![], - media_descriptions: vec![], - }, - reader, - }; - - let mut state = Some(StateFn { f: s1 }); - while let Some(s) = state { - state = (s.f)(&mut lexer)?; - } - - Ok(lexer.desc) - } -} - -impl From for String { - fn from(sdp: SessionDescription) -> String { - sdp.marshal() - } -} - -impl TryFrom for SessionDescription { - type Error = Error; - fn try_from(sdp_string: String) -> Result { - let mut reader = io::Cursor::new(sdp_string.as_bytes()); - let session_description = SessionDescription::unmarshal(&mut reader)?; - Ok(session_description) - } -} - -fn s1<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - if &key == b"v=" { - return Ok(Some(StateFn { - f: unmarshal_protocol_version, - })); - } - - Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)) -} - -fn s2<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - if &key == b"o=" { - return Ok(Some(StateFn { - f: unmarshal_origin, - })); - } - - Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)) -} - -fn s3<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - if &key == b"s=" { - return Ok(Some(StateFn { - f: unmarshal_session_name, - })); - } - - Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)) -} - -fn s4<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - match key.as_slice() { - b"i=" => Ok(Some(StateFn { - f: unmarshal_session_information, - })), - b"u=" => Ok(Some(StateFn { f: unmarshal_uri })), - b"e=" => Ok(Some(StateFn { f: unmarshal_email })), - b"p=" => Ok(Some(StateFn { f: unmarshal_phone })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_session_connection_information, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_session_bandwidth, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s5<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - match key.as_slice() { - b"b=" => Ok(Some(StateFn { - f: unmarshal_session_bandwidth, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s6<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - match key.as_slice() { - b"p=" => Ok(Some(StateFn { f: unmarshal_phone })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_session_connection_information, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_session_bandwidth, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s7<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - match key.as_slice() { - b"u=" => Ok(Some(StateFn { f: unmarshal_uri })), - b"e=" => Ok(Some(StateFn { f: unmarshal_email })), - b"p=" => Ok(Some(StateFn { f: unmarshal_phone })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_session_connection_information, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_session_bandwidth, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s8<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - match key.as_slice() { - b"c=" => Ok(Some(StateFn { - f: unmarshal_session_connection_information, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_session_bandwidth, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s9<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"z=" => Ok(Some(StateFn { - f: unmarshal_time_zones, - })), - b"k=" => Ok(Some(StateFn { - f: unmarshal_session_encryption_key, - })), - b"a=" => Ok(Some(StateFn { - f: unmarshal_session_attribute, - })), - b"r=" => Ok(Some(StateFn { - f: unmarshal_repeat_times, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s10<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, _) = read_type(lexer.reader)?; - match key.as_slice() { - b"e=" => Ok(Some(StateFn { f: unmarshal_email })), - b"p=" => Ok(Some(StateFn { f: unmarshal_phone })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_session_connection_information, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_session_bandwidth, - })), - b"t=" => Ok(Some(StateFn { - f: unmarshal_timing, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s11<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"a=" => Ok(Some(StateFn { - f: unmarshal_session_attribute, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s12<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"a=" => Ok(Some(StateFn { - f: unmarshal_media_attribute, - })), - b"k=" => Ok(Some(StateFn { - f: unmarshal_media_encryption_key, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_media_bandwidth, - })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_media_connection_information, - })), - b"i=" => Ok(Some(StateFn { - f: unmarshal_media_title, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s13<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"a=" => Ok(Some(StateFn { - f: unmarshal_session_attribute, - })), - b"k=" => Ok(Some(StateFn { - f: unmarshal_session_encryption_key, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s14<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"a=" => Ok(Some(StateFn { - f: unmarshal_media_attribute, - })), - // Non-spec ordering - b"k=" => Ok(Some(StateFn { - f: unmarshal_media_encryption_key, - })), - // Non-spec ordering - b"b=" => Ok(Some(StateFn { - f: unmarshal_media_bandwidth, - })), - // Non-spec ordering - b"c=" => Ok(Some(StateFn { - f: unmarshal_media_connection_information, - })), - // Non-spec ordering - b"i=" => Ok(Some(StateFn { - f: unmarshal_media_title, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s15<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"a=" => Ok(Some(StateFn { - f: unmarshal_media_attribute, - })), - b"k=" => Ok(Some(StateFn { - f: unmarshal_media_encryption_key, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_media_bandwidth, - })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_media_connection_information, - })), - // Non-spec ordering - b"i=" => Ok(Some(StateFn { - f: unmarshal_media_title, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn s16<'a, R: io::BufRead + io::Seek>(lexer: &mut Lexer<'a, R>) -> Result>> { - let (key, num_bytes) = read_type(lexer.reader)?; - if key.is_empty() && num_bytes == 0 { - return Ok(None); - } - - match key.as_slice() { - b"a=" => Ok(Some(StateFn { - f: unmarshal_media_attribute, - })), - b"k=" => Ok(Some(StateFn { - f: unmarshal_media_encryption_key, - })), - b"c=" => Ok(Some(StateFn { - f: unmarshal_media_connection_information, - })), - b"b=" => Ok(Some(StateFn { - f: unmarshal_media_bandwidth, - })), - // Non-spec ordering - b"i=" => Ok(Some(StateFn { - f: unmarshal_media_title, - })), - b"m=" => Ok(Some(StateFn { - f: unmarshal_media_description, - })), - _ => Err(Error::SdpInvalidSyntax(String::from_utf8(key)?)), - } -} - -fn unmarshal_protocol_version<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let version = value.parse::()?; - - // As off the latest draft of the rfc this value is required to be 0. - // https://tools.ietf.org/html/draft-ietf-rtcweb-jsep-24#section-5.8.1 - if version != 0 { - return Err(Error::SdpInvalidSyntax(value)); - } - - Ok(Some(StateFn { f: s2 })) -} - -fn unmarshal_origin<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let fields: Vec<&str> = value.split_whitespace().collect(); - if fields.len() != 6 { - return Err(Error::SdpInvalidSyntax(format!("`o={value}`"))); - } - - let session_id = fields[1].parse::()?; - let session_version = fields[2].parse::()?; - - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-8.2.6 - let i = index_of(fields[3], &["IN"]); - if i == -1 { - return Err(Error::SdpInvalidValue(fields[3].to_owned())); - } - - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-8.2.7 - let i = index_of(fields[4], &["IP4", "IP6"]); - if i == -1 { - return Err(Error::SdpInvalidValue(fields[4].to_owned())); - } - - // TODO validated UnicastAddress - - lexer.desc.origin = Origin { - username: fields[0].to_owned(), - session_id, - session_version, - network_type: fields[3].to_owned(), - address_type: fields[4].to_owned(), - unicast_address: fields[5].to_owned(), - }; - - Ok(Some(StateFn { f: s3 })) -} - -fn unmarshal_session_name<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.session_name = value; - Ok(Some(StateFn { f: s4 })) -} - -fn unmarshal_session_information<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.session_information = Some(value); - Ok(Some(StateFn { f: s7 })) -} - -fn unmarshal_uri<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.uri = Some(Url::parse(&value)?); - Ok(Some(StateFn { f: s10 })) -} - -fn unmarshal_email<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.email_address = Some(value); - Ok(Some(StateFn { f: s6 })) -} - -fn unmarshal_phone<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.phone_number = Some(value); - Ok(Some(StateFn { f: s8 })) -} - -fn unmarshal_session_connection_information<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.connection_information = unmarshal_connection_information(&value)?; - Ok(Some(StateFn { f: s5 })) -} - -fn unmarshal_connection_information(value: &str) -> Result> { - let fields: Vec<&str> = value.split_whitespace().collect(); - if fields.len() < 2 { - return Err(Error::SdpInvalidSyntax(format!("`c={value}`"))); - } - - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-8.2.6 - let i = index_of(fields[0], &["IN"]); - if i == -1 { - return Err(Error::SdpInvalidValue(fields[0].to_owned())); - } - - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-8.2.7 - let i = index_of(fields[1], &["IP4", "IP6"]); - if i == -1 { - return Err(Error::SdpInvalidValue(fields[1].to_owned())); - } - - let address = if fields.len() > 2 { - Some(Address { - address: fields[2].to_owned(), - ttl: None, - range: None, - }) - } else { - None - }; - - Ok(Some(ConnectionInformation { - network_type: fields[0].to_owned(), - address_type: fields[1].to_owned(), - address, - })) -} - -fn unmarshal_session_bandwidth<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.bandwidth.push(unmarshal_bandwidth(&value)?); - Ok(Some(StateFn { f: s5 })) -} - -fn unmarshal_bandwidth(value: &str) -> Result { - let mut parts: Vec<&str> = value.split(':').collect(); - if parts.len() != 2 { - return Err(Error::SdpInvalidSyntax(format!("`b={value}`"))); - } - - let experimental = parts[0].starts_with("X-"); - if experimental { - parts[0] = parts[0].trim_start_matches("X-"); - } else { - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-5.8 and - // https://datatracker.ietf.org/doc/html/rfc3890 - let i = index_of(parts[0], &["CT", "AS", "TIAS"]); - if i == -1 { - return Err(Error::SdpInvalidValue(parts[0].to_owned())); - } - } - - let bandwidth = parts[1].parse::()?; - - Ok(Bandwidth { - experimental, - bandwidth_type: parts[0].to_owned(), - bandwidth, - }) -} - -fn unmarshal_timing<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let fields: Vec<&str> = value.split_whitespace().collect(); - if fields.len() < 2 { - return Err(Error::SdpInvalidSyntax(format!("`t={value}`"))); - } - - let start_time = fields[0].parse::()?; - let stop_time = fields[1].parse::()?; - - lexer.desc.time_descriptions.push(TimeDescription { - timing: Timing { - start_time, - stop_time, - }, - repeat_times: vec![], - }); - - Ok(Some(StateFn { f: s9 })) -} - -fn unmarshal_repeat_times<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let fields: Vec<&str> = value.split_whitespace().collect(); - if fields.len() < 3 { - return Err(Error::SdpInvalidSyntax(format!("`r={value}`"))); - } - - if let Some(latest_time_desc) = lexer.desc.time_descriptions.last_mut() { - let interval = parse_time_units(fields[0])?; - let duration = parse_time_units(fields[1])?; - let mut offsets = vec![]; - for field in fields.iter().skip(2) { - let offset = parse_time_units(field)?; - offsets.push(offset); - } - latest_time_desc.repeat_times.push(RepeatTime { - interval, - duration, - offsets, - }); - - Ok(Some(StateFn { f: s9 })) - } else { - Err(Error::SdpEmptyTimeDescription) - } -} - -fn unmarshal_time_zones<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - // These fields are transimitted in pairs - // z= .... - // so we are making sure that there are actually multiple of 2 total. - let fields: Vec<&str> = value.split_whitespace().collect(); - if fields.len() % 2 != 0 { - return Err(Error::SdpInvalidSyntax(format!("`t={value}`"))); - } - - for i in (0..fields.len()).step_by(2) { - let adjustment_time = fields[i].parse::()?; - let offset = parse_time_units(fields[i + 1])?; - - lexer.desc.time_zones.push(TimeZone { - adjustment_time, - offset, - }); - } - - Ok(Some(StateFn { f: s13 })) -} - -fn unmarshal_session_encryption_key<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - lexer.desc.encryption_key = Some(value); - Ok(Some(StateFn { f: s11 })) -} - -fn unmarshal_session_attribute<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let fields: Vec<&str> = value.splitn(2, ':').collect(); - let attribute = if fields.len() == 2 { - Attribute { - key: fields[0].to_owned(), - value: Some(fields[1].to_owned()), - } - } else { - Attribute { - key: fields[0].to_owned(), - value: None, - } - }; - lexer.desc.attributes.push(attribute); - - Ok(Some(StateFn { f: s11 })) -} - -fn unmarshal_media_description<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let fields: Vec<&str> = value.split_whitespace().collect(); - if fields.len() < 4 { - return Err(Error::SdpInvalidSyntax(format!("`m={value}`"))); - } - - // - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-5.14 - // including "image", registered here: - // https://datatracker.ietf.org/doc/html/rfc6466 - let i = index_of( - fields[0], - &["audio", "video", "text", "application", "message", "image"], - ); - if i == -1 { - return Err(Error::SdpInvalidValue(fields[0].to_owned())); - } - - // - let parts: Vec<&str> = fields[1].split('/').collect(); - let port_value = parts[0].parse::()? as isize; - let port_range = if parts.len() > 1 { - Some(parts[1].parse::()? as isize) - } else { - None - }; - - // - // Set according to currently registered with IANA - // https://tools.ietf.org/html/rfc4566#section-5.14 - let mut protos = vec![]; - for proto in fields[2].split('/').collect::>() { - let i = index_of( - proto, - &[ - "UDP", "RTP", "AVP", "SAVP", "SAVPF", "TLS", "DTLS", "SCTP", "AVPF", "udptl", - ], - ); - if i == -1 { - return Err(Error::SdpInvalidValue(fields[2].to_owned())); - } - protos.push(proto.to_owned()); - } - - // ... - let mut formats = vec![]; - for field in fields.iter().skip(3) { - formats.push(field.to_string()); - } - - lexer.desc.media_descriptions.push(MediaDescription { - media_name: MediaName { - media: fields[0].to_owned(), - port: RangedPort { - value: port_value, - range: port_range, - }, - protos, - formats, - }, - media_title: None, - connection_information: None, - bandwidth: vec![], - encryption_key: None, - attributes: vec![], - }); - - Ok(Some(StateFn { f: s12 })) -} - -fn unmarshal_media_title<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - if let Some(latest_media_desc) = lexer.desc.media_descriptions.last_mut() { - latest_media_desc.media_title = Some(value); - Ok(Some(StateFn { f: s16 })) - } else { - Err(Error::SdpEmptyTimeDescription) - } -} - -fn unmarshal_media_connection_information<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - if let Some(latest_media_desc) = lexer.desc.media_descriptions.last_mut() { - latest_media_desc.connection_information = unmarshal_connection_information(&value)?; - Ok(Some(StateFn { f: s15 })) - } else { - Err(Error::SdpEmptyTimeDescription) - } -} - -fn unmarshal_media_bandwidth<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - if let Some(latest_media_desc) = lexer.desc.media_descriptions.last_mut() { - let bandwidth = unmarshal_bandwidth(&value)?; - latest_media_desc.bandwidth.push(bandwidth); - Ok(Some(StateFn { f: s15 })) - } else { - Err(Error::SdpEmptyTimeDescription) - } -} - -fn unmarshal_media_encryption_key<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - if let Some(latest_media_desc) = lexer.desc.media_descriptions.last_mut() { - latest_media_desc.encryption_key = Some(value); - Ok(Some(StateFn { f: s14 })) - } else { - Err(Error::SdpEmptyTimeDescription) - } -} - -fn unmarshal_media_attribute<'a, R: io::BufRead + io::Seek>( - lexer: &mut Lexer<'a, R>, -) -> Result>> { - let (value, _) = read_value(lexer.reader)?; - - let fields: Vec<&str> = value.splitn(2, ':').collect(); - let attribute = if fields.len() == 2 { - Attribute { - key: fields[0].to_owned(), - value: Some(fields[1].to_owned()), - } - } else { - Attribute { - key: fields[0].to_owned(), - value: None, - } - }; - - if let Some(latest_media_desc) = lexer.desc.media_descriptions.last_mut() { - latest_media_desc.attributes.push(attribute); - Ok(Some(StateFn { f: s14 })) - } else { - Err(Error::SdpEmptyTimeDescription) - } -} - -fn parse_time_units(value: &str) -> Result { - // Some time offsets in the protocol can be provided with a shorthand - // notation. This code ensures to convert it to NTP timestamp format. - let val = value.as_bytes(); - let len = val.len(); - let (num, factor) = match val.last() { - Some(b'd') => (&value[..len - 1], 86400), // days - Some(b'h') => (&value[..len - 1], 3600), // hours - Some(b'm') => (&value[..len - 1], 60), // minutes - Some(b's') => (&value[..len - 1], 1), // seconds (allowed for completeness) - _ => (value, 1), - }; - num.parse::()? - .checked_mul(factor) - .ok_or_else(|| Error::SdpInvalidValue(value.to_owned())) -} diff --git a/sdp/src/direction/direction_test.rs b/sdp/src/direction/direction_test.rs deleted file mode 100644 index ff543f754..000000000 --- a/sdp/src/direction/direction_test.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::iter::Iterator; - -use super::*; - -#[test] -fn test_new_direction() { - let passingtests = [ - ("sendrecv", Direction::SendRecv), - ("sendonly", Direction::SendOnly), - ("recvonly", Direction::RecvOnly), - ("inactive", Direction::Inactive), - ]; - - let failingtests = ["", "notadirection"]; - - for (i, u) in passingtests.iter().enumerate() { - let dir = Direction::new(u.0); - assert!(u.1 == dir, "{}: {}", i, u.0); - } - for &u in failingtests.iter() { - let dir = Direction::new(u); - assert!(dir == Direction::Unspecified); - } -} - -#[test] -fn test_direction_string() { - let tests = [ - (Direction::Unspecified, DIRECTION_UNSPECIFIED_STR), - (Direction::SendRecv, "sendrecv"), - (Direction::SendOnly, "sendonly"), - (Direction::RecvOnly, "recvonly"), - (Direction::Inactive, "inactive"), - ]; - - for (i, u) in tests.iter().enumerate() { - assert!(u.1 == u.0.to_string(), "{}: {}", i, u.1); - } -} diff --git a/sdp/src/direction/mod.rs b/sdp/src/direction/mod.rs deleted file mode 100644 index 23e5e22e3..000000000 --- a/sdp/src/direction/mod.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::fmt; - -#[cfg(test)] -mod direction_test; - -/// Direction is a marker for transmission direction of an endpoint -#[derive(Default, Debug, PartialEq, Eq, Clone)] -pub enum Direction { - #[default] - Unspecified = 0, - /// Direction::SendRecv is for bidirectional communication - SendRecv = 1, - /// Direction::SendOnly is for outgoing communication - SendOnly = 2, - /// Direction::RecvOnly is for incoming communication - RecvOnly = 3, - /// Direction::Inactive is for no communication - Inactive = 4, -} - -const DIRECTION_SEND_RECV_STR: &str = "sendrecv"; -const DIRECTION_SEND_ONLY_STR: &str = "sendonly"; -const DIRECTION_RECV_ONLY_STR: &str = "recvonly"; -const DIRECTION_INACTIVE_STR: &str = "inactive"; -const DIRECTION_UNSPECIFIED_STR: &str = "Unspecified"; - -impl fmt::Display for Direction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - Direction::SendRecv => DIRECTION_SEND_RECV_STR, - Direction::SendOnly => DIRECTION_SEND_ONLY_STR, - Direction::RecvOnly => DIRECTION_RECV_ONLY_STR, - Direction::Inactive => DIRECTION_INACTIVE_STR, - _ => DIRECTION_UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -impl Direction { - /// new defines a procedure for creating a new direction from a raw string. - pub fn new(raw: &str) -> Self { - match raw { - DIRECTION_SEND_RECV_STR => Direction::SendRecv, - DIRECTION_SEND_ONLY_STR => Direction::SendOnly, - DIRECTION_RECV_ONLY_STR => Direction::RecvOnly, - DIRECTION_INACTIVE_STR => Direction::Inactive, - _ => Direction::Unspecified, - } - } -} diff --git a/sdp/src/error.rs b/sdp/src/error.rs deleted file mode 100644 index 3a912dfa6..000000000 --- a/sdp/src/error.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::io; -use std::num::ParseIntError; -use std::string::FromUtf8Error; - -use substring::Substring; -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("codec not found")] - CodecNotFound, - #[error("missing whitespace")] - MissingWhitespace, - #[error("missing colon")] - MissingColon, - #[error("payload type not found")] - PayloadTypeNotFound, - #[error("{0}")] - Io(#[source] IoError), - #[error("utf-8 error: {0}")] - Utf8(#[from] FromUtf8Error), - #[error("SdpInvalidSyntax: {0}")] - SdpInvalidSyntax(String), - #[error("SdpInvalidValue: {0}")] - SdpInvalidValue(String), - #[error("sdp: empty time_descriptions")] - SdpEmptyTimeDescription, - #[error("parse int: {0}")] - ParseInt(#[from] ParseIntError), - #[error("parse url: {0}")] - ParseUrl(#[from] url::ParseError), - #[error("parse extmap: {0}")] - ParseExtMap(String), - #[error("{} --> {} <-- {}", .s.substring(0,*.p), .s.substring(*.p, *.p+1), .s.substring(*.p+1, .s.len()))] - SyntaxError { s: String, p: usize }, -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} diff --git a/sdp/src/extmap/extmap_test.rs b/sdp/src/extmap/extmap_test.rs deleted file mode 100644 index e2c7c9e66..000000000 --- a/sdp/src/extmap/extmap_test.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::io::BufReader; -use std::iter::Iterator; - -use super::*; -use crate::lexer::END_LINE; -use crate::util::ATTRIBUTE_KEY; - -const EXAMPLE_ATTR_EXTMAP1: &str = "extmap:1 http://example.com/082005/ext.htm#ttime"; -const EXAMPLE_ATTR_EXTMAP2: &str = - "extmap:2/sendrecv http://example.com/082005/ext.htm#xmeta short"; -const FAILING_ATTR_EXTMAP1: &str = - "extmap:257/sendrecv http://example.com/082005/ext.htm#xmeta short"; -const FAILING_ATTR_EXTMAP2: &str = "extmap:2/blorg http://example.com/082005/ext.htm#xmeta short"; - -#[test] -fn test_extmap() -> Result<()> { - let example_attr_extmap1_line = EXAMPLE_ATTR_EXTMAP1; - let example_attr_extmap2_line = EXAMPLE_ATTR_EXTMAP2; - let failing_attr_extmap1_line = format!("{ATTRIBUTE_KEY}{FAILING_ATTR_EXTMAP1}{END_LINE}"); - let failing_attr_extmap2_line = format!("{ATTRIBUTE_KEY}{FAILING_ATTR_EXTMAP2}{END_LINE}"); - let passingtests = [ - (EXAMPLE_ATTR_EXTMAP1, example_attr_extmap1_line), - (EXAMPLE_ATTR_EXTMAP2, example_attr_extmap2_line), - ]; - let failingtests = vec![ - (FAILING_ATTR_EXTMAP1, failing_attr_extmap1_line), - (FAILING_ATTR_EXTMAP2, failing_attr_extmap2_line), - ]; - - for (i, u) in passingtests.iter().enumerate() { - let mut reader = BufReader::new(u.1.as_bytes()); - let actual = ExtMap::unmarshal(&mut reader)?; - assert_eq!( - actual.marshal(), - u.1, - "{}: {} vs {}", - i, - u.1, - actual.marshal() - ); - } - - for u in failingtests { - let mut reader = BufReader::new(u.1.as_bytes()); - let actual = ExtMap::unmarshal(&mut reader); - assert!(actual.is_err()); - } - - Ok(()) -} - -#[test] -fn test_transport_cc_extmap() -> Result<()> { - // a=extmap:["/"] - // a=extmap:3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01 - let uri = Some(Url::parse( - "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01", - )?); - let e = ExtMap { - value: 3, - uri, - direction: Direction::Unspecified, - ext_attr: None, - }; - - let s = e.marshal(); - if s == "3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" { - panic!("TestTransportCC failed"); - } else { - assert_eq!( - s, - "extmap:3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" - ) - } - - Ok(()) -} diff --git a/sdp/src/extmap/mod.rs b/sdp/src/extmap/mod.rs deleted file mode 100644 index e304e71cd..000000000 --- a/sdp/src/extmap/mod.rs +++ /dev/null @@ -1,120 +0,0 @@ -#[cfg(test)] -mod extmap_test; - -use std::{fmt, io}; - -use url::Url; - -use super::direction::*; -use super::error::{Error, Result}; -use crate::description::common::*; - -/// Default ext values -pub const DEF_EXT_MAP_VALUE_ABS_SEND_TIME: usize = 1; -pub const DEF_EXT_MAP_VALUE_TRANSPORT_CC: usize = 2; -pub const DEF_EXT_MAP_VALUE_SDES_MID: usize = 3; -pub const DEF_EXT_MAP_VALUE_SDES_RTP_STREAM_ID: usize = 4; - -pub const ABS_SEND_TIME_URI: &str = "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time"; -pub const TRANSPORT_CC_URI: &str = - "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"; -pub const SDES_MID_URI: &str = "urn:ietf:params:rtp-hdrext:sdes:mid"; -pub const SDES_RTP_STREAM_ID_URI: &str = "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id"; -pub const SDES_REPAIR_RTP_STREAM_ID_URI: &str = - "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id"; - -pub const AUDIO_LEVEL_URI: &str = "urn:ietf:params:rtp-hdrext:ssrc-audio-level"; -pub const VIDEO_ORIENTATION_URI: &str = "urn:3gpp:video-orientation"; - -/// ExtMap represents the activation of a single RTP header extension -#[derive(Debug, Clone, Default)] -pub struct ExtMap { - pub value: isize, - pub direction: Direction, - pub uri: Option, - pub ext_attr: Option, -} - -impl fmt::Display for ExtMap { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut output = format!("{}", self.value); - if self.direction != Direction::Unspecified { - output += format!("/{}", self.direction).as_str(); - } - - if let Some(uri) = &self.uri { - output += format!(" {uri}").as_str(); - } - - if let Some(ext_attr) = &self.ext_attr { - output += format!(" {ext_attr}").as_str(); - } - - write!(f, "{output}") - } -} - -impl ExtMap { - /// converts this object to an Attribute - pub fn convert(&self) -> Attribute { - Attribute { - key: "extmap".to_string(), - value: Some(self.to_string()), - } - } - - /// unmarshal creates an Extmap from a string - pub fn unmarshal(reader: &mut R) -> Result { - let mut line = String::new(); - reader.read_line(&mut line)?; - let parts: Vec<&str> = line.trim().splitn(2, ':').collect(); - if parts.len() != 2 { - return Err(Error::ParseExtMap(line)); - } - - let fields: Vec<&str> = parts[1].split_whitespace().collect(); - if fields.len() < 2 { - return Err(Error::ParseExtMap(line)); - } - - let valdir: Vec<&str> = fields[0].split('/').collect(); - let value = valdir[0].parse::()?; - if !(1..=246).contains(&value) { - return Err(Error::ParseExtMap(format!( - "{} -- extmap key must be in the range 1-256", - valdir[0] - ))); - } - - let mut direction = Direction::Unspecified; - if valdir.len() == 2 { - direction = Direction::new(valdir[1]); - if direction == Direction::Unspecified { - return Err(Error::ParseExtMap(format!( - "unknown direction from {}", - valdir[1] - ))); - } - } - - let uri = Some(Url::parse(fields[1])?); - - let ext_attr = if fields.len() == 3 { - Some(fields[2].to_owned()) - } else { - None - }; - - Ok(ExtMap { - value, - direction, - uri, - ext_attr, - }) - } - - /// marshal creates a string from an ExtMap - pub fn marshal(&self) -> String { - "extmap:".to_string() + self.to_string().as_str() - } -} diff --git a/sdp/src/lexer/mod.rs b/sdp/src/lexer/mod.rs deleted file mode 100644 index a65a387a9..000000000 --- a/sdp/src/lexer/mod.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::io; -use std::io::SeekFrom; - -use super::description::session::SessionDescription; -use super::error::{Error, Result}; - -pub(crate) const END_LINE: &str = "\r\n"; - -pub struct Lexer<'a, R: io::BufRead + io::Seek> { - pub desc: SessionDescription, - pub reader: &'a mut R, -} - -pub type StateFnType<'a, R> = fn(&mut Lexer<'a, R>) -> Result>>; - -pub struct StateFn<'a, R: io::BufRead + io::Seek> { - pub f: StateFnType<'a, R>, -} - -pub fn read_type(reader: &mut R) -> Result<(Vec, usize)> { - let mut b = [0; 1]; - - loop { - if reader.read_exact(&mut b).is_err() { - return Ok((b"".to_vec(), 0)); - } - - if b[0] == b'\n' || b[0] == b'\r' { - continue; - } - reader.seek(SeekFrom::Current(-1))?; - - let mut buf = Vec::with_capacity(2); - let num_bytes = reader.read_until(b'=', &mut buf)?; - if num_bytes == 0 { - return Ok((b"".to_vec(), num_bytes)); - } - match buf.len() { - 2 => return Ok((buf, num_bytes)), - _ => return Err(Error::SdpInvalidSyntax(String::from_utf8(buf)?)), - } - } -} - -pub fn read_value(reader: &mut R) -> Result<(String, usize)> { - let mut value = String::new(); - let num_bytes = reader.read_line(&mut value)?; - Ok((value.trim().to_string(), num_bytes)) -} - -pub fn index_of(element: &str, data: &[&str]) -> i32 { - for (k, &v) in data.iter().enumerate() { - if element == v { - return k as i32; - } - } - -1 -} - -pub fn key_value_build(key: &str, value: Option<&String>) -> String { - if let Some(val) = value { - format!("{key}{val}{END_LINE}") - } else { - "".to_string() - } -} diff --git a/sdp/src/lib.rs b/sdp/src/lib.rs deleted file mode 100644 index f64a5b671..000000000 --- a/sdp/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub mod description; -pub mod direction; -pub mod extmap; -pub mod util; - -mod error; -pub(crate) mod lexer; - -pub use description::media::MediaDescription; -pub use description::session::SessionDescription; -pub use error::Error; diff --git a/sdp/src/util/mod.rs b/sdp/src/util/mod.rs deleted file mode 100644 index 38f146027..000000000 --- a/sdp/src/util/mod.rs +++ /dev/null @@ -1,246 +0,0 @@ -#[cfg(test)] -mod util_test; - -use std::collections::HashMap; -use std::fmt; - -use super::error::{Error, Result}; - -pub const ATTRIBUTE_KEY: &str = "a="; - -/// ConnectionRole indicates which of the end points should initiate the connection establishment -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum ConnectionRole { - #[default] - Unspecified, - - /// ConnectionRoleActive indicates the endpoint will initiate an outgoing connection. - Active, - - /// ConnectionRolePassive indicates the endpoint will accept an incoming connection. - Passive, - - /// ConnectionRoleActpass indicates the endpoint is willing to accept an incoming connection or to initiate an outgoing connection. - Actpass, - - /// ConnectionRoleHoldconn indicates the endpoint does not want the connection to be established for the time being. - Holdconn, -} - -const CONNECTION_ROLE_ACTIVE_STR: &str = "active"; -const CONNECTION_ROLE_PASSIVE_STR: &str = "passive"; -const CONNECTION_ROLE_ACTPASS_STR: &str = "actpass"; -const CONNECTION_ROLE_HOLDCONN_STR: &str = "holdconn"; - -impl fmt::Display for ConnectionRole { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - ConnectionRole::Active => CONNECTION_ROLE_ACTIVE_STR, - ConnectionRole::Passive => CONNECTION_ROLE_PASSIVE_STR, - ConnectionRole::Actpass => CONNECTION_ROLE_ACTPASS_STR, - ConnectionRole::Holdconn => CONNECTION_ROLE_HOLDCONN_STR, - _ => "Unspecified", - }; - write!(f, "{s}") - } -} - -impl From for ConnectionRole { - fn from(v: u8) -> Self { - match v { - 1 => ConnectionRole::Active, - 2 => ConnectionRole::Passive, - 3 => ConnectionRole::Actpass, - 4 => ConnectionRole::Holdconn, - _ => ConnectionRole::Unspecified, - } - } -} - -impl From<&str> for ConnectionRole { - fn from(raw: &str) -> Self { - match raw { - CONNECTION_ROLE_ACTIVE_STR => ConnectionRole::Active, - CONNECTION_ROLE_PASSIVE_STR => ConnectionRole::Passive, - CONNECTION_ROLE_ACTPASS_STR => ConnectionRole::Actpass, - CONNECTION_ROLE_HOLDCONN_STR => ConnectionRole::Holdconn, - _ => ConnectionRole::Unspecified, - } - } -} - -/// https://tools.ietf.org/html/draft-ietf-rtcweb-jsep-26#section-5.2.1 -/// Session ID is recommended to be constructed by generating a 64-bit -/// quantity with the highest bit set to zero and the remaining 63-bits -/// being cryptographically random. -pub(crate) fn new_session_id() -> u64 { - let c = u64::MAX ^ (1u64 << 63); - rand::random::() & c -} - -// Codec represents a codec -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct Codec { - pub payload_type: u8, - pub name: String, - pub clock_rate: u32, - pub encoding_parameters: String, - pub fmtp: String, - pub rtcp_feedback: Vec, -} - -impl fmt::Display for Codec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {}/{}/{} ({}) [{}]", - self.payload_type, - self.name, - self.clock_rate, - self.encoding_parameters, - self.fmtp, - self.rtcp_feedback.join(", "), - ) - } -} - -pub(crate) fn parse_rtpmap(rtpmap: &str) -> Result { - // a=rtpmap: /[/] - let split: Vec<&str> = rtpmap.split_whitespace().collect(); - if split.len() != 2 { - return Err(Error::MissingWhitespace); - } - - let pt_split: Vec<&str> = split[0].split(':').collect(); - if pt_split.len() != 2 { - return Err(Error::MissingColon); - } - let payload_type = pt_split[1].parse::()?; - - let split: Vec<&str> = split[1].split('/').collect(); - let name = split[0].to_string(); - let parts = split.len(); - let clock_rate = if parts > 1 { - split[1].parse::()? - } else { - 0 - }; - let encoding_parameters = if parts > 2 { - split[2].to_string() - } else { - "".to_string() - }; - - Ok(Codec { - payload_type, - name, - clock_rate, - encoding_parameters, - ..Default::default() - }) -} - -pub(crate) fn parse_fmtp(fmtp: &str) -> Result { - // a=fmtp: - let split: Vec<&str> = fmtp.split_whitespace().collect(); - if split.len() != 2 { - return Err(Error::MissingWhitespace); - } - - let fmtp = split[1].to_string(); - - let split: Vec<&str> = split[0].split(':').collect(); - if split.len() != 2 { - return Err(Error::MissingColon); - } - let payload_type = split[1].parse::()?; - - Ok(Codec { - payload_type, - fmtp, - ..Default::default() - }) -} - -pub(crate) fn parse_rtcp_fb(rtcp_fb: &str) -> Result { - // a=ftcp-fb: [] - let split: Vec<&str> = rtcp_fb.splitn(2, ' ').collect(); - if split.len() != 2 { - return Err(Error::MissingWhitespace); - } - - let pt_split: Vec<&str> = split[0].split(':').collect(); - if pt_split.len() != 2 { - return Err(Error::MissingColon); - } - - Ok(Codec { - payload_type: pt_split[1].parse::()?, - rtcp_feedback: vec![split[1].to_string()], - ..Default::default() - }) -} - -pub(crate) fn merge_codecs(mut codec: Codec, codecs: &mut HashMap) { - if let Some(saved_codec) = codecs.get_mut(&codec.payload_type) { - if saved_codec.payload_type == 0 { - saved_codec.payload_type = codec.payload_type - } - if saved_codec.name.is_empty() { - saved_codec.name = codec.name - } - if saved_codec.clock_rate == 0 { - saved_codec.clock_rate = codec.clock_rate - } - if saved_codec.encoding_parameters.is_empty() { - saved_codec.encoding_parameters = codec.encoding_parameters - } - if saved_codec.fmtp.is_empty() { - saved_codec.fmtp = codec.fmtp - } - saved_codec.rtcp_feedback.append(&mut codec.rtcp_feedback); - } else { - codecs.insert(codec.payload_type, codec); - } -} - -fn equivalent_fmtp(want: &str, got: &str) -> bool { - let mut want_split: Vec<&str> = want.split(';').collect(); - let mut got_split: Vec<&str> = got.split(';').collect(); - - if want_split.len() != got_split.len() { - return false; - } - - want_split.sort_unstable(); - got_split.sort_unstable(); - - for (i, &want_part) in want_split.iter().enumerate() { - let want_part = want_part.trim(); - let got_part = got_split[i].trim(); - if got_part != want_part { - return false; - } - } - - true -} - -pub(crate) fn codecs_match(wanted: &Codec, got: &Codec) -> bool { - if !wanted.name.is_empty() && wanted.name.to_lowercase() != got.name.to_lowercase() { - return false; - } - if wanted.clock_rate != 0 && wanted.clock_rate != got.clock_rate { - return false; - } - if !wanted.encoding_parameters.is_empty() - && wanted.encoding_parameters != got.encoding_parameters - { - return false; - } - if !wanted.fmtp.is_empty() && !equivalent_fmtp(&wanted.fmtp, &got.fmtp) { - return false; - } - - true -} diff --git a/sdp/src/util/util_test.rs b/sdp/src/util/util_test.rs deleted file mode 100644 index 33d36f7a1..000000000 --- a/sdp/src/util/util_test.rs +++ /dev/null @@ -1,177 +0,0 @@ -use super::*; -use crate::description::common::*; -use crate::description::media::*; -use crate::description::session::*; - -fn get_test_session_description() -> SessionDescription { - SessionDescription{ - media_descriptions: vec![ - MediaDescription { - media_name: MediaName { - media: "video".to_string(), - port: RangedPort { - value: 51372, - range: None, - }, - protos: vec!["RTP".to_string(), "AVP".to_string()], - formats: vec!["120".to_string(), "121".to_string(), "126".to_string(), "97".to_string()], - }, - attributes: vec![ - Attribute::new("fmtp:126 profile-level-id=42e01f;level-asymmetry-allowed=1;packetization-mode=1".to_string(), None), - Attribute::new("fmtp:97 profile-level-id=42e01f;level-asymmetry-allowed=1".to_string(), None), - Attribute::new("fmtp:120 max-fs=12288;max-fr=60".to_string(), None), - Attribute::new("fmtp:121 max-fs=12288;max-fr=60".to_string(), None), - Attribute::new("rtpmap:120 VP8/90000".to_string(), None), - Attribute::new("rtpmap:121 VP9/90000".to_string(), None), - Attribute::new("rtpmap:126 H264/90000".to_string(), None), - Attribute::new("rtpmap:97 H264/90000".to_string(), None), - Attribute::new("rtcp-fb:97 ccm fir".to_string(), None), - Attribute::new("rtcp-fb:97 nack".to_string(), None), - Attribute::new("rtcp-fb:97 nack pli".to_string(), None), - ], - ..Default::default() - }, - ], - ..Default::default() - } -} - -#[test] -fn test_get_payload_type_for_vp8() -> Result<()> { - let tests = vec![ - ( - Codec { - name: "VP8".to_string(), - ..Default::default() - }, - 120, - ), - ( - Codec { - name: "VP9".to_string(), - ..Default::default() - }, - 121, - ), - ( - Codec { - name: "H264".to_string(), - fmtp: "profile-level-id=42e01f;level-asymmetry-allowed=1".to_string(), - ..Default::default() - }, - 97, - ), - ( - Codec { - name: "H264".to_string(), - fmtp: "level-asymmetry-allowed=1;profile-level-id=42e01f".to_string(), - ..Default::default() - }, - 97, - ), - ( - Codec { - name: "H264".to_string(), - fmtp: "profile-level-id=42e01f;level-asymmetry-allowed=1;packetization-mode=1" - .to_string(), - ..Default::default() - }, - 126, - ), - ]; - - for (codec, expected) in tests { - let sdp = get_test_session_description(); - let actual = sdp.get_payload_type_for_codec(&codec)?; - assert_eq!(actual, expected); - } - - Ok(()) -} - -#[test] -fn test_get_codec_for_payload_type() -> Result<()> { - let tests: Vec<(u8, Codec)> = vec![ - ( - 120, - Codec { - payload_type: 120, - name: "VP8".to_string(), - clock_rate: 90000, - fmtp: "max-fs=12288;max-fr=60".to_string(), - ..Default::default() - }, - ), - ( - 121, - Codec { - payload_type: 121, - name: "VP9".to_string(), - clock_rate: 90000, - fmtp: "max-fs=12288;max-fr=60".to_string(), - ..Default::default() - }, - ), - ( - 126, - Codec { - payload_type: 126, - name: "H264".to_string(), - clock_rate: 90000, - fmtp: "profile-level-id=42e01f;level-asymmetry-allowed=1;packetization-mode=1" - .to_string(), - ..Default::default() - }, - ), - ( - 97, - Codec { - payload_type: 97, - name: "H264".to_string(), - clock_rate: 90000, - fmtp: "profile-level-id=42e01f;level-asymmetry-allowed=1".to_string(), - rtcp_feedback: vec![ - "ccm fir".to_string(), - "nack".to_string(), - "nack pli".to_string(), - ], - ..Default::default() - }, - ), - ]; - - for (payload_type, expected) in &tests { - let sdp = get_test_session_description(); - let actual = sdp.get_codec_for_payload_type(*payload_type)?; - assert_eq!(actual, *expected); - } - - Ok(()) -} - -#[test] -fn test_new_session_id() -> Result<()> { - let mut min = 0x7FFFFFFFFFFFFFFFu64; - let mut max = 0u64; - for _ in 0..10000 { - let r = new_session_id(); - - if r > (1 << 63) - 1 { - panic!("Session ID must be less than 2**64-1, got {r}") - } - if r < min { - min = r - } - if r > max { - max = r - } - } - if min > 0x1000000000000000 { - panic!("Value around lower boundary was not generated") - } - if max < 0x7000000000000000 { - panic!("Value around upper boundary was not generated") - } - - Ok(()) -} diff --git a/src/allocation/allocation_manager.rs b/src/allocation/allocation_manager.rs new file mode 100644 index 000000000..987e347b5 --- /dev/null +++ b/src/allocation/allocation_manager.rs @@ -0,0 +1,625 @@ +//! [Allocation]s storage. +//! +//! [Allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-5 + +use std::{ + collections::HashMap, + mem, + sync::{atomic::Ordering, Arc, Mutex as SyncMutex}, + time::Duration, +}; + +use futures::future; +use tokio::{ + sync::{mpsc, Mutex}, + time::sleep, +}; + +use crate::{ + allocation::{Allocation, AllocationMap}, + attr::Username, + con::Conn, + relay::RelayAllocator, + AllocInfo, Error, FiveTuple, +}; + +/// `ManagerConfig` a bag of config params for [`Manager`]. +pub(crate) struct ManagerConfig { + /// Relay connections allocator. + pub(crate) relay_addr_generator: RelayAllocator, + + /// Injected into allocations to notify when allocation is closed. + pub(crate) alloc_close_notify: Option>, +} + +/// [`Manager`] is used to hold active allocations. +pub(crate) struct Manager { + /// [`Allocation`]s storage. + allocations: AllocationMap, + + /// [Reservation][1]s storage. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.9 + reservations: Arc>>, + + /// Relay connections allocator. + relay_allocator: RelayAllocator, + + /// Injected into allocations to notify when allocation is closed. + alloc_close_notify: Option>, +} + +impl Manager { + /// Creates a new [`Manager`]. + pub(crate) fn new(config: ManagerConfig) -> Self { + Self { + allocations: Arc::new(SyncMutex::new(HashMap::new())), + reservations: Arc::new(Mutex::new(HashMap::new())), + relay_allocator: config.relay_addr_generator, + alloc_close_notify: config.alloc_close_notify, + } + } + + /// Returns the information about the all [`Allocation`]s associated with + /// the specified [`FiveTuple`]s. + pub(crate) fn get_allocations_info( + &self, + five_tuples: &Option>, + ) -> HashMap { + let mut infos = HashMap::new(); + + #[allow( + clippy::unwrap_used, + clippy::iter_over_hash_type, + clippy::significant_drop_in_scrutinee + )] + for (five_tuple, alloc) in self.allocations.lock().unwrap().iter() { + #[allow(clippy::unwrap_used)] + if five_tuples.is_none() + || five_tuples.as_ref().unwrap().contains(five_tuple) + { + drop(infos.insert( + *five_tuple, + AllocInfo::new( + *five_tuple, + alloc.username.name().to_owned(), + alloc.relayed_bytes.load(Ordering::Acquire), + ), + )); + } + } + + infos + } + + /// Fetches the [`Allocation`] matching the passed [`FiveTuple`]. + pub(crate) fn has_alloc(&self, five_tuple: &FiveTuple) -> bool { + #[allow(clippy::unwrap_used)] + self.allocations.lock().unwrap().get(five_tuple).is_some() + } + + /// Fetches the [`Allocation`] matching the passed [`FiveTuple`]. + #[allow(clippy::unwrap_in_result)] + pub(crate) fn get_alloc( + &self, + five_tuple: &FiveTuple, + ) -> Option> { + #[allow(clippy::unwrap_used)] + self.allocations.lock().unwrap().get(five_tuple).cloned() + } + + /// Creates a new [`Allocation`] and starts relaying. + #[allow(clippy::too_many_arguments)] + pub(crate) async fn create_allocation( + &self, + five_tuple: FiveTuple, + turn_socket: Arc, + requested_port: u16, + lifetime: Duration, + username: Username, + use_ipv4: bool, + ) -> Result, Error> { + if lifetime == Duration::from_secs(0) { + return Err(Error::LifetimeZero); + } + + if self.has_alloc(&five_tuple) { + return Err(Error::DupeFiveTuple); + } + + let (relay_socket, relay_addr) = self + .relay_allocator + .allocate_conn(use_ipv4, requested_port) + .await?; + let mut a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + five_tuple, + username, + self.alloc_close_notify.clone(), + ); + a.allocations = Some(Arc::clone(&self.allocations)); + + log::trace!("listening on relay addr: {:?}", a.relay_addr); + a.start(lifetime); + a.packet_handler(); + + let a = Arc::new(a); + #[allow(clippy::unwrap_used)] + drop( + self.allocations.lock().unwrap().insert(five_tuple, Arc::clone(&a)), + ); + + Ok(a) + } + + /// Removes an [`Allocation`]. + pub(crate) async fn delete_allocation(&self, five_tuple: &FiveTuple) { + #[allow(clippy::unwrap_used)] + let allocation = self.allocations.lock().unwrap().remove(five_tuple); + + if let Some(a) = allocation { + if let Err(err) = a.close().await { + log::error!("Failed to close allocation: {}", err); + } + } + } + + /// Deletes the [`Allocation`]s according to the specified username `name`. + pub(crate) async fn delete_allocations_by_username(&self, name: &str) { + let to_delete = { + #[allow(clippy::unwrap_used)] + let mut allocations = self.allocations.lock().unwrap(); + + let mut to_delete = Vec::new(); + + // TODO(logist322): Use `.drain_filter()` once stabilized. + allocations.retain(|_, allocation| { + let match_name = allocation.username.name() == name; + + if match_name { + to_delete.push(Arc::clone(allocation)); + } + + !match_name + }); + + to_delete + }; + + drop( + future::join_all(to_delete.iter().map(|a| async move { + if let Err(err) = a.close().await { + log::error!("Failed to close allocation: {}", err); + } + })) + .await, + ); + } + + /// Stores the reservation for the token+port. + pub(crate) async fn create_reservation(&self, token: u64, port: u16) { + let reservations = Arc::clone(&self.reservations); + + drop(tokio::spawn(async move { + let liftime = sleep(Duration::from_secs(30)); + tokio::pin!(liftime); + + tokio::select! { + () = &mut liftime => { + _ = reservations.lock().await.remove(&token); + }, + } + })); + + _ = self.reservations.lock().await.insert(token, port); + } + + /// Returns a random un-allocated udp4 port. + pub(crate) async fn get_random_even_port(&self) -> Result { + let (_, addr) = self.relay_allocator.allocate_conn(true, 0).await?; + Ok(addr.port()) + } + + /// Closes this [`Manager`] and closes all [`Allocation`]s it manages. + pub(crate) async fn close(&self) -> Result<(), Error> { + #[allow(clippy::unwrap_used)] + let allocations = mem::take(&mut *self.allocations.lock().unwrap()); + + #[allow(clippy::iter_over_hash_type)] + for a in allocations.values() { + a.close().await?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod allocation_manager_test { + use bytecodec::DecodeExt; + use rand::random; + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + str::FromStr, + }; + use stun_codec::MessageDecoder; + use tokio::net::UdpSocket; + + use crate::{ + attr::{Attribute, ChannelNumber, Data}, + chandata::ChannelData, + server::DEFAULT_LIFETIME, + }; + + use super::*; + + fn new_test_manager() -> Manager { + let config = ManagerConfig { + relay_addr_generator: RelayAllocator { + relay_address: IpAddr::from([127, 0, 0, 1]), + min_port: 49152, + max_port: 65535, + max_retries: 10, + address: String::from("127.0.0.1"), + }, + alloc_close_notify: None, + }; + Manager::new(config) + } + + fn random_five_tuple() -> FiveTuple { + FiveTuple { + src_addr: SocketAddr::new( + Ipv4Addr::new(0, 0, 0, 0).into(), + random(), + ), + dst_addr: SocketAddr::new( + Ipv4Addr::new(0, 0, 0, 0).into(), + random(), + ), + ..Default::default() + } + } + + #[tokio::test] + async fn test_packet_handler() { + // turn server initialization + let turn_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + // client listener initialization + let client_listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let src_addr = client_listener.local_addr().unwrap(); + let (data_ch_tx, mut data_ch_rx) = mpsc::channel(1); + // client listener read data + tokio::spawn(async move { + let mut buffer = vec![0u8; 1500]; + loop { + let n = match client_listener.recv_from(&mut buffer).await { + Ok((n, _)) => n, + Err(_) => break, + }; + + let _ = data_ch_tx.send(buffer[..n].to_vec()).await; + } + }); + + let m = new_test_manager(); + let a = m + .create_allocation( + FiveTuple { + src_addr, + dst_addr: turn_socket.local_addr().unwrap(), + ..Default::default() + }, + Arc::new(turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + let peer_listener1 = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let peer_listener2 = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + let port = { + // add permission with peer1 address + a.add_permission(peer_listener1.local_addr().unwrap().ip()).await; + // add channel with min channel number and peer2 address + a.add_channel_bind( + ChannelNumber::MIN, + peer_listener2.local_addr().unwrap(), + DEFAULT_LIFETIME, + ) + .await + .unwrap(); + + a.relay_socket.local_addr().unwrap().port() + }; + + let relay_addr_with_host_str = format!("127.0.0.1:{port}"); + let relay_addr_with_host = + SocketAddr::from_str(&relay_addr_with_host_str).unwrap(); + + // test for permission and data message + let target_text = "permission"; + let _ = peer_listener1 + .send_to(target_text.as_bytes(), relay_addr_with_host) + .await + .unwrap(); + let data = data_ch_rx.recv().await.unwrap(); + + let msg = MessageDecoder::::new() + .decode_from_bytes(&data) + .unwrap() + .unwrap(); + + let msg_data = msg.get_attribute::().unwrap().data().to_vec(); + assert_eq!( + target_text.as_bytes(), + &msg_data, + "get message doesn't equal the target text" + ); + + // test for channel bind and channel data + let target_text2 = "channel bind"; + let _ = peer_listener2 + .send_to(target_text2.as_bytes(), relay_addr_with_host) + .await + .unwrap(); + let data = data_ch_rx.recv().await.unwrap(); + + // resolve channel data + assert!(ChannelData::is_channel_data(&data), "should be channel data"); + + let channel_data = ChannelData::decode(data).unwrap(); + assert_eq!( + ChannelNumber::MIN, + channel_data.num(), + "get channel data's number is invalid" + ); + assert_eq!( + target_text2.as_bytes(), + &channel_data.data(), + "get data doesn't equal the target text." + ); + + // listeners close + m.close().await.unwrap(); + } + + #[tokio::test] + async fn test_create_allocation_duplicate_five_tuple() { + // turn server initialization + let turn_socket: Arc = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let m = new_test_manager(); + + let five_tuple = random_five_tuple(); + + let _ = m + .create_allocation( + five_tuple, + Arc::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + let result = m + .create_allocation( + five_tuple, + Arc::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await; + assert!(result.is_err(), "expected error, but got ok"); + } + + #[tokio::test] + async fn test_delete_allocation() { + // turn server initialization + let turn_socket: Arc = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let m = new_test_manager(); + + let five_tuple = random_five_tuple(); + + let _ = m + .create_allocation( + five_tuple, + Arc::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + assert!( + m.has_alloc(&five_tuple), + "Failed to get allocation right after creation" + ); + + m.delete_allocation(&five_tuple).await; + + assert!( + !m.has_alloc(&five_tuple), + "Get allocation with {five_tuple} should be nil after delete" + ); + } + + #[tokio::test] + async fn test_allocation_timeout() { + // turn server initialization + let turn_socket: Arc = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let m = new_test_manager(); + + let mut allocations = vec![]; + let lifetime = Duration::from_millis(100); + + for _ in 0..5 { + let five_tuple = random_five_tuple(); + + let a = m + .create_allocation( + five_tuple, + Arc::clone(&turn_socket), + 0, + lifetime, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + allocations.push(a); + } + + let mut count = 0; + + 'outer: loop { + count += 1; + + if count >= 10 { + panic!("Allocations didn't timeout"); + } + + sleep(lifetime + Duration::from_millis(100)).await; + + let any_outstanding = false; + + for a in &allocations { + if a.close().await.is_ok() { + continue 'outer; + } + } + + if !any_outstanding { + return; + } + } + } + + #[tokio::test] + async fn test_manager_close() { + // turn server initialization + let turn_socket: Arc = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let m = new_test_manager(); + + let mut allocations = vec![]; + + let a1 = m + .create_allocation( + random_five_tuple(), + Arc::clone(&turn_socket), + 0, + Duration::from_millis(100), + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + allocations.push(a1); + + let a2 = m + .create_allocation( + random_five_tuple(), + Arc::clone(&turn_socket), + 0, + Duration::from_millis(200), + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + allocations.push(a2); + + sleep(Duration::from_millis(150)).await; + + log::trace!("Mgr is going to be closed..."); + + m.close().await.unwrap(); + + for a in allocations { + assert!( + a.close().await.is_err(), + "Allocation should be closed if lifetime timeout" + ); + } + } + + #[tokio::test] + async fn test_delete_allocation_by_username() { + let turn_socket: Arc = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let m = new_test_manager(); + + let five_tuple1 = random_five_tuple(); + let five_tuple2 = random_five_tuple(); + let five_tuple3 = random_five_tuple(); + + let _ = m + .create_allocation( + five_tuple1, + Arc::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + let _ = m + .create_allocation( + five_tuple2, + Arc::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + let _ = m + .create_allocation( + five_tuple3, + Arc::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user2")).unwrap(), + true, + ) + .await + .unwrap(); + + assert_eq!(m.allocations.lock().unwrap().len(), 3); + + m.delete_allocations_by_username("user").await; + + assert_eq!(m.allocations.lock().unwrap().len(), 1); + + assert!( + m.get_alloc(&five_tuple1).is_none() + && m.get_alloc(&five_tuple2).is_none() + && m.get_alloc(&five_tuple3).is_some() + ); + } +} diff --git a/src/allocation/channel_bind.rs b/src/allocation/channel_bind.rs new file mode 100644 index 000000000..f19a82bde --- /dev/null +++ b/src/allocation/channel_bind.rs @@ -0,0 +1,136 @@ +//! TURN [`Channel`]. +//! +//! [`Channel`]: https://tools.ietf.org/html/rfc5766#section-2.5 + +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; + +use tokio::{ + sync::{mpsc, Mutex}, + time::{sleep, Duration, Instant}, +}; + +/// TURN [`Channel`]. +/// +/// [`Channel`]: https://tools.ietf.org/html/rfc5766#section-2.5 +#[derive(Clone)] +pub(crate) struct ChannelBind { + /// Transport address of the peer. + peer: SocketAddr, + + /// Channel number. + number: u16, + + /// Channel to the internal loop used to update lifetime or drop channel + /// binding. + reset_tx: Option>, +} + +impl ChannelBind { + /// Creates a new [`ChannelBind`] + pub(crate) const fn new(number: u16, peer: SocketAddr) -> Self { + Self { number, peer, reset_tx: None } + } + + /// Starts [`ChannelBind`]'s internal lifetime watching loop. + pub(crate) fn start( + &mut self, + bindings: Arc>>, + lifetime: Duration, + ) { + let (reset_tx, mut reset_rx) = mpsc::channel(1); + self.reset_tx = Some(reset_tx); + + let number = self.number; + + drop(tokio::spawn(async move { + let timer = sleep(lifetime); + tokio::pin!(timer); + + loop { + tokio::select! { + () = &mut timer => { + if bindings.lock().await.remove(&number).is_none() { + log::error!( + "Failed to remove ChannelBind for {number}" + ); + } + break; + }, + result = reset_rx.recv() => { + if let Some(d) = result { + timer.as_mut().reset(Instant::now() + d); + } else { + break; + } + }, + } + } + })); + } + + /// Returns transport address of the peer. + pub(crate) const fn peer(&self) -> SocketAddr { + self.peer + } + + /// Returns channel number. + pub(crate) const fn num(&self) -> u16 { + self.number + } + + /// Updates [`ChannelBind`]'s lifetime. + pub(crate) async fn refresh(&self, lifetime: Duration) { + if let Some(tx) = &self.reset_tx { + _ = tx.send(lifetime).await; + } + } +} + +#[cfg(test)] +mod channel_bind_test { + use std::net::Ipv4Addr; + + use tokio::net::UdpSocket; + + use crate::{ + allocation::Allocation, + attr::{ChannelNumber, Username}, + con, Error, FiveTuple, + }; + + use super::*; + + async fn create_channel_bind( + lifetime: Duration, + ) -> Result { + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); + let relay_socket = Arc::clone(&turn_socket); + let relay_addr = relay_socket.local_addr().unwrap(); + let a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + FiveTuple::default(), + Username::new(String::from("user")).unwrap(), + None, + ); + + let addr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); + + a.add_channel_bind(ChannelNumber::MIN, addr, lifetime).await?; + + Ok(a) + } + + #[tokio::test] + async fn test_channel_bind() { + let a = create_channel_bind(Duration::from_millis(20)).await.unwrap(); + + let result = a.get_channel_addr(&ChannelNumber::MIN).await; + if let Some(addr) = result { + assert_eq!(addr.ip().to_string(), "0.0.0.0"); + } else { + panic!("expected some, but got none"); + } + } +} diff --git a/src/allocation/mod.rs b/src/allocation/mod.rs new file mode 100644 index 000000000..42ed53f44 --- /dev/null +++ b/src/allocation/mod.rs @@ -0,0 +1,761 @@ +//! TURN server [allocation]. +//! +//! [allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-5 + +mod allocation_manager; +mod channel_bind; +mod permission; + +use std::{ + collections::HashMap, + fmt, + marker::{Send, Sync}, + mem, + net::{IpAddr, SocketAddr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex as SyncMutex, + }, +}; + +use bytecodec::EncodeExt; +use rand::random; +use stun_codec::{ + rfc5766::methods::DATA, Message, MessageClass, MessageEncoder, + TransactionId, +}; +use tokio::{ + net::UdpSocket, + sync::{ + mpsc, + oneshot::{self, Sender}, + Mutex, + }, + time::{sleep, Duration, Instant}, +}; + +use crate::{ + allocation::permission::PERMISSION_LIFETIME, + attr::{Attribute, Data, Username, XorPeerAddress}, + chandata::ChannelData, + con::Conn, + server::INBOUND_MTU, + Error, +}; + +use self::{channel_bind::ChannelBind, permission::Permission}; + +pub(crate) use allocation_manager::{Manager, ManagerConfig}; + +/// [`Allocation`]s storage. +pub(crate) type AllocationMap = + Arc>>>; + +/// Information about an allocation. +#[derive(Debug, Clone)] +pub struct AllocInfo { + /// [`FiveTuple`] of this allocation. + pub five_tuple: FiveTuple, + + /// Username of this allocation. + pub username: String, + + /// Relayed bytes with this allocation. + pub relayed_bytes: usize, +} + +impl AllocInfo { + /// Creates a new [`AllocInfo`]. + #[must_use] + pub const fn new( + five_tuple: FiveTuple, + username: String, + relayed_bytes: usize, + ) -> Self { + Self { five_tuple, username, relayed_bytes } + } +} + +/// The tuple (source IP address, source port, destination IP +/// address, destination port, transport protocol). A 5-tuple +/// uniquely identifies a UDP/TCP session. +#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)] +pub struct FiveTuple { + /// Transport protocol according to [IANA] protocol numbers. + /// + /// [IANA]: https://tinyurl.com/iana-protocol-numbers + pub protocol: u8, + + /// Packet source address. + pub src_addr: SocketAddr, + + /// Packet target address. + pub dst_addr: SocketAddr, +} + +impl fmt::Display for FiveTuple { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_{}_{}", self.protocol, self.src_addr, self.dst_addr) + } +} + +/// TURN server [Allocation]. +/// +/// [Allocation]:https://datatracker.ietf.org/doc/html/rfc5766#section-5 +pub(crate) struct Allocation { + /// [`Conn`] used to create this [`Allocation`]. + turn_socket: Arc, + + /// Relay socket address. + relay_addr: SocketAddr, + + /// Allocated relay socket. + relay_socket: Arc, + + /// [`FiveTuple`] this allocation is created with. + five_tuple: FiveTuple, + + /// Remote user ICE [`Username`]. + username: Username, + + /// List of [`Permission`]s for this [`Allocation`] + permissions: Arc>>, + + /// This [`Allocation`] [`ChannelBind`]ings. + channel_bindings: Arc>>, + + /// All [`Allocation`]s storage. + allocations: Option, + + /// Channel to the internal loop used to update lifetime or drop + /// allocation. + reset_tx: SyncMutex>>, + + /// Total number of relayed bytes. + relayed_bytes: AtomicUsize, + + /// Channel to the packet handler loop used to stop it. + drop_tx: Option>, + + /// Injected into allocations to notify when allocation is closed. + alloc_close_notify: Option>, +} + +impl Allocation { + /// Creates a new [`Allocation`]. + pub(crate) fn new( + turn_socket: Arc, + relay_socket: Arc, + relay_addr: SocketAddr, + five_tuple: FiveTuple, + username: Username, + alloc_close_notify: Option>, + ) -> Self { + Self { + turn_socket, + relay_addr, + relay_socket, + five_tuple, + username, + permissions: Arc::new(Mutex::new(HashMap::new())), + channel_bindings: Arc::new(Mutex::new(HashMap::new())), + allocations: None, + reset_tx: SyncMutex::new(None), + relayed_bytes: AtomicUsize::default(), + drop_tx: None, + alloc_close_notify, + } + } + + /// Send the given data via associated relay socket. + pub(crate) async fn relay( + &self, + data: &[u8], + to: SocketAddr, + ) -> Result<(), Error> { + match self.relay_socket.send_to(data, to).await { + Ok(n) => { + _ = self.relayed_bytes.fetch_add(n, Ordering::AcqRel); + + Ok(()) + } + Err(err) => Err(Error::from(err)), + } + } + + /// Returns [`SocketAddr`] of the associated relay socket. + pub(crate) const fn relay_addr(&self) -> SocketAddr { + self.relay_addr + } + + /// Checks the Permission for the `addr`. + pub(crate) async fn has_permission(&self, addr: &SocketAddr) -> bool { + self.permissions.lock().await.get(&addr.ip()).is_some() + } + + /// Adds a new [`Permission`] to this [`Allocation`]. + pub(crate) async fn add_permission(&self, ip: IpAddr) { + let mut permissions = self.permissions.lock().await; + + if let Some(existed_permission) = permissions.get(&ip) { + existed_permission.refresh(PERMISSION_LIFETIME).await; + } else { + let mut p = Permission::new(ip); + p.start(Arc::clone(&self.permissions), PERMISSION_LIFETIME); + drop(permissions.insert(p.ip(), p)); + } + } + + /// Adds a new [`ChannelBind`] to this [`Allocation`], it also updates the + /// permissions needed for this [`ChannelBind`]. + #[allow(clippy::significant_drop_tightening)] // false-positive + pub(crate) async fn add_channel_bind( + &self, + number: u16, + peer_addr: SocketAddr, + lifetime: Duration, + ) -> Result<(), Error> { + // The channel number is not currently bound to a different transport + // address (same transport address is OK); + if let Some(addr) = self.get_channel_addr(&number).await { + if addr != peer_addr { + return Err(Error::SameChannelDifferentPeer); + } + } + + // The transport address is not currently bound to a different + // channel number. + if let Some(n) = self.get_channel_number(&peer_addr).await { + if number != n { + return Err(Error::SamePeerDifferentChannel); + } + } + + let mut channel_bindings = self.channel_bindings.lock().await; + if let Some(cb) = channel_bindings.get(&number).cloned() { + drop(channel_bindings); + + cb.refresh(lifetime).await; + + // Channel binds also refresh permissions. + self.add_permission(cb.peer().ip()).await; + } else { + let mut bind = ChannelBind::new(number, peer_addr); + bind.start(Arc::clone(&self.channel_bindings), lifetime); + + drop(channel_bindings.insert(number, bind)); + + // Channel binds also refresh permissions. + self.add_permission(peer_addr.ip()).await; + } + Ok(()) + } + + /// Gets the [`ChannelBind`]'s address by `number`. + pub(crate) async fn get_channel_addr( + &self, + number: &u16, + ) -> Option { + self.channel_bindings.lock().await.get(number).map(ChannelBind::peer) + } + + /// Gets the [`ChannelBind`]'s number from this [`Allocation`] by `addr`. + pub(crate) async fn get_channel_number( + &self, + addr: &SocketAddr, + ) -> Option { + self.channel_bindings + .lock() + .await + .values() + .find_map(|b| (b.peer() == *addr).then_some(b.num())) + } + + /// Closes the [`Allocation`]. + pub(crate) async fn close(&self) -> Result<(), Error> { + #[allow(clippy::unwrap_used)] + if self.reset_tx.lock().unwrap().take().is_none() { + return Err(Error::Closed); + } + + drop(mem::take(&mut *self.permissions.lock().await)); + drop(mem::take(&mut *self.channel_bindings.lock().await)); + + log::trace!("allocation with {} closed!", self.five_tuple); + + drop(self.relay_socket.close().await); + + if let Some(notify_tx) = &self.alloc_close_notify { + drop( + notify_tx + .send(AllocInfo { + five_tuple: self.five_tuple, + username: self.username.name().to_owned(), + relayed_bytes: self + .relayed_bytes + .load(Ordering::Acquire), + }) + .await, + ); + } + + Ok(()) + } + + /// Starts the internal lifetime watching loop. + pub(crate) fn start(&self, lifetime: Duration) { + let (reset_tx, mut reset_rx) = mpsc::channel(1); + #[allow(clippy::unwrap_used)] + drop(self.reset_tx.lock().unwrap().replace(reset_tx)); + + let allocations = self.allocations.clone(); + let five_tuple = self.five_tuple; + + drop(tokio::spawn(async move { + let timer = sleep(lifetime); + tokio::pin!(timer); + + loop { + tokio::select! { + () = &mut timer => { + if let Some(allocs) = &allocations{ + #[allow(clippy::unwrap_used)] + let alloc = allocs + .lock() + .unwrap() + .remove(&five_tuple); + + if let Some(a) = alloc { + drop(a.close().await); + } + } + break; + }, + result = reset_rx.recv() => { + if let Some(d) = result { + timer.as_mut().reset(Instant::now() + d); + } else { + break; + } + }, + } + } + })); + } + + /// Updates the allocations lifetime. + pub(crate) async fn refresh(&self, lifetime: Duration) { + #[allow(clippy::unwrap_used)] + let reset_tx = self.reset_tx.lock().unwrap().clone(); + + if let Some(tx) = reset_tx { + _ = tx.send(lifetime).await; + } + } + + /// When the server receives a UDP datagram at a currently allocated + /// relayed transport address, the server looks up the allocation + /// associated with the relayed transport address. The server then + /// checks to see whether the set of permissions for the allocation allow + /// the relaying of the UDP datagram as described in Section 8. + /// + /// If relaying is permitted, then the server checks if there is a + /// channel bound to the peer that sent the UDP datagram (see + /// Section 11). If a channel is bound, then processing proceeds as + /// described in Section 11.7. + /// + /// If relaying is permitted but no channel is bound to the peer, then + /// the server forms and sends a Data indication. The Data indication + /// MUST contain both an XOR-PEER-ADDRESS and a DATA attribute. The DATA + /// attribute is set to the value of the 'data octets' field from the + /// datagram, and the XOR-PEER-ADDRESS attribute is set to the source + /// transport address of the received UDP datagram. The Data indication + /// is then sent on the 5-tuple associated with the allocation. + #[allow(clippy::too_many_lines)] + fn packet_handler(&mut self) { + let five_tuple = self.five_tuple; + let relay_addr = self.relay_addr; + let relay_socket = Arc::clone(&self.relay_socket); + let turn_socket = Arc::clone(&self.turn_socket); + let allocations = self.allocations.clone(); + let channel_bindings = Arc::clone(&self.channel_bindings); + let permissions = Arc::clone(&self.permissions); + let (drop_tx, drop_rx) = oneshot::channel::(); + self.drop_tx = Some(drop_tx); + + drop(tokio::spawn(async move { + let mut buffer = vec![0u8; INBOUND_MTU]; + + tokio::pin!(drop_rx); + loop { + let (n, src_addr) = tokio::select! { + result = relay_socket.recv_from(&mut buffer) => { + if let Ok((data, src_addr)) = result { + (data, src_addr) + } else { + if let Some(allocs) = &allocations { + #[allow(clippy::unwrap_used)] + drop( + allocs.lock().unwrap().remove(&five_tuple) + ); + } + break; + } + } + _ = drop_rx.as_mut() => { + log::trace!("allocation has stopped, \ + stop packet_handler. five_tuple: {:?}", + five_tuple); + break; + } + }; + + let cb_number = { + let mut cb_number = None; + #[allow( + clippy::iter_over_hash_type, + clippy::significant_drop_in_scrutinee + )] + for cb in channel_bindings.lock().await.values() { + if cb.peer() == src_addr { + cb_number = Some(cb.num()); + break; + } + } + cb_number + }; + + if let Some(number) = cb_number { + match ChannelData::encode(buffer[..n].to_vec(), number) { + Ok(data) => { + if let Err(err) = turn_socket + .send_to(data, five_tuple.src_addr) + .await + { + log::error!( + "Failed to send ChannelData from \ + allocation {src_addr}: {err}", + ); + } + } + Err(err) => { + log::error!( + "Failed to send ChannelData from allocation \ + {src_addr}: {err}" + ); + } + }; + } else { + let exist = + permissions.lock().await.get(&src_addr.ip()).is_some(); + + if exist { + log::trace!( + "relaying message from {} to client at {}", + src_addr, + five_tuple.src_addr + ); + + let mut msg: Message = Message::new( + MessageClass::Indication, + DATA, + TransactionId::new(random()), + ); + msg.add_attribute(XorPeerAddress::new(src_addr)); + let Ok(data) = Data::new(buffer[..n].to_vec()) else { + log::error!("DataIndication is too long"); + continue; + }; + msg.add_attribute(data); + + match MessageEncoder::new().encode_into_bytes(msg) { + Ok(encoded) => { + if let Err(err) = turn_socket + .send_to(encoded, five_tuple.src_addr) + .await + { + log::error!( + "Failed to send DataIndication from \ + allocation {} {}", + src_addr, + err + ); + } + } + Err(e) => { + log::error!("DataIndication encode err: {e}"); + } + } + } else { + log::info!( + "No Permission or Channel exists for {} on \ + allocation {}", + src_addr, + relay_addr + ); + } + } + } + })); + } +} + +#[cfg(test)] +mod allocation_test { + use std::{net::Ipv4Addr, str::FromStr}; + + use tokio::net::UdpSocket; + + use super::*; + + use crate::{ + attr::{ChannelNumber, PROTO_UDP}, + server::DEFAULT_LIFETIME, + }; + + impl Default for FiveTuple { + fn default() -> Self { + FiveTuple { + protocol: PROTO_UDP, + src_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), + dst_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), + } + } + } + + #[tokio::test] + async fn test_has_permission() { + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + let relay_socket = Arc::clone(&turn_socket); + let relay_addr = relay_socket.local_addr().unwrap(); + let a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + FiveTuple::default(), + Username::new(String::from("user")).unwrap(), + None, + ); + + let addr1 = SocketAddr::from_str("127.0.0.1:3478").unwrap(); + let addr2 = SocketAddr::from_str("127.0.0.1:3479").unwrap(); + let addr3 = SocketAddr::from_str("127.0.0.2:3478").unwrap(); + + a.add_permission(addr1.ip()).await; + a.add_permission(addr2.ip()).await; + a.add_permission(addr3.ip()).await; + + let found_p1 = a.has_permission(&addr1).await; + assert!(found_p1, "Should keep the first one."); + + let found_p2 = a.has_permission(&addr2).await; + assert!(found_p2, "Second one should be ignored."); + + let found_p3 = a.has_permission(&addr3).await; + assert!(found_p3, "Permission with another IP should be found"); + } + + #[tokio::test] + async fn test_add_permission() { + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + let relay_socket = Arc::clone(&turn_socket); + let relay_addr = relay_socket.local_addr().unwrap(); + let a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + FiveTuple::default(), + Username::new(String::from("user")).unwrap(), + None, + ); + + let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap(); + a.add_permission(addr.ip()).await; + + let found_p = a.has_permission(&addr).await; + assert!(found_p, "Should keep the first one."); + } + + #[tokio::test] + async fn test_get_channel_by_number() { + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + let relay_socket = Arc::clone(&turn_socket); + let relay_addr = relay_socket.local_addr().unwrap(); + let a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + FiveTuple::default(), + Username::new(String::from("user")).unwrap(), + None, + ); + + let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap(); + + a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME) + .await + .unwrap(); + + let exist_channel_addr = + a.get_channel_addr(&ChannelNumber::MIN).await.unwrap(); + assert_eq!(addr, exist_channel_addr); + + let not_exist_channel = + a.get_channel_addr(&(ChannelNumber::MIN + 1)).await; + assert!( + not_exist_channel.is_none(), + "should be nil for not existed channel." + ); + } + + #[tokio::test] + async fn test_get_channel_by_addr() { + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + let relay_socket = Arc::clone(&turn_socket); + let relay_addr = relay_socket.local_addr().unwrap(); + let a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + FiveTuple::default(), + Username::new(String::from("user")).unwrap(), + None, + ); + + let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap(); + let addr2 = SocketAddr::from_str("127.0.0.1:3479").unwrap(); + + a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME) + .await + .unwrap(); + + let exist_channel_number = a.get_channel_number(&addr).await.unwrap(); + assert_eq!(ChannelNumber::MIN, exist_channel_number); + + let not_exist_channel = a.get_channel_number(&addr2).await; + assert!( + not_exist_channel.is_none(), + "should be nil for not existed channel." + ); + } + + #[tokio::test] + async fn test_allocation_close() { + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + let relay_socket = Arc::clone(&turn_socket); + let relay_addr = relay_socket.local_addr().unwrap(); + let a = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + FiveTuple::default(), + Username::new(String::from("user")).unwrap(), + None, + ); + + // add mock lifetimeTimer + a.start(DEFAULT_LIFETIME); + + // add channel + let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap(); + + a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME) + .await + .unwrap(); + + // add permission + a.add_permission(addr.ip()).await; + + a.close().await.unwrap(); + } +} + +#[cfg(test)] +mod five_tuple_test { + use std::net::SocketAddr; + + use crate::{ + attr::{PROTO_TCP, PROTO_UDP}, + FiveTuple, + }; + + #[test] + fn test_five_tuple_equal() { + let src_addr1: SocketAddr = + "0.0.0.0:3478".parse::().unwrap(); + let src_addr2: SocketAddr = + "0.0.0.0:3479".parse::().unwrap(); + + let dst_addr1: SocketAddr = + "0.0.0.0:3480".parse::().unwrap(); + let dst_addr2: SocketAddr = + "0.0.0.0:3481".parse::().unwrap(); + + let tests = vec![ + ( + "Equal", + true, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr1, + dst_addr: dst_addr1, + }, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr1, + dst_addr: dst_addr1, + }, + ), + ( + "DifferentProtocol", + false, + FiveTuple { + protocol: PROTO_TCP, + src_addr: src_addr1, + dst_addr: dst_addr1, + }, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr1, + dst_addr: dst_addr1, + }, + ), + ( + "DifferentSrcAddr", + false, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr1, + dst_addr: dst_addr1, + }, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr2, + dst_addr: dst_addr1, + }, + ), + ( + "DifferentDstAddr", + false, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr1, + dst_addr: dst_addr1, + }, + FiveTuple { + protocol: PROTO_UDP, + src_addr: src_addr1, + dst_addr: dst_addr2, + }, + ), + ]; + + for (name, expect, a, b) in tests { + let fact = a == b; + assert_eq!( + expect, fact, + "{name}: {a}, {b} equal check should be {expect}, but {fact}" + ); + } + } +} diff --git a/src/allocation/permission.rs b/src/allocation/permission.rs new file mode 100644 index 000000000..4507a8acf --- /dev/null +++ b/src/allocation/permission.rs @@ -0,0 +1,81 @@ +//! TURN [Allocation] [Permission]. +//! +//! [Allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-2.2 +//! [Permission]: https://datatracker.ietf.org/doc/html/rfc5766#section-8 + +use std::{collections::HashMap, net::IpAddr, sync::Arc}; + +use tokio::{ + sync::{mpsc, Mutex}, + time::{sleep, Duration, Instant}, +}; + +/// The Permission Lifetime MUST be 300 seconds (= 5 minutes)[1]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-8 +pub(crate) const PERMISSION_LIFETIME: Duration = Duration::from_secs(5 * 60); + +/// TURN [Allocation] [Permission]. +/// +/// [Allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-2.2 +/// [Permission]: https://datatracker.ietf.org/doc/html/rfc5766#section-8 +pub(crate) struct Permission { + /// [`IpAddr`] of this permission that is matched with the source IP + /// address of the datagram received. + ip: IpAddr, + + /// Channel to the inner lifetime watching loop. + reset_tx: Option>, +} + +impl Permission { + /// Creates a new [`Permission`]. + pub(crate) const fn new(ip: IpAddr) -> Self { + Self { ip, reset_tx: None } + } + + /// Starts [`Permission`]'s internal lifetime watching loop. + pub(crate) fn start( + &mut self, + permissions: Arc>>, + lifetime: Duration, + ) { + let (reset_tx, mut reset_rx) = mpsc::channel(1); + self.reset_tx = Some(reset_tx); + + let ip = self.ip; + + drop(tokio::spawn(async move { + let timer = sleep(lifetime); + tokio::pin!(timer); + + loop { + tokio::select! { + () = &mut timer => { + drop(permissions.lock().await.remove(&ip)); + break; + }, + result = reset_rx.recv() => { + if let Some(d) = result { + timer.as_mut().reset(Instant::now() + d); + } else { + break; + } + }, + } + } + })); + } + + /// Returns [`IpAddr`] of this [`Permission`]. + pub(crate) const fn ip(&self) -> IpAddr { + self.ip + } + + /// Updates [`Permission`]'s lifetime. + pub(crate) async fn refresh(&self, lifetime: Duration) { + if let Some(tx) = &self.reset_tx { + _ = tx.send(lifetime).await; + } + } +} diff --git a/src/attr.rs b/src/attr.rs new file mode 100644 index 000000000..9c2b4c66a --- /dev/null +++ b/src/attr.rs @@ -0,0 +1,58 @@ +//! STUN and TURN attributes used by the server. + +use stun_codec::define_attribute_enums; + +pub(crate) use stun_codec::{ + rfc5389::attributes::{ + AlternateServer, ErrorCode, Fingerprint, MappedAddress, + MessageIntegrity, Nonce, Realm, Software, UnknownAttributes, Username, + XorMappedAddress, + }, + rfc5766::attributes::{ + ChannelNumber, Data, DontFragment, EvenPort, Lifetime, + RequestedTransport, ReservationToken, XorPeerAddress, XorRelayAddress, + }, + rfc8656::attributes::{AddressFamily, RequestedAddressFamily}, +}; + +/// UDP protocol number according to [IANA]. +/// +/// [IANA]: https://tinyurl.com/iana-protocol-numbers +pub(crate) const PROTO_UDP: u8 = 17; + +/// TCP protocol number according to [IANA]. +/// +/// [IANA]: https://tinyurl.com/iana-protocol-numbers +pub(crate) const PROTO_TCP: u8 = 6; + +define_attribute_enums!( + Attribute, + AttributeDecoder, + AttributeEncoder, + [ + // RFC 5389 + MappedAddress, + Username, + MessageIntegrity, + ErrorCode, + UnknownAttributes, + Realm, + Nonce, + XorMappedAddress, + Software, + AlternateServer, + Fingerprint, + // RFC 5766 + ChannelNumber, + Lifetime, + XorPeerAddress, + Data, + XorRelayAddress, + EvenPort, + RequestedTransport, + DontFragment, + ReservationToken, + // RFC 8656 + RequestedAddressFamily + ] +); diff --git a/src/chandata.rs b/src/chandata.rs new file mode 100644 index 000000000..f5ddf6aab --- /dev/null +++ b/src/chandata.rs @@ -0,0 +1,294 @@ +//! [`ChannelData`] message implementation. + +use crate::{attr::ChannelNumber, Error}; + +/// [`ChannelData`] message MUST be padded to a multiple of four bytes in order +/// to ensure the alignment of subsequent messages. +const PADDING: usize = 4; + +/// [Channel Number] field size. +/// +/// [Channel Number]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 +const CHANNEL_DATA_NUMBER_SIZE: usize = 2; + +/// [Length] field size. +/// +/// [Length]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 +const CHANNEL_DATA_LENGTH_SIZE: usize = 2; + +/// [ChannelData] message header size. +/// +/// [ChannelData]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 +const CHANNEL_DATA_HEADER_SIZE: usize = + CHANNEL_DATA_LENGTH_SIZE + CHANNEL_DATA_NUMBER_SIZE; + +/// [`ChannelData`] represents the `ChannelData` Message defined in +/// [RFC 5766](https://www.rfc-editor.org/rfc/rfc5766#section-11.4). +#[derive(Debug)] +pub(crate) struct ChannelData { + /// Parsed [`ChannelData`] [Channel Number][1]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 + number: u16, + + /// Parsed [`ChannelData`] payload. + data: Vec, +} + +impl ChannelData { + /// Returns `true` if `buf` looks like the `ChannelData` Message. + #[allow(clippy::missing_asserts_for_indexing)] // Length is checked + pub(crate) fn is_channel_data(buf: &[u8]) -> bool { + if buf.len() < CHANNEL_DATA_HEADER_SIZE { + return false; + } + let len = usize::from(u16::from_be_bytes([ + buf[CHANNEL_DATA_NUMBER_SIZE], + buf[CHANNEL_DATA_NUMBER_SIZE + 1], + ])); + + if len > buf[CHANNEL_DATA_HEADER_SIZE..].len() { + return false; + } + + ChannelNumber::new(u16::from_be_bytes([buf[0], buf[1]])).is_ok() + } + + /// Decodes the given raw message as [`ChannelData`]. + pub(crate) fn decode(mut raw: Vec) -> Result { + if raw.len() < CHANNEL_DATA_HEADER_SIZE { + return Err(Error::UnexpectedEof); + } + + let number = u16::from_be_bytes([raw[0], raw[1]]); + if ChannelNumber::new(number).is_err() { + return Err(Error::InvalidChannelNumber); + } + + let l = usize::from(u16::from_be_bytes([ + raw[CHANNEL_DATA_NUMBER_SIZE], + raw[CHANNEL_DATA_NUMBER_SIZE + 1], + ])); + + if l > raw[CHANNEL_DATA_HEADER_SIZE..].len() { + return Err(Error::BadChannelDataLength); + } + + // Discard header and padding. + drop(raw.drain(0..CHANNEL_DATA_HEADER_SIZE)); + if l != raw.len() { + raw.truncate(l); + } + + Ok(Self { data: raw, number }) + } + + /// Returns [`ChannelData`] [Channel Number][1]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 + pub(crate) const fn num(&self) -> u16 { + self.number + } + + /// Encodes the provided [`ChannelData`] payload and channel number to + /// bytes. + pub(crate) fn encode( + mut data: Vec, + chan_num: u16, + ) -> Result, Error> { + #[allow(clippy::map_err_ignore)] + let len = u16::try_from(data.len()) + .map_err(|_| Error::BadChannelDataLength)?; + for i in len.to_be_bytes().into_iter().rev() { + data.insert(0, i); + } + for i in chan_num.to_be_bytes().into_iter().rev() { + data.insert(0, i); + } + + let padded = nearest_padded_value_length(data.len()); + let bytes_to_add = padded - data.len(); + if bytes_to_add > 0 { + data.extend_from_slice(&vec![0; bytes_to_add]); + } + + Ok(data) + } + + /// Returns [`ChannelData`] payload. + pub(crate) fn data(self) -> Vec { + self.data + } +} + +/// Calculates nearest padded length for the [`ChannelData`]. +pub(crate) const fn nearest_padded_value_length(l: usize) -> usize { + let mut n = PADDING * (l / PADDING); + if n < l { + n += PADDING; + } + n +} + +#[cfg(test)] +mod chandata_test { + use super::*; + + #[test] + fn test_channel_data_encode() { + let encoded = + ChannelData::encode(vec![1, 2, 3, 4], ChannelNumber::MIN + 1) + .unwrap(); + let decoded = ChannelData::decode(encoded.clone()).unwrap(); + + assert!( + ChannelData::is_channel_data(&encoded), + "unexpected IsChannelData" + ); + + assert_eq!(vec![1, 2, 3, 4], decoded.data, "not equal"); + assert_eq!(ChannelNumber::MIN + 1, decoded.number, "not equal"); + } + + #[test] + fn test_channel_data_equal() { + let tests = vec![ + ( + "equal", + ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 3] }, + ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 3] }, + true, + ), + ( + "number", + ChannelData { + number: ChannelNumber::MIN + 1, + data: vec![1, 2, 3], + }, + ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 3] }, + false, + ), + ( + "length", + ChannelData { + number: ChannelNumber::MIN, + data: vec![1, 2, 3, 4], + }, + ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 3] }, + false, + ), + ( + "data", + ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 2] }, + ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 3] }, + false, + ), + ]; + + for (name, a, b, r) in tests { + let v = ChannelData::encode(a.data.clone(), a.number) + == ChannelData::encode(b.data.clone(), b.number); + assert_eq!(v, r, "unexpected: ({name}) {r} != {r}"); + } + } + + #[test] + fn test_channel_data_decode() { + let tests = vec![ + ("small", vec![1, 2, 3], Error::UnexpectedEof), + ( + "zeroes", + vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + Error::InvalidChannelNumber, + ), + ( + "bad chan number", + vec![63, 255, 0, 0, 0, 4, 0, 0, 1, 2, 3, 4], + Error::InvalidChannelNumber, + ), + ( + "bad length", + vec![0x40, 0x40, 0x02, 0x23, 0x16, 0, 0, 0, 0, 0, 0, 0], + Error::BadChannelDataLength, + ), + ]; + + for (name, buf, want_err) in tests { + if let Err(err) = ChannelData::decode(buf) { + assert_eq!( + want_err, err, + "unexpected: ({name}) {want_err} != {err}" + ); + } else { + panic!("expected error, but got ok"); + } + } + } + + #[test] + fn test_is_channel_data() { + let tests = vec![ + ("small", vec![1, 2, 3, 4], false), + ("zeroes", vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], false), + ]; + + for (name, buf, r) in tests { + let v = ChannelData::is_channel_data(&buf); + assert_eq!(v, r, "unexpected: ({name}) {r} != {v}"); + } + } + + const CHANDATA_TEST_HEX: [&str; 2] = [ + "40000064000100502112a442453731722f2b322b6e4e7a5800060009443758343a3377\ + 6c59000000c0570004000003e7802a00081d5136dab65b169300250000002400046e00\ + 1eff0008001465d11a330e104a9f5f598af4abc6a805f26003cf802800046b334442", + "4000022316fefd0000000000000011012c0b000120000100000000012000011d00011a\ + 308201163081bda003020102020900afe52871340bd13e300a06082a8648ce3d040302\ + 3011310f300d06035504030c06576562525443301e170d313830383131303335323030\ + 5a170d3138303931313033353230305a3011310f300d06035504030c06576562525443\ + 3059301306072a8648ce3d020106082a8648ce3d030107034200048080e348bd41469c\ + fb7a7df316676fd72a06211765a50a0f0b07526c872dcf80093ed5caa3f5a40a725dd7\ + 4b41b79bdd19ee630c5313c8601d6983286c8722c1300a06082a8648ce3d0403020348\ + 003045022100d13a0a131bc2a9f27abd3d4c547f7ef172996a0c0755c707b6a3e048d8\ + 762ded0220055fc8182818a644a3d3b5b157304cc3f1421fadb06263bfb451cd28be4b\ + c9ee16fefd0000000000000012002d10000021000200000000002120f7e23c97df45a9\ + 6e13cb3e76b37eff5e73e2aee0b6415d29443d0bd24f578b7e16fefd00000000000000\ + 1300580f00004c000300000000004c040300483046022100fdbb74eab1aca1532e6ac0\ + ab267d5b83a24bb4d5d7d504936e2785e6e388b2bd022100f6a457b9edd9ead52a9d0e\ + 9a19240b3a68b95699546c044f863cf8349bc8046214fefd0000000000000014000101\ + 16fefd0001000000000004003000010000000000040aae2421e7d549632a7def8ed068\ + 98c3c5b53f5b812a963a39ab6cdd303b79bdb237f3314c1da21b", + ]; + + #[test] + fn test_chrome_channel_data() { + let mut data = vec![]; + let mut messages = vec![]; + + // Decoding hex data into binary. + for h in &CHANDATA_TEST_HEX { + let b = match hex::decode(h) { + Ok(b) => b, + Err(_) => panic!("hex decode error"), + }; + data.push(b); + } + + // All hex streams decoded to raw binary format and stored in data + // slice. Decoding packets to messages. + for packet in data { + let m = ChannelData::decode(packet.clone()).unwrap(); + + let encoded = + ChannelData::encode(m.data.clone(), m.number).unwrap(); + let decoded = ChannelData::decode(encoded.clone()).unwrap(); + + assert_eq!(m.data, decoded.data, "should be equal"); + assert_eq!(m.number, decoded.number, "should be equal"); + + messages.push(m); + } + + assert_eq!(messages.len(), 2, "unexpected message slice list"); + } +} diff --git a/src/con/mod.rs b/src/con/mod.rs new file mode 100644 index 000000000..08c5d69d8 --- /dev/null +++ b/src/con/mod.rs @@ -0,0 +1,137 @@ +//! Main STUN/TURN transport implementation. + +mod tcp; + +use std::io; + +use std::net::SocketAddr; + +use async_trait::async_trait; + +use tokio::{ + net, + net::{ToSocketAddrs, UdpSocket}, +}; + +use crate::{attr::PROTO_UDP, server::INBOUND_MTU, Error}; + +pub use tcp::TcpServer; + +/// Abstracting over transport implementation. +#[async_trait] +pub trait Conn { + async fn recv_from(&self) -> Result<(Vec, SocketAddr), Error>; + async fn send_to( + &self, + buf: Vec, + target: SocketAddr, + ) -> Result; + + /// Returns the local transport address. + fn local_addr(&self) -> SocketAddr; + + /// Return the transport protocol according to [IANA]. + /// + /// [IANA]: https://tinyurl.com/iana-protocol-numbers + fn proto(&self) -> u8; + + /// Closes the underlying transport. + async fn close(&self) -> Result<(), Error>; +} + +/// Performs a DNS resolution. +pub(crate) async fn lookup_host( + use_ipv4: bool, + host: T, +) -> Result +where + T: ToSocketAddrs, +{ + for remote_addr in net::lookup_host(host).await? { + if (use_ipv4 && remote_addr.is_ipv4()) + || (!use_ipv4 && remote_addr.is_ipv6()) + { + return Ok(remote_addr); + } + } + + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "No available {} IP address found!", + if use_ipv4 { "ipv4" } else { "ipv6" }, + ), + ) + .into()) +} + +#[async_trait] +impl Conn for UdpSocket { + async fn recv_from(&self) -> Result<(Vec, SocketAddr), Error> { + let mut buf = vec![0u8; INBOUND_MTU]; + let (len, addr) = self.recv_from(&mut buf).await?; + buf.truncate(len); + + Ok((buf, addr)) + } + + async fn send_to( + &self, + data: Vec, + target: SocketAddr, + ) -> Result { + Ok(self.send_to(&data, target).await?) + } + + fn local_addr(&self) -> SocketAddr { + #[allow(clippy::unwrap_used)] + self.local_addr().unwrap() + } + + fn proto(&self) -> u8 { + PROTO_UDP + } + + async fn close(&self) -> Result<(), Error> { + Ok(()) + } +} + +#[cfg(test)] +mod conn_test { + use super::*; + + #[tokio::test] + async fn test_conn_lookup_host() { + let stun_serv_addr = "stun1.l.google.com:19302"; + + if let Ok(ipv4_addr) = lookup_host(true, stun_serv_addr).await { + assert!( + ipv4_addr.is_ipv4(), + "expected ipv4 but got ipv6: {ipv4_addr}" + ); + } + + if let Ok(ipv6_addr) = lookup_host(false, stun_serv_addr).await { + assert!( + ipv6_addr.is_ipv6(), + "expected ipv6 but got ipv4: {ipv6_addr}" + ); + } + } +} + +#[cfg(test)] +mod net_test { + use super::*; + + #[tokio::test] + async fn test_net_native_resolve_addr() { + let udp_addr = lookup_host(true, "localhost:1234").await.unwrap(); + assert_eq!(udp_addr.ip().to_string(), "127.0.0.1", "should match"); + assert_eq!(udp_addr.port(), 1234, "should match"); + + let result = lookup_host(false, "127.0.0.1:1234").await; + assert!(result.is_err(), "should not match"); + } +} diff --git a/src/con/tcp.rs b/src/con/tcp.rs new file mode 100644 index 000000000..e8c4f222b --- /dev/null +++ b/src/con/tcp.rs @@ -0,0 +1,295 @@ +//! STUN/TURN TCP server connection implementation. + +#![allow(clippy::module_name_repetitions)] + +use std::{ + collections::{hash_map::Entry, HashMap}, + net::SocketAddr, + sync::Arc, +}; + +use async_trait::async_trait; +use bytes::BytesMut; +use futures::StreamExt; +use tokio::{ + io::AsyncWriteExt as _, + net::{TcpListener, TcpStream}, + sync::{mpsc, mpsc::error::TrySendError, oneshot, Mutex}, +}; +use tokio_util::codec::{Decoder, FramedRead}; + +use crate::{ + attr::PROTO_TCP, + chandata::nearest_padded_value_length, + con::{Conn, Error}, +}; + +/// Channels to the active TCP sessions. +type TcpWritersMap = Arc< + Mutex< + HashMap< + SocketAddr, + mpsc::Sender<(Vec, oneshot::Sender>)>, + >, + >, +>; + +/// TURN TCP transport. +#[derive(Debug)] +pub struct TcpServer { + /// Ingress messages receiver. + ingress_rx: Mutex, SocketAddr)>>, + + /// Local [`TcpListener`] address. + local_addr: SocketAddr, + + /// Channels to all active TCP sessions. + writers: TcpWritersMap, +} + +#[async_trait] +impl Conn for TcpServer { + async fn recv_from(&self) -> Result<(Vec, SocketAddr), Error> { + if let Some((data, addr)) = self.ingress_rx.lock().await.recv().await { + Ok((data, addr)) + } else { + Err(Error::TransportIsDead) + } + } + + #[allow(clippy::significant_drop_in_scrutinee)] + async fn send_to( + &self, + data: Vec, + target: SocketAddr, + ) -> Result { + let mut writers = self.writers.lock().await; + match writers.entry(target) { + Entry::Occupied(mut e) => { + let (res_tx, res_rx) = oneshot::channel(); + if e.get_mut().send((data, res_tx)).await.is_err() { + // Underlying TCP stream is dead. + drop(e.remove_entry()); + Err(Error::TransportIsDead) + } else { + #[allow(clippy::map_err_ignore)] + res_rx.await.map_err(|_| Error::TransportIsDead)? + } + } + Entry::Vacant(_) => Err(Error::TransportIsDead), + } + } + + fn local_addr(&self) -> SocketAddr { + self.local_addr + } + + fn proto(&self) -> u8 { + PROTO_TCP + } + + async fn close(&self) -> Result<(), Error> { + Ok(()) + } +} + +impl TcpServer { + /// Creates a new [`TcpServer`]. + /// + /// # Errors + /// + /// With [`enum@Error`] if failed to receive local [`SocketAddr`] for the + /// provided [`TcpListener`]. + pub fn new(listener: TcpListener) -> Result { + let local_addr = listener.local_addr()?; + let (ingress_tx, ingress_rx) = mpsc::channel(256); + let writers = Arc::new(Mutex::new(HashMap::new())); + + drop(tokio::spawn({ + let writers = Arc::clone(&writers); + async move { + loop { + let Ok((stream, remote)) = listener.accept().await else { + log::debug!("Closing TCP listener at {local_addr}"); + break; + }; + if ingress_tx.is_closed() { + break; + } + + Self::spawn_stream_handler( + stream, + local_addr, + remote, + ingress_tx.clone(), + Arc::clone(&writers), + ); + } + } + })); + + Ok(Self { ingress_rx: Mutex::new(ingress_rx), local_addr, writers }) + } + + /// Spawns a handler task for the given [`TcpStream`] + fn spawn_stream_handler( + mut stream: TcpStream, + local_addr: SocketAddr, + remote: SocketAddr, + ingress_tx: mpsc::Sender<(Vec, SocketAddr)>, + writers: TcpWritersMap, + ) { + drop(tokio::spawn(async move { + let (egress_tx, mut egress_rx) = mpsc::channel::<( + Vec, + oneshot::Sender>, + )>(256); + drop(writers.lock().await.insert(remote, egress_tx)); + + let (reader, mut writer) = stream.split(); + let mut reader = FramedRead::new(reader, StunTcpCodec::default()); + loop { + tokio::select! { + msg = egress_rx.recv() => { + if let Some((msg, tx)) = msg { + let len = msg.len(); + let res = + writer.write_all(msg.as_slice()).await + .map(|()| len) + .map_err(Error::from); + + drop(tx.send(res)); + } else { + log::debug!("Closing TCP {local_addr} <=> \ + {remote}"); + + break; + } + }, + msg = reader.next() => { + match msg { + Some(Ok(msg)) => { + match ingress_tx.try_send((msg, remote)) { + Ok(()) => {}, + Err(TrySendError::Full(_)) => { + log::debug!("Dropped ingress message \ + from TCP {local_addr} <=> {remote}"); + } + Err(TrySendError::Closed(_)) => + { + log::debug!("Closing TCP \ + {local_addr} <=> {remote}"); + + break; + } + } + } + Some(Err(_)) => {}, + None => { + log::debug!("Closing TCP \ + {local_addr} <=> {remote}"); + + break; + } + } + }, + } + } + })); + } +} + +#[derive(Debug, Clone, Copy)] +enum StunMessageKind { + /// STUN method. + /// + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |0 0| STUN Message Type | Message Length | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Magic Cookie | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// | Transaction ID (96 bits) | + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Method(usize), + + /// TURN [ChannelData][1]. + /// + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Channel Number | Length | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// / Application Data | + /// / | + /// | | + /// | +-------------------------------+ + /// | | + /// +-------------------------------+ + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 + ChannelData(usize), +} + +impl StunMessageKind { + /// Detects [`StunMessageKind`] from the given 4 bytes. + fn detect_kind(first_4_bytes: [u8; 4]) -> Self { + let size = usize::from(u16::from_be_bytes([ + first_4_bytes[2], + first_4_bytes[3], + ])); + + // If the first two bits are zeroes, then this is a STUN method. + if first_4_bytes[0] & 0b1100_0000 == 0 { + Self::Method(nearest_padded_value_length(size + 20)) + } else { + Self::ChannelData(nearest_padded_value_length(size + 4)) + } + } + + /// Returns the expected length of the message. + const fn length(&self) -> usize { + *match self { + Self::Method(l) | Self::ChannelData(l) => l, + } + } +} + +/// [`Decoder`] that splits STUN/TURN stream into frames. +#[derive(Default)] +struct StunTcpCodec { + /// Current message kind. + current: Option, +} + +impl Decoder for StunTcpCodec { + type Error = Error; + type Item = Vec; + + #[allow(clippy::unwrap_in_result, clippy::missing_asserts_for_indexing)] + fn decode( + &mut self, + buf: &mut BytesMut, + ) -> Result, Self::Error> { + if self.current.is_none() && buf.len() >= 4 { + self.current = Some(StunMessageKind::detect_kind([ + buf[0], buf[1], buf[2], buf[3], + ])); + } + if let Some(pending) = self.current { + if buf.len() >= pending.length() { + #[allow(clippy::unwrap_used)] + return Ok(Some( + buf.split_to(self.current.take().unwrap().length()) + .to_vec(), + )); + } + } + + Ok(None) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 000000000..b1d83732a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,333 @@ +//! A pure Rust implementation of TURN. + +#![deny( + macro_use_extern_crate, + nonstandard_style, + rust_2018_idioms, + rustdoc::all, + trivial_casts, + trivial_numeric_casts, + unsafe_code +)] +#![forbid(non_ascii_idents)] +#![warn( + clippy::absolute_paths, + clippy::as_conversions, + clippy::as_ptr_cast_mut, + clippy::assertions_on_result_states, + clippy::branches_sharing_code, + clippy::clear_with_drain, + clippy::clone_on_ref_ptr, + clippy::collection_is_never_read, + clippy::create_dir, + clippy::dbg_macro, + clippy::debug_assert_with_mut_call, + clippy::decimal_literal_representation, + clippy::default_union_representation, + clippy::derive_partial_eq_without_eq, + clippy::else_if_without_else, + clippy::empty_drop, + clippy::empty_line_after_outer_attr, + clippy::empty_structs_with_brackets, + clippy::equatable_if_let, + clippy::empty_enum_variants_with_brackets, + clippy::exit, + clippy::expect_used, + clippy::fallible_impl_from, + clippy::filetype_is_file, + clippy::float_cmp_const, + clippy::fn_to_numeric_cast, + clippy::fn_to_numeric_cast_any, + clippy::format_push_string, + clippy::get_unwrap, + clippy::if_then_some_else_none, + clippy::imprecise_flops, + clippy::index_refutable_slice, + clippy::infinite_loop, + clippy::iter_on_empty_collections, + clippy::iter_on_single_items, + clippy::iter_over_hash_type, + clippy::iter_with_drain, + clippy::large_include_file, + clippy::large_stack_frames, + clippy::let_underscore_untyped, + clippy::lossy_float_literal, + clippy::manual_c_str_literals, + clippy::manual_clamp, + clippy::map_err_ignore, + clippy::mem_forget, + clippy::missing_assert_message, + clippy::missing_asserts_for_indexing, + clippy::missing_const_for_fn, + clippy::missing_docs_in_private_items, + clippy::multiple_inherent_impl, + clippy::multiple_unsafe_ops_per_block, + clippy::mutex_atomic, + clippy::mutex_integer, + clippy::needless_collect, + clippy::needless_pass_by_ref_mut, + clippy::needless_raw_strings, + clippy::nonstandard_macro_braces, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::panic_in_result_fn, + clippy::partial_pub_fields, + clippy::pedantic, + clippy::print_stderr, + clippy::print_stdout, + clippy::pub_without_shorthand, + clippy::ref_as_ptr, + clippy::rc_buffer, + clippy::rc_mutex, + clippy::read_zero_byte_vec, + clippy::readonly_write_lock, + clippy::redundant_clone, + clippy::redundant_type_annotations, + clippy::ref_patterns, + clippy::rest_pat_in_fully_bound_structs, + clippy::same_name_method, + clippy::semicolon_inside_block, + clippy::shadow_unrelated, + clippy::significant_drop_in_scrutinee, + clippy::significant_drop_tightening, + clippy::str_to_string, + clippy::string_add, + clippy::string_lit_as_bytes, + clippy::string_lit_chars_any, + clippy::string_slice, + clippy::string_to_string, + clippy::suboptimal_flops, + clippy::suspicious_operation_groupings, + clippy::suspicious_xor_used_as_pow, + clippy::tests_outside_test_module, + clippy::todo, + clippy::trailing_empty_array, + clippy::transmute_undefined_repr, + clippy::trivial_regex, + clippy::try_err, + clippy::undocumented_unsafe_blocks, + clippy::unimplemented, + clippy::uninhabited_references, + clippy::unnecessary_safety_comment, + clippy::unnecessary_safety_doc, + clippy::unnecessary_self_imports, + clippy::unnecessary_struct_initialization, + clippy::unneeded_field_pattern, + clippy::unused_peekable, + clippy::unwrap_in_result, + clippy::unwrap_used, + clippy::use_debug, + clippy::use_self, + clippy::useless_let_if_seq, + clippy::verbose_file_reads, + clippy::wildcard_enum_match_arm, + explicit_outlives_requirements, + future_incompatible, + let_underscore_drop, + meta_variable_misuse, + missing_abi, + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + semicolon_in_expressions_from_macros, + single_use_lifetimes, + unit_bindings, + unreachable_pub, + unsafe_op_in_unsafe_fn, + unstable_features, + unused_crate_dependencies, + unused_extern_crates, + unused_import_braces, + unused_lifetimes, + unused_macro_rules, + unused_qualifications, + unused_results, + variant_size_differences +)] +#![cfg_attr(test, allow(unused_crate_dependencies, unused_lifetimes))] + +mod allocation; +mod attr; +mod chandata; +mod con; +mod relay; +mod server; + +use std::{io, net::SocketAddr}; + +use thiserror::Error; + +pub use self::{ + allocation::{AllocInfo, FiveTuple}, + con::TcpServer, + relay::RelayAllocator, + server::{Config, ConnConfig, Server}, +}; + +/// External authentication handler. +pub trait AuthHandler { + /// Perform authentication of the given user data returning ICE password + /// on success. + /// + /// # Errors + /// + /// If authentication fails. + fn auth_handle( + &self, + username: &str, + realm: &str, + src_addr: SocketAddr, + ) -> Result, Error>; +} + +/// TURN server errors. +#[derive(Debug, Error, PartialEq)] +#[non_exhaustive] +#[allow(variant_size_differences)] +pub enum Error { + /// Failed to allocate new relay connection sine maximum retires count + /// exceeded. + #[error("turn: max retries exceeded")] + MaxRetriesExceeded, + + /// Failed to handle channel data since channel number is incorrect. + #[error("channel number not in [0x4000, 0x7FFF]")] + InvalidChannelNumber, + + /// Failed to handle channel data cause of incorrect message length. + #[error("channelData length != len(Data)")] + BadChannelDataLength, + + /// Failed to handle message since it's shorter than expected. + #[error("unexpected EOF")] + UnexpectedEof, + + /// A peer address is part of a different address family than that of the + /// relayed transport address of the allocation. + #[error("error code 443: peer address family mismatch")] + PeerAddressFamilyMismatch, + + /// Error when trying to perform action after closing server. + #[error("use of closed network connection")] + Closed, + + /// Channel binding request failed since channel number is currently bound + /// to a different transport address. + #[error("you cannot use the same channel number with different peer")] + SameChannelDifferentPeer, + + /// Channel binding request failed since the transport address is currently + /// bound to a different channel number. + #[error("you cannot use the same peer number with different channel")] + SamePeerDifferentChannel, + + /// Cannot create allocation with zero lifetime. + #[error("allocations must not be created with a lifetime of 0")] + LifetimeZero, + + /// Cannot create allocation for the same five-tuple. + #[error("allocation attempt created with duplicate FiveTuple")] + DupeFiveTuple, + + /// The given nonce is wrong or already been used. + #[error("duplicated Nonce generated, discarding request")] + RequestReplay, + + /// Authentication error. + #[error("no such user exists")] + NoSuchUser, + + /// Unsupported request class. + #[error("unexpected class")] + UnexpectedClass, + + /// Allocate request failed since allocation already exists for the given + /// five-tuple. + #[error("relay already allocated for 5-TUPLE")] + RelayAlreadyAllocatedForFiveTuple, + + /// STUN message does not have a required attribute. + #[error("requested attribute not found")] + AttributeNotFound, + + /// STUN message contains wrong message integrity. + #[error("message integrity mismatch")] + IntegrityMismatch, + + /// [DONT-FRAGMENT][1] attribute is not supported. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.8 + #[error("no support for DONT-FRAGMENT")] + NoDontFragmentSupport, + + /// Allocate request cannot have both [RESERVATION-TOKEN][1] and + /// [EVEN-PORT]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.9 + /// [EVEN-PORT]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.6 + #[error("Request must not contain RESERVATION-TOKEN and EVEN-PORT")] + RequestWithReservationTokenAndEvenPort, + + /// Allocation request cannot contain both [RESERVATION-TOKEN][1] and + /// [REQUESTED-ADDRESS-FAMILY][2]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.9 + /// [2]: https://www.rfc-editor.org/rfc/rfc6156#section-4.1.1 + #[error( + "Request must not contain RESERVATION-TOKEN \ + and REQUESTED-ADDRESS-FAMILY" + )] + RequestWithReservationTokenAndReqAddressFamily, + + /// No allocation for the given five-tuple. + #[error("no allocation found")] + NoAllocationFound, + + /// The specified protocol is not supported. + #[error("allocation requested unsupported proto")] + UnsupportedRelayProto, + + /// Failed to handle send indication since there is no permission for the + /// given address. + #[error("unable to handle send-indication, no permission added")] + NoPermission, + + /// Failed to handle channel data since ther is no binding for the given + /// channel. + #[error("no such channel bind")] + NoSuchChannelBind, + + /// Failed to decode message. + #[error("Failed to decode STUN/TURN message: {0:?}")] + Decode(bytecodec::ErrorKind), + + /// Failed to encode message. + #[error("Failed to encode STUN/TURN message: {0:?}")] + Encode(bytecodec::ErrorKind), + + /// Tried to use dead transport. + #[error("Underlying TCP/UDP transport is dead")] + TransportIsDead, + + /// Error for transport. + #[error("{0}")] + Io(#[source] IoError), +} + +/// [`io::Error`] wrapper. +#[derive(Debug, Error)] +#[error("io error: {0}")] +pub struct IoError(#[from] pub io::Error); + +// Workaround for wanting PartialEq for io::Error. +impl PartialEq for IoError { + fn eq(&self, other: &Self) -> bool { + self.0.kind() == other.0.kind() + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Self::Io(IoError(e)) + } +} diff --git a/src/relay.rs b/src/relay.rs new file mode 100644 index 000000000..79d60502b --- /dev/null +++ b/src/relay.rs @@ -0,0 +1,86 @@ +//! [`RelayAllocator`] is used to create relay transports wit the given +//! configuration. + +#![allow(clippy::module_name_repetitions)] + +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; +use tokio::net::UdpSocket; + +use crate::{con, Error}; + +/// [`RelayAllocator`] is used to generate a Relay Address when creating an +/// allocation. +#[derive(Debug)] +pub struct RelayAllocator { + /// `relay_address` is the IP returned to the user when the relay is + /// created. + pub relay_address: IpAddr, + + /// `min_port` the minimum port to allocate. + pub min_port: u16, + + /// `max_port` the maximum (inclusive) port to allocate. + pub max_port: u16, + + /// `max_retries` the amount of tries to allocate a random port in the + /// defined range. + pub max_retries: u16, + + /// `address` is passed to Listen/ListenPacket when creating the Relay. + pub address: String, +} + +impl RelayAllocator { + /// Allocates a new relay connection. + /// + /// # Errors + /// + /// With [`Error::MaxRetriesExceeded`] if the requested port is `0` and + /// failed to find a free port in the specified maximum retries. + /// + /// With [`Error::Io`] if failed to bind to the specified port. + pub async fn allocate_conn( + &self, + use_ipv4: bool, + requested_port: u16, + ) -> Result<(Arc, SocketAddr), Error> { + let max_retries = + if self.max_retries == 0 { 10 } else { self.max_retries }; + + if requested_port == 0 { + for _ in 0..max_retries { + let port = self.min_port + + rand::random::() + % (self.max_port - self.min_port + 1); + let addr = con::lookup_host( + use_ipv4, + &format!("{}:{}", self.address, port), + ) + .await?; + let Ok(conn) = UdpSocket::bind(addr).await else { + continue; + }; + + let mut relay_addr = conn.local_addr()?; + relay_addr.set_ip(self.relay_address); + return Ok((Arc::new(conn), relay_addr)); + } + + Err(Error::MaxRetriesExceeded) + } else { + let addr = con::lookup_host( + use_ipv4, + &format!("{}:{}", self.address, requested_port), + ) + .await?; + let conn = Arc::new(UdpSocket::bind(addr).await?); + let mut relay_addr = conn.local_addr()?; + relay_addr.set_ip(self.relay_address); + + Ok((conn, relay_addr)) + } + } +} diff --git a/src/server/config.rs b/src/server/config.rs new file mode 100644 index 000000000..57c991803 --- /dev/null +++ b/src/server/config.rs @@ -0,0 +1,82 @@ +//! TURN server configuration. + +#![allow(clippy::module_name_repetitions)] + +use std::{fmt, sync::Arc}; + +use tokio::{sync::mpsc, time::Duration}; + +use crate::{ + allocation::AllocInfo, con::Conn, relay::RelayAllocator, AuthHandler, +}; + +/// Main STUN/TURN socket configuration. +pub struct ConnConfig { + /// STUN socket. + pub conn: Arc, + + /// Relay connections allocator. + pub relay_addr_generator: RelayAllocator, +} + +impl ConnConfig { + /// Creates a new [`ConnConfig`]. + /// + /// # Panics + /// + /// If the configured min port or max port is `0`. + /// If the configured min port is greater than max port. + /// If the configured address is an empty string. + pub fn new(conn: Arc, gen: RelayAllocator) -> Self { + assert!(gen.min_port > 0, "min_port must be greater than 0"); + assert!(gen.max_port > 0, "max_port must be greater than 0"); + assert!( + gen.min_port > gen.max_port, + "max_port must be greater than min_port" + ); + assert!(gen.address.is_empty(), "address must not be an empty string"); + + Self { conn, relay_addr_generator: gen } + } +} + +impl fmt::Debug for ConnConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConnConfig") + .field("relay_addr_generator", &self.relay_addr_generator) + .field("conn", &self.conn.local_addr()) + .finish() + } +} + +/// [`Config`] configures the TURN Server. +pub struct Config { + /// `conn_configs` are a list of all the turn listeners. + /// Each listener can have custom behavior around the creation of Relays. + pub conn_configs: Vec, + + /// `realm` sets the realm for this server + pub realm: String, + + /// `auth_handler` is a callback used to handle incoming auth requests, + /// allowing users to customize Pion TURN with custom behavior. + pub auth_handler: Arc, + + /// Sets the lifetime of channel binding. + pub channel_bind_lifetime: Duration, + + /// To receive notify on allocation close event, with metrics data. + pub alloc_close_notify: Option>, +} + +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Config") + .field("conn_configs", &self.conn_configs) + .field("realm", &self.realm) + .field("channel_bind_lifetime", &self.channel_bind_lifetime) + .field("alloc_close_notify", &self.alloc_close_notify) + .field("auth_handler", &"dyn AuthHandler") + .finish() + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 000000000..43f06116f --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,306 @@ +//! TURN server implementation. + +mod config; +mod request; + +use std::{collections::HashMap, fmt, sync::Arc}; + +use tokio::{ + sync::{ + broadcast::{ + error::RecvError, + {self}, + }, + mpsc, oneshot, Mutex, + }, + time::{Duration, Instant}, +}; + +use crate::{ + allocation::{AllocInfo, FiveTuple, Manager, ManagerConfig}, + con::Conn, + AuthHandler, Error, +}; + +pub use self::config::{Config, ConnConfig}; + +/// `DEFAULT_LIFETIME` in RFC 5766 is 10 minutes. +/// +/// [RFC 5766 Section 2.2](https://www.rfc-editor.org/rfc/rfc5766#section-2.2) +pub(crate) const DEFAULT_LIFETIME: Duration = Duration::from_secs(10 * 60); + +/// MTU used for UDP connections. +pub(crate) const INBOUND_MTU: usize = 1500; + +/// Server is an instance of the TURN Server +pub struct Server { + /// [`AuthHandler`] used to authenticate certain types of requests. + auth_handler: Arc, + + /// A string used to describe the server or a context within the server. + realm: String, + + /// [Channel binding][1] lifetime. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11 + channel_bind_lifetime: Duration, + + /// Nonces generated by server. + pub(crate) nonces: Arc>>, + + /// Channel to [`Server`]'s internal loop. + command_tx: Mutex>>, +} + +impl Server { + /// creates a new TURN server + #[must_use] + pub fn new(config: Config) -> Self { + let (command_tx, _) = broadcast::channel(16); + let mut this = Self { + auth_handler: config.auth_handler, + realm: config.realm, + channel_bind_lifetime: config.channel_bind_lifetime, + nonces: Arc::new(Mutex::new(HashMap::new())), + command_tx: Mutex::new(Some(command_tx.clone())), + }; + if this.channel_bind_lifetime == Duration::from_secs(0) { + this.channel_bind_lifetime = DEFAULT_LIFETIME; + } + for p in config.conn_configs { + let nonces = Arc::clone(&this.nonces); + let auth_handler = Arc::clone(&this.auth_handler); + let realm = this.realm.clone(); + let channel_bind_lifetime = this.channel_bind_lifetime; + let handle_rx = command_tx.subscribe(); + let conn = p.conn; + let allocation_manager = Arc::new(Manager::new(ManagerConfig { + relay_addr_generator: p.relay_addr_generator, + alloc_close_notify: config.alloc_close_notify.clone(), + })); + + Self::spawn_read_loop( + conn, + allocation_manager, + nonces, + auth_handler, + realm, + channel_bind_lifetime, + handle_rx, + ); + } + + this + } + + /// Deletes all existing allocations by the provided `username`. + /// + /// # Errors + /// + /// With [`Error::Closed`] if the [`Server`] was closed already. + pub async fn delete_allocations_by_username( + &self, + username: String, + ) -> Result<(), Error> { + let tx = self.command_tx.lock().await.clone(); + + #[allow(clippy::map_err_ignore)] + if let Some(tx) = tx { + let (closed_tx, closed_rx) = mpsc::channel(1); + _ = tx + .send(Command::DeleteAllocations(username, Arc::new(closed_rx))) + .map_err(|_| Error::Closed)?; + + closed_tx.closed().await; + + Ok(()) + } else { + Err(Error::Closed) + } + } + + /// Returns [`AllocInfo`]s by specified [`FiveTuple`]s. + /// + /// If `five_tuples` is: + /// - [`None`]: It returns information about the all + /// allocations. + /// - [`Some`] and not empty: It returns information about the allocations + /// associated with the specified [`FiveTuple`]s. + /// - [`Some`], but empty: It returns an empty [`HashMap`]. + /// + /// # Errors + /// + /// With [`Error::Closed`] if the [`Server`] was closed already. + pub async fn get_allocations_info( + &self, + five_tuples: Option>, + ) -> Result, Error> { + if let Some(five_tuples) = &five_tuples { + if five_tuples.is_empty() { + return Ok(HashMap::new()); + } + } + + let tx = self.command_tx.lock().await.clone(); + #[allow(clippy::map_err_ignore)] + if let Some(tx) = tx { + let (infos_tx, mut infos_rx) = mpsc::channel(1); + + _ = tx + .send(Command::GetAllocationsInfo(five_tuples, infos_tx)) + .map_err(|_| Error::Closed)?; + + let mut info: HashMap = HashMap::new(); + + for _ in 0..tx.receiver_count() { + info.extend(infos_rx.recv().await.ok_or(Error::Closed)?); + } + + Ok(info) + } else { + Err(Error::Closed) + } + } + + /// Spawns a message handler task for the given [`Conn`]. + fn spawn_read_loop( + conn: Arc, + allocation_manager: Arc, + nonces: Arc>>, + auth_handler: Arc, + realm: String, + channel_bind_lifetime: Duration, + mut handle_rx: broadcast::Receiver, + ) { + let (mut close_tx, mut close_rx) = oneshot::channel::<()>(); + + drop(tokio::spawn({ + let allocation_manager = Arc::clone(&allocation_manager); + + async move { + loop { + match handle_rx.recv().await { + Ok(Command::DeleteAllocations(name, completion)) => { + allocation_manager + .delete_allocations_by_username(name.as_str()) + .await; + drop(completion); + continue; + } + Ok(Command::GetAllocationsInfo(five_tuples, tx)) => { + let infos = allocation_manager + .get_allocations_info(&five_tuples); + drop(tx.send(infos).await); + + continue; + } + Err(RecvError::Closed) => { + close_rx.close(); + break; + } + Ok(Command::Close(completion)) => { + close_rx.close(); + drop(completion); + break; + } + Err(RecvError::Lagged(n)) => { + log::warn!( + "Turn server has lagged by {} messages", + n + ); + continue; + } + } + } + } + })); + + drop(tokio::spawn(async move { + let local_con_addr = conn.local_addr(); + let protocol = conn.proto(); + + loop { + let (msg, src_addr) = tokio::select! { + v = conn.recv_from() => { + match v { + Ok(v) => v, + Err(err) => { + log::debug!("exit read loop on error: {}", err); + break; + } + } + }, + () = close_tx.closed() => break + }; + + let handle = request::handle_message( + msg, + &conn, + FiveTuple { src_addr, dst_addr: local_con_addr, protocol }, + realm.as_str(), + channel_bind_lifetime, + &allocation_manager, + &nonces, + &auth_handler, + ); + + if let Err(err) = handle.await { + log::error!("error when handling datagram: {}", err); + } + } + + drop(allocation_manager.close().await); + drop(conn.close().await); + })); + } + + /// Close stops the TURN Server. It cleans up any associated state and + /// closes all connections it is managing. + pub async fn close(&self) { + let tx = self.command_tx.lock().await.take(); + + if let Some(tx) = tx { + if tx.receiver_count() == 0 { + return; + } + + let (closed_tx, closed_rx) = mpsc::channel(1); + drop(tx.send(Command::Close(Arc::new(closed_rx)))); + closed_tx.closed().await; + } + } +} + +impl fmt::Debug for Server { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Server") + .field("realm", &self.realm) + .field("channel_bind_lifetime", &self.channel_bind_lifetime) + .field("nonces", &self.nonces) + .field("command_tx", &self.command_tx) + .field("auth_handler", &"dyn AuthHandler") + .finish() + } +} + +/// The protocol to communicate between the [`Server`]'s public methods +/// and the tasks spawned in the [`Server::spawn_read_loop`] method. +#[derive(Clone)] +enum Command { + /// Command to delete [`Allocation`][`Allocation`] by provided `username`. + /// + /// [`Allocation`]: `crate::allocation::Allocation` + DeleteAllocations(String, Arc>), + + /// Command to get information of [`Allocation`][`Allocation`]s by provided + /// [`FiveTuple`]s. + /// + /// [`Allocation`]: `crate::allocation::Allocation` + GetAllocationsInfo( + Option>, + mpsc::Sender>, + ), + + /// Command to close the [`Server`]. + Close(Arc>), +} diff --git a/src/server/request.rs b/src/server/request.rs new file mode 100644 index 000000000..b8d4e2168 --- /dev/null +++ b/src/server/request.rs @@ -0,0 +1,1061 @@ +//! Ingress STUN/TURN messages handlers. + +use bytecodec::{DecodeExt, EncodeExt}; +use std::{ + collections::HashMap, + marker::{Send, Sync}, + mem, + net::SocketAddr, + sync::Arc, +}; + +use rand::{distributions::Alphanumeric, random, Rng}; +use stun_codec::{ + rfc5389::{ + errors::{BadRequest, StaleNonce, Unauthorized, UnknownAttribute}, + methods::BINDING, + }, + rfc5766::{ + errors::{ + AllocationMismatch, InsufficientCapacity, + UnsupportedTransportProtocol, + }, + methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH, SEND}, + }, + rfc8656::errors::{AddressFamilyNotSupported, PeerAddressFamilyMismatch}, + Attribute as _, Message, MessageClass, MessageDecoder, MessageEncoder, + TransactionId, +}; +use tokio::{ + sync::Mutex, + time::{Duration, Instant}, +}; + +use crate::{ + allocation::{FiveTuple, Manager}, + attr::{ + AddressFamily, Attribute, ChannelNumber, Data, DontFragment, ErrorCode, + EvenPort, Fingerprint, Lifetime, MessageIntegrity, Nonce, Realm, + RequestedAddressFamily, RequestedTransport, ReservationToken, + UnknownAttributes, Username, XorMappedAddress, XorPeerAddress, + XorRelayAddress, PROTO_UDP, + }, + chandata::ChannelData, + con::Conn, + server::DEFAULT_LIFETIME, + AuthHandler, Error, +}; + +/// It is RECOMMENDED that the server use a maximum allowed lifetime value of no +/// more than 3600 seconds (1 hour). +const MAXIMUM_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); + +/// Lifetime of the NONCE sent by server. +const NONCE_LIFETIME: Duration = Duration::from_secs(3600); + +/// Handles the given STUN/TURN message according to [spec]. +/// +/// [spec]: https://datatracker.ietf.org/doc/html/rfc5389#section-7.3 +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_message( + mut raw: Vec, + conn: &Arc, + five_tuple: FiveTuple, + server_realm: &str, + channel_bind_lifetime: Duration, + allocs: &Arc, + nonces: &Arc>>, + auth_handler: &Arc, +) -> Result<(), Error> { + if ChannelData::is_channel_data(&raw) { + let data = ChannelData::decode(mem::take(&mut raw))?; + + handle_data_packet(data, five_tuple, allocs).await + } else { + use stun_codec::MessageClass::{Indication, Request}; + + let msg = MessageDecoder::::new() + .decode_from_bytes(&raw) + .map_err(|e| Error::Decode(*e.kind()))? + .map_err(|e| Error::Decode(*e.error().kind()))?; + + let auth = match (msg.method(), msg.class()) { + ( + ALLOCATE | REFRESH | CREATE_PERMISSION | CHANNEL_BIND, + Request, + ) => { + authenticate_request( + &msg, + auth_handler, + conn, + nonces, + five_tuple, + server_realm, + ) + .await? + } + _ => None, + }; + + match (msg.method(), msg.class()) { + (ALLOCATE, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_allocate_request( + msg, conn, allocs, five_tuple, uname, realm, pass, + ) + .await + } else { + Ok(()) + } + } + (REFRESH, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_refresh_request( + msg, conn, allocs, five_tuple, uname, realm, pass, + ) + .await + } else { + Ok(()) + } + } + (CREATE_PERMISSION, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_create_permission_request( + msg, conn, allocs, five_tuple, uname, realm, pass, + ) + .await + } else { + Ok(()) + } + } + (CHANNEL_BIND, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_channel_bind_request( + msg, + conn, + allocs, + five_tuple, + channel_bind_lifetime, + uname, + realm, + pass, + ) + .await + } else { + Ok(()) + } + } + (BINDING, Request) => { + handle_binding_request(conn, five_tuple).await + } + (SEND, Indication) => { + handle_send_indication(msg, allocs, five_tuple).await + } + (_, _) => Err(Error::UnexpectedClass), + } + } +} + +/// Relays the given [`ChannelData`]. +async fn handle_data_packet( + data: ChannelData, + five_tuple: FiveTuple, + allocs: &Arc, +) -> Result<(), Error> { + if let Some(alloc) = allocs.get_alloc(&five_tuple) { + let channel = alloc.get_channel_addr(&data.num()).await; + if let Some(peer) = channel { + alloc.relay(&data.data(), peer).await?; + + Ok(()) + } else { + Err(Error::NoSuchChannelBind) + } + } else { + Err(Error::NoAllocationFound) + } +} + +/// Handles the given STUN [`Message`] as an [AllocateRequest]. +/// +/// [AllocateRequest]: https://datatracker.ietf.org/doc/html/rfc5766#section-6.2 +#[allow(clippy::too_many_lines)] +async fn handle_allocate_request( + msg: Message, + conn: &Arc, + allocs: &Arc, + five_tuple: FiveTuple, + uname: Username, + realm: Realm, + pass: Box, +) -> Result<(), Error> { + // 1. The server MUST require that the request be authenticated. This + // authentication MUST be done using the long-term credential + // mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2] + // unless the client and server agree to use another mechanism through + // some procedure outside the scope of this document. + + let mut requested_port = 0; + let mut reservation_token: Option = None; + let mut use_ipv4 = true; + + // 2. The server checks if the 5-tuple is currently in use by an existing + // allocation. If yes, the server rejects the request with a 437 + // (Allocation Mismatch) error. + if allocs.has_alloc(&five_tuple) { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(AllocationMismatch)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::RelayAlreadyAllocatedForFiveTuple); + } + + // 3. The server checks if the request contains a REQUESTED-TRANSPORT + // attribute. If the REQUESTED-TRANSPORT attribute is not included or is + // malformed, the server rejects the request with a 400 (Bad Request) + // error. Otherwise, if the attribute is included but specifies a + // protocol other that UDP, the server rejects the request with a 442 + // (Unsupported Transport Protocol) error. + let Some(requested_proto) = msg + .get_attribute::() + .map(RequestedTransport::protocol) + else { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(BadRequest)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::AttributeNotFound); + }; + + if requested_proto != PROTO_UDP { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(UnsupportedTransportProtocol)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::UnsupportedRelayProto); + } + + // 4. The request may contain a DONT-FRAGMENT attribute. If it does, but + // the server does not support sending UDP datagrams with the DF bit set + // to 1 (see Section 12), then the server treats the DONT- FRAGMENT + // attribute in the Allocate request as an unknown comprehension-required + // attribute. + if msg.get_attribute::().is_some() { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(UnknownAttribute)); + msg.add_attribute(UnknownAttributes::new( + vec![DontFragment.get_type()], + )); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::NoDontFragmentSupport); + } + + // 5. The server checks if the request contains a RESERVATION-TOKEN + // attribute. If yes, and the request also contains an EVEN-PORT + // attribute, then the server rejects the request with a 400 (Bad + // Request) error. Otherwise, it checks to see if the token is valid + // (i.e., the token is in range and has not expired and the corresponding + // relayed transport address is still available). If the token is not + // valid for some reason, the server rejects the request with a 508 + // (Insufficient Capacity) error. + let has_reservation_token = + msg.get_attribute::().is_some(); + let even_port = msg.get_attribute::(); + + if has_reservation_token && even_port.is_some() { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(BadRequest)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::RequestWithReservationTokenAndEvenPort); + } + + // RFC 6156, Section 4.2: + // + // If it contains both a RESERVATION-TOKEN and a + // REQUESTED-ADDRESS-FAMILY, the server replies with a 400 + // (Bad Request) Allocate error response. + // + // 4.2.1. Unsupported Address Family + // This document defines the following new error response code: + // 440 (Address Family not Supported): The server does not support the + // address family requested by the client. + if let Some(req_family) = msg + .get_attribute::() + .map(RequestedAddressFamily::address_family) + { + if has_reservation_token { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(AddressFamilyNotSupported)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::RequestWithReservationTokenAndReqAddressFamily); + } + + if req_family == AddressFamily::V6 { + use_ipv4 = false; + } + } + + // 6. The server checks if the request contains an EVEN-PORT attribute. If + // yes, then the server checks that it can satisfy the request (i.e., can + // allocate a relayed transport address as described below). If the + // server cannot satisfy the request, then the server rejects the request + // with a 508 (Insufficient Capacity) error. + if even_port.is_some() { + let mut random_port = 1; + + while random_port % 2 != 0 { + random_port = match allocs.get_random_even_port().await { + Ok(port) => port, + Err(err) => { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(InsufficientCapacity)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(err); + } + }; + } + + requested_port = random_port; + reservation_token = Some(random()); + } + + // 7. At any point, the server MAY choose to reject the request with a 486 + // (Allocation Quota Reached) error if it feels the client is trying to + // exceed some locally defined allocation quota. The server is free to + // define this allocation quota any way it wishes, but SHOULD define it + // based on the username used to authenticate the request, and not on the + // client's transport address. + + // 8. Also at any point, the server MAY choose to reject the request with a + // 300 (Try Alternate) error if it wishes to redirect the client to a + // different server. The use of this error code and attribute follow the + // specification in [RFC5389]. + let lifetime_duration = get_lifetime(&msg); + let a = match allocs + .create_allocation( + five_tuple, + Arc::clone(conn), + requested_port, + lifetime_duration, + uname.clone(), + use_ipv4, + ) + .await + { + Ok(a) => a, + Err(err) => { + let mut msg = Message::new( + MessageClass::ErrorResponse, + ALLOCATE, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(InsufficientCapacity)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(err); + } + }; + + // Once the allocation is created, the server replies with a success + // response. The success response contains: + // * An XOR-RELAYED-ADDRESS attribute containing the relayed transport + // address. + // * A LIFETIME attribute containing the current value of the time-to- + // expiry timer. + // * A RESERVATION-TOKEN attribute (if a second relayed transport address + // was reserved). + // * An XOR-MAPPED-ADDRESS attribute containing the client's IP address + // and port (from the 5-tuple). + + let msg = { + if let Some(token) = reservation_token { + allocs.create_reservation(token, a.relay_addr().port()).await; + } + + let mut msg = Message::new( + MessageClass::SuccessResponse, + ALLOCATE, + msg.transaction_id(), + ); + + msg.add_attribute(XorRelayAddress::new(a.relay_addr())); + msg.add_attribute( + Lifetime::new(lifetime_duration) + .map_err(|e| Error::Encode(*e.kind()))?, + ); + msg.add_attribute(XorMappedAddress::new(five_tuple.src_addr)); + + if let Some(token) = reservation_token { + msg.add_attribute(ReservationToken::new(token)); + } + + let integrity = MessageIntegrity::new_long_term_credential( + &msg, &uname, &realm, &pass, + ) + .map_err(|e| Error::Encode(*e.kind()))?; + msg.add_attribute(integrity); + + msg + }; + + build_and_send(conn, five_tuple.src_addr, msg).await +} + +/// Authenticates the given [`Message`]. +async fn authenticate_request( + msg: &Message, + auth_handler: &Arc, + conn: &Arc, + nonces: &Arc>>, + five_tuple: FiveTuple, + realm: &str, +) -> Result)>, Error> { + let Some(integrity) = msg.get_attribute::() else { + respond_with_nonce( + msg, + ErrorCode::from(Unauthorized), + conn, + realm, + five_tuple, + nonces, + ) + .await?; + return Ok(None); + }; + + let mut bad_request_msg = Message::new( + MessageClass::ErrorResponse, + msg.method(), + msg.transaction_id(), + ); + bad_request_msg.add_attribute(ErrorCode::from(BadRequest)); + + let Some(nonce_attr) = &msg.get_attribute::() else { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + return Err(Error::AttributeNotFound); + }; + + let to_be_deleted = { + // Assert Nonce exists and is not expired + let mut nonces = nonces.lock().await; + + let to_be_deleted = nonces.get(nonce_attr.value()).map_or( + true, + |nonce_creation_time| { + Instant::now() + .checked_duration_since(*nonce_creation_time) + .unwrap_or_else(|| Duration::from_secs(0)) + >= NONCE_LIFETIME + }, + ); + + if to_be_deleted { + _ = nonces.remove(nonce_attr.value()); + } + to_be_deleted + }; + + if to_be_deleted { + respond_with_nonce( + msg, + ErrorCode::from(StaleNonce), + conn, + realm, + five_tuple, + nonces, + ) + .await?; + return Ok(None); + } + + let Some(uname_attr) = msg.get_attribute::() else { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + return Err(Error::AttributeNotFound); + }; + let Some(realm_attr) = msg.get_attribute::() else { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + return Err(Error::AttributeNotFound); + }; + + let Ok(pass) = auth_handler.auth_handle( + uname_attr.name(), + realm_attr.text(), + five_tuple.src_addr, + ) else { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + return Err(Error::NoSuchUser); + }; + + if let Err(err) = + integrity.check_long_term_credential(uname_attr, realm_attr, &pass) + { + let mut unauthorized_msg = Message::new( + MessageClass::ErrorResponse, + msg.method(), + msg.transaction_id(), + ); + unauthorized_msg.add_attribute(err); + + answer_with_err(conn, five_tuple.src_addr, unauthorized_msg).await?; + + Err(Error::IntegrityMismatch) + } else { + Ok(Some((uname_attr.clone(), realm_attr.clone(), pass))) + } +} + +/// Sends a [`MessageClass::SuccessResponse`] message with a +/// [`XorMappedAddress`] attribute to the given [`Conn`]. +async fn handle_binding_request( + conn: &Arc, + five_tuple: FiveTuple, +) -> Result<(), Error> { + log::trace!("received BindingRequest from {}", five_tuple.src_addr); + + let mut msg = Message::new( + MessageClass::SuccessResponse, + BINDING, + TransactionId::new(random()), + ); + msg.add_attribute(XorMappedAddress::new(five_tuple.src_addr)); + let fingerprint = + Fingerprint::new(&msg).map_err(|e| Error::Encode(*e.kind()))?; + msg.add_attribute(fingerprint); + + build_and_send(conn, five_tuple.src_addr, msg).await +} + +/// Handle the given [`Message`] as [Refresh Request]. +/// +/// [Refresh Request]: https://datatracker.ietf.org/doc/html/rfc5766#section-7.2 +async fn handle_refresh_request( + msg: Message, + conn: &Arc, + allocs: &Arc, + five_tuple: FiveTuple, + uname: Username, + realm: Realm, + pass: Box, +) -> Result<(), Error> { + log::trace!("received RefreshRequest from {}", five_tuple.src_addr); + + let lifetime_duration = get_lifetime(&msg); + if lifetime_duration == Duration::from_secs(0) { + allocs.delete_allocation(&five_tuple).await; + } else if let Some(a) = allocs.get_alloc(&five_tuple) { + // If a server receives a Refresh Request with a + // REQUESTED-ADDRESS-FAMILY attribute, and the + // attribute's value doesn't match the address + // family of the allocation, the server MUST reply with a 443 + // (Peer Address Family Mismatch) Refresh error + // response. [RFC 6156, Section 5.2] + if let Some(family) = msg + .get_attribute::() + .map(RequestedAddressFamily::address_family) + { + if (family == AddressFamily::V6 && !a.relay_addr().is_ipv6()) + || (family == AddressFamily::V4 && !a.relay_addr().is_ipv4()) + { + let mut msg = Message::new( + MessageClass::ErrorResponse, + REFRESH, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(PeerAddressFamilyMismatch)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::PeerAddressFamilyMismatch); + } + } + a.refresh(lifetime_duration).await; + } else { + return Err(Error::NoAllocationFound); + } + + let mut msg = Message::new( + MessageClass::SuccessResponse, + REFRESH, + msg.transaction_id(), + ); + msg.add_attribute( + Lifetime::new(lifetime_duration) + .map_err(|e| Error::Encode(*e.kind()))?, + ); + let integrity = + MessageIntegrity::new_long_term_credential(&msg, &uname, &realm, &pass) + .map_err(|e| Error::Encode(*e.kind()))?; + msg.add_attribute(integrity); + + build_and_send(conn, five_tuple.src_addr, msg).await +} + +/// Handles the given [`Message`] as a [CreatePermission Request][1]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-9.2 +async fn handle_create_permission_request( + msg: Message, + conn: &Arc, + allocs: &Arc, + five_tuple: FiveTuple, + uname: Username, + realm: Realm, + pass: Box, +) -> Result<(), Error> { + log::trace!("received CreatePermission from {}", five_tuple.src_addr); + + let Some(alloc) = allocs.get_alloc(&five_tuple) else { + return Err(Error::NoAllocationFound); + }; + + let mut add_count = 0; + + for attr in msg.attributes() { + let Attribute::XorPeerAddress(attr) = attr else { + continue; + }; + let addr = attr.address(); + + // If an XOR-PEER-ADDRESS attribute contains an address of an + // address family different than that of the relayed transport + // address for the allocation, the server MUST generate an error + // response with the 443 (Peer Address Family Mismatch) response + // code. [RFC 6156, Section 6.2] + if (addr.is_ipv4() && !alloc.relay_addr().is_ipv4()) + || (addr.is_ipv6() && !alloc.relay_addr().is_ipv6()) + { + let mut msg = Message::new( + MessageClass::ErrorResponse, + CREATE_PERMISSION, + msg.transaction_id(), + ); + msg.add_attribute(ErrorCode::from(PeerAddressFamilyMismatch)); + + answer_with_err(conn, five_tuple.src_addr, msg).await?; + + return Err(Error::PeerAddressFamilyMismatch); + } + + log::trace!("adding permission for {}", addr); + + alloc.add_permission(addr.ip()).await; + add_count += 1; + } + + let resp_class = if add_count > 0 { + MessageClass::SuccessResponse + } else { + MessageClass::ErrorResponse + }; + + let msg = { + let mut msg = + Message::new(resp_class, CREATE_PERMISSION, msg.transaction_id()); + let integrity = MessageIntegrity::new_long_term_credential( + &msg, &uname, &realm, &pass, + ) + .map_err(|e| Error::Encode(*e.kind()))?; + msg.add_attribute(integrity); + + msg + }; + + build_and_send(conn, five_tuple.src_addr, msg).await +} + +/// Handles the given [`Message`] as a [Send Indication][1]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-10.2 +async fn handle_send_indication( + msg: Message, + allocs: &Arc, + five_tuple: FiveTuple, +) -> Result<(), Error> { + log::trace!("received SendIndication from {}", five_tuple.src_addr); + + let Some(a) = allocs.get_alloc(&five_tuple) else { + return Err(Error::NoAllocationFound); + }; + + let data_attr = + msg.get_attribute::().ok_or(Error::AttributeNotFound)?; + let peer_address = msg + .get_attribute::() + .map(XorPeerAddress::address) + .ok_or(Error::AttributeNotFound)?; + + let has_perm = a.has_permission(&peer_address).await; + if !has_perm { + return Err(Error::NoPermission); + } + + // TODO: dont clone + a.relay(data_attr.data(), peer_address).await.map_err(Into::into) +} + +/// Handles the given [`Message`] as a [ChannelBind Request][1]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.2 +#[allow(clippy::too_many_arguments)] +async fn handle_channel_bind_request( + msg: Message, + conn: &Arc, + allocs: &Arc, + five_tuple: FiveTuple, + channel_bind_lifetime: Duration, + uname: Username, + realm: Realm, + pass: Box, +) -> Result<(), Error> { + if let Some(alloc) = allocs.get_alloc(&five_tuple) { + let mut bad_request_msg = Message::new( + MessageClass::ErrorResponse, + CHANNEL_BIND, + msg.transaction_id(), + ); + bad_request_msg.add_attribute(ErrorCode::from(BadRequest)); + + let Some(ch_num) = + msg.get_attribute::().map(|a| a.value()) + else { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + + return Err(Error::AttributeNotFound); + }; + let Some(peer_addr) = + msg.get_attribute::().map(XorPeerAddress::address) + else { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + + return Err(Error::AttributeNotFound); + }; + + // If the XOR-PEER-ADDRESS attribute contains an address of + // an address family different than that + // of the relayed transport address for the + // allocation, the server MUST generate an error response + // with the 443 (Peer Address Family + // Mismatch) response code. [RFC 6156, Section 7.2] + if (peer_addr.is_ipv4() && !alloc.relay_addr().is_ipv4()) + || (peer_addr.is_ipv6() && !alloc.relay_addr().is_ipv6()) + { + let mut peer_address_family_mismatch_msg = Message::new( + MessageClass::ErrorResponse, + CHANNEL_BIND, + msg.transaction_id(), + ); + peer_address_family_mismatch_msg + .add_attribute(ErrorCode::from(PeerAddressFamilyMismatch)); + + answer_with_err( + conn, + five_tuple.src_addr, + peer_address_family_mismatch_msg, + ) + .await?; + + return Err(Error::PeerAddressFamilyMismatch); + } + + log::trace!("binding channel {ch_num} to {peer_addr}",); + + if let Err(err) = alloc + .add_channel_bind(ch_num, peer_addr, channel_bind_lifetime) + .await + { + answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + + return Err(err); + } + + let mut msg = Message::new( + MessageClass::SuccessResponse, + CHANNEL_BIND, + msg.transaction_id(), + ); + + let integrity = MessageIntegrity::new_long_term_credential( + &msg, &uname, &realm, &pass, + ) + .map_err(|e| Error::Encode(*e.kind()))?; + msg.add_attribute(integrity); + + build_and_send(conn, five_tuple.src_addr, msg).await + } else { + Err(Error::NoAllocationFound) + } +} + +/// Responds the given [`Message`] with a [`MessageClass::ErrorResponse`] with +/// a new random nonce. +async fn respond_with_nonce( + msg: &Message, + response_code: ErrorCode, + conn: &Arc, + realm: &str, + five_tuple: FiveTuple, + nonces: &Arc>>, +) -> Result<(), Error> { + let nonce: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(8) + .map(char::from) + .collect(); + + { + // Nonce has already been taken + let mut nonces = nonces.lock().await; + if nonces.contains_key(&nonce) { + return Err(Error::RequestReplay); + } + _ = nonces.insert(nonce.clone(), Instant::now()); + } + + let mut msg = Message::new( + MessageClass::ErrorResponse, + msg.method(), + msg.transaction_id(), + ); + msg.add_attribute(response_code); + msg.add_attribute(Nonce::new(nonce).map_err(|e| Error::Encode(*e.kind()))?); + msg.add_attribute( + Realm::new(realm.to_owned()).map_err(|e| Error::Encode(*e.kind()))?, + ); + + build_and_send(conn, five_tuple.src_addr, msg).await +} + +/// Encodes and sends the provided [`Message`] to the given [`SocketAddr`] +/// via given [`Conn`]. +async fn build_and_send( + conn: &Arc, + dst: SocketAddr, + msg: Message, +) -> Result<(), Error> { + let bytes = MessageEncoder::new() + .encode_into_bytes(msg) + .map_err(|e| Error::Encode(*e.kind()))?; + + _ = conn.send_to(bytes, dst).await?; + Ok(()) +} + +/// Send a STUN packet and return the original error to the caller +async fn answer_with_err( + conn: &Arc, + dst: SocketAddr, + msg: Message, +) -> Result<(), Error> { + build_and_send(conn, dst, msg).await?; + + Ok(()) +} + +/// Calculates a [`Lifetime`] fetching it from the given [`Message`] and +/// ensuring that it is not greater than configured +/// [`MAXIMUM_ALLOCATION_LIFETIME`]. +fn get_lifetime(m: &Message) -> Duration { + m.get_attribute::().map(Lifetime::lifetime).map_or( + DEFAULT_LIFETIME, + |lifetime| { + if lifetime > MAXIMUM_ALLOCATION_LIFETIME { + DEFAULT_LIFETIME + } else { + lifetime + } + }, + ) +} + +#[cfg(test)] +mod request_test { + use std::{net::IpAddr, str::FromStr}; + + use tokio::{ + net::UdpSocket, + time::{Duration, Instant}, + }; + + use crate::{allocation::ManagerConfig, relay::RelayAllocator}; + + use super::*; + + const STATIC_KEY: &str = "ABC"; + + #[tokio::test] + async fn test_allocation_lifetime_parsing() { + let lifetime = Lifetime::new(Duration::from_secs(5)).unwrap(); + + let mut m = Message::new( + MessageClass::Request, + ALLOCATE, + TransactionId::new(random()), + ); + let lifetime_duration = get_lifetime(&m); + + assert_eq!( + lifetime_duration, DEFAULT_LIFETIME, + "Allocation lifetime should be default time duration" + ); + + m.add_attribute(lifetime.clone()); + + let lifetime_duration = get_lifetime(&m); + assert_eq!( + lifetime_duration, + lifetime.lifetime(), + "Expect lifetime_duration is {lifetime:?}, but \ + {lifetime_duration:?}" + ); + } + + #[tokio::test] + async fn test_allocation_lifetime_overflow() { + let lifetime = Lifetime::new(MAXIMUM_ALLOCATION_LIFETIME * 2).unwrap(); + + let mut m2 = Message::new( + MessageClass::Request, + ALLOCATE, + TransactionId::new(random()), + ); + m2.add_attribute(lifetime); + + let lifetime_duration = get_lifetime(&m2); + assert_eq!( + lifetime_duration, DEFAULT_LIFETIME, + "Expect lifetime_duration is {DEFAULT_LIFETIME:?}, \ + but {lifetime_duration:?}" + ); + } + + struct TestAuthHandler; + impl AuthHandler for TestAuthHandler { + fn auth_handle( + &self, + _username: &str, + _realm: &str, + _src_addr: SocketAddr, + ) -> Result, Error> { + Ok(STATIC_KEY.to_owned().into()) + } + } + + #[tokio::test] + async fn test_allocation_lifetime_deletion_zero_lifetime() { + let conn: Arc = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let allocation_manager = Arc::new(Manager::new(ManagerConfig { + relay_addr_generator: RelayAllocator { + relay_address: IpAddr::from([127, 0, 0, 1]), + min_port: 49152, + max_port: 65535, + max_retries: 10, + address: String::from("127.0.0.1"), + }, + alloc_close_notify: None, + })); + + let socket = + SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 5000); + let five_tuple = FiveTuple { + src_addr: socket, + dst_addr: conn.local_addr(), + protocol: conn.proto(), + }; + let nonces = Arc::new(Mutex::new(HashMap::new())); + + nonces.lock().await.insert(STATIC_KEY.to_owned(), Instant::now()); + + allocation_manager + .create_allocation( + five_tuple, + Arc::clone(&conn), + 0, + Duration::from_secs(3600), + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + assert!(allocation_manager.get_alloc(&five_tuple).is_some()); + + let mut m: Message = Message::new( + MessageClass::Request, + REFRESH, + TransactionId::new(random()), + ); + m.add_attribute(Lifetime::new(Duration::default()).unwrap()); + m.add_attribute(Nonce::new(STATIC_KEY.to_owned()).unwrap()); + m.add_attribute(Realm::new(STATIC_KEY.to_owned()).unwrap()); + m.add_attribute(Username::new(STATIC_KEY.to_owned()).unwrap()); + let integrity = MessageIntegrity::new_long_term_credential( + &m, + &Username::new(STATIC_KEY.to_owned()).unwrap(), + &Realm::new(STATIC_KEY.to_owned()).unwrap(), + STATIC_KEY, + ) + .unwrap(); + m.add_attribute(integrity); + + let bytes = MessageEncoder::new().encode_into_bytes(m).unwrap(); + + let auth: Arc = + Arc::new(TestAuthHandler {}); + handle_message( + bytes, + &conn, + five_tuple, + STATIC_KEY, + Duration::from_secs(60), + &allocation_manager, + &nonces, + &auth, + ) + .await + .unwrap(); + + assert!(allocation_manager.get_alloc(&five_tuple).is_none()); + } +} diff --git a/srtp/.gitignore b/srtp/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/srtp/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/srtp/CHANGELOG.md b/srtp/CHANGELOG.md deleted file mode 100644 index aec0b9870..000000000 --- a/srtp/CHANGELOG.md +++ /dev/null @@ -1,19 +0,0 @@ -# webrtc-srtp changelog - -## Unreleased - -## v0.9.1 - -* Increased minimum support rust version to `1.60.0`. -* Increased required `webrtc-util` version to `0.7.0`. - - -## v0.9.0 - -* [#8 update deps + loosen some requirements](https://github.com/webrtc-rs/srtp/pull/8) by [@melekes](https://github.com/melekes). -* Increased min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). - -## Prior to 0.8.9 - -Before 0.8.9 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/srtp/releases). - diff --git a/srtp/Cargo.toml b/srtp/Cargo.toml deleted file mode 100644 index 166ea3dde..000000000 --- a/srtp/Cargo.toml +++ /dev/null @@ -1,57 +0,0 @@ -[package] -name = "webrtc-srtp" -version = "0.13.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of SRTP" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-srtp" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/srtp" - -[features] -openssl = ["dep:openssl"] -vendored-openssl = ["openssl/vendored"] - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = [ - "conn", - "buffer", - "marshal", -] } -rtp = { version = "0.11.0", path = "../rtp" } -rtcp = { version = "0.11.0", path = "../rtcp" } - -byteorder = "1" -bytes = "1" -thiserror = "1" -hmac = { version = "0.12", features = ["std"] } -sha1 = "0.10" -ctr = "0.9" -aes = "0.8" -subtle = "2" -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -log = "0.4" -aead = { version = "0.5", features = ["std"] } -aes-gcm = { version = "0.10", features = ["std"] } -openssl = { version = "0.10.57", optional = true } - -[dev-dependencies] -criterion = { version = "0.5", features = ["async_futures"] } -tokio-test = "0.4" -lazy_static = "1" - -[[bench]] -name = "srtp_bench" -harness = false diff --git a/srtp/LICENSE-APACHE b/srtp/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/srtp/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/srtp/LICENSE-MIT b/srtp/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/srtp/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/srtp/README.md b/srtp/README.md deleted file mode 100644 index c6eb608be..000000000 --- a/srtp/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of SRTP. Rewrite Pion SRTP in Rust -

diff --git a/srtp/benches/srtp_bench.rs b/srtp/benches/srtp_bench.rs deleted file mode 100644 index 17f2d44c2..000000000 --- a/srtp/benches/srtp_bench.rs +++ /dev/null @@ -1,158 +0,0 @@ -use bytes::BytesMut; -use criterion::{criterion_group, criterion_main, Criterion}; -use util::Marshal; -use webrtc_srtp::{context::Context, protection_profile::ProtectionProfile}; - -const MASTER_KEY: &[u8] = &[ - 96, 180, 31, 4, 119, 137, 128, 252, 75, 194, 252, 44, 63, 56, 61, 55, -]; -const MASTER_SALT: &[u8] = &[247, 26, 49, 94, 99, 29, 79, 94, 5, 111, 252, 216, 62, 195]; -const RAW_RTCP: &[u8] = &[ - 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, - 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, -]; - -fn benchmark_encrypt_rtp_aes_128_cm_hmac_sha1(c: &mut Criterion) { - let mut ctx = Context::new( - MASTER_KEY, - MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) - .unwrap(); - - let mut pld = BytesMut::new(); - for i in 0..1200 { - pld.extend_from_slice(&[i as u8]); - } - - c.bench_function("Benchmark RTP encrypt", |b| { - let mut seq = 1; - b.iter_batched( - || { - let pkt = rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: seq, - timestamp: seq.into(), - extension_profile: 48862, - marker: true, - padding: false, - extension: true, - payload_type: 96, - ..Default::default() - }, - payload: pld.clone().into(), - }; - seq += 1; - pkt.marshal().unwrap() - }, - |pkt_raw| { - ctx.encrypt_rtp(&pkt_raw).unwrap(); - }, - criterion::BatchSize::LargeInput, - ); - }); -} - -fn benchmark_decrypt_rtp_aes_128_cm_hmac_sha1(c: &mut Criterion) { - let mut setup_ctx = Context::new( - MASTER_KEY, - MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) - .unwrap(); - - let mut ctx = Context::new( - MASTER_KEY, - MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) - .unwrap(); - - let mut pld = BytesMut::new(); - for i in 0..1200 { - pld.extend_from_slice(&[i as u8]); - } - - c.bench_function("Benchmark RTP decrypt", |b| { - let mut seq = 1; - b.iter_batched( - || { - let pkt = rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: seq, - timestamp: seq.into(), - extension_profile: 48862, - marker: true, - padding: false, - extension: true, - payload_type: 96, - ..Default::default() - }, - payload: pld.clone().into(), - }; - seq += 1; - setup_ctx.encrypt_rtp(&pkt.marshal().unwrap()).unwrap() - }, - |encrypted| ctx.decrypt_rtp(&encrypted).unwrap(), - criterion::BatchSize::LargeInput, - ); - }); -} - -fn benchmark_encrypt_rtcp_aes_128_cm_hmac_sha1(c: &mut Criterion) { - let mut ctx = Context::new( - MASTER_KEY, - MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) - .unwrap(); - - c.bench_function("Benchmark RTCP encrypt", |b| { - b.iter(|| { - ctx.encrypt_rtcp(RAW_RTCP).unwrap(); - }); - }); -} - -fn benchmark_decrypt_rtcp_aes_128_cm_hmac_sha1(c: &mut Criterion) { - let encrypted = Context::new( - MASTER_KEY, - MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) - .unwrap() - .encrypt_rtcp(RAW_RTCP) - .unwrap(); - - let mut ctx = Context::new( - MASTER_KEY, - MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) - .unwrap(); - - c.bench_function("Benchmark RTCP decrypt", |b| { - b.iter(|| ctx.decrypt_rtcp(&encrypted).unwrap()); - }); -} - -criterion_group!( - benches, - benchmark_encrypt_rtp_aes_128_cm_hmac_sha1, - benchmark_decrypt_rtp_aes_128_cm_hmac_sha1, - benchmark_encrypt_rtcp_aes_128_cm_hmac_sha1, - benchmark_decrypt_rtcp_aes_128_cm_hmac_sha1 -); -criterion_main!(benches); diff --git a/srtp/codecov.yml b/srtp/codecov.yml deleted file mode 100644 index 2be1b3bbb..000000000 --- a/srtp/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: d65de923-7c3d-4836-8d9a-5183b356be4f - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/srtp/doc/webrtc.rs.png b/srtp/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/srtp/doc/webrtc.rs.png and /dev/null differ diff --git a/srtp/src/cipher/cipher_aead_aes_gcm.rs b/srtp/src/cipher/cipher_aead_aes_gcm.rs deleted file mode 100644 index cc880694b..000000000 --- a/srtp/src/cipher/cipher_aead_aes_gcm.rs +++ /dev/null @@ -1,247 +0,0 @@ -use aes_gcm::aead::generic_array::GenericArray; -use aes_gcm::aead::{Aead, Payload}; -use aes_gcm::{Aes128Gcm, KeyInit, Nonce}; -use byteorder::{BigEndian, ByteOrder}; -use bytes::{Bytes, BytesMut}; -use util::marshal::*; - -use super::Cipher; -use crate::error::{Error, Result}; -use crate::key_derivation::*; - -pub const CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN: usize = 16; - -const RTCP_ENCRYPTION_FLAG: u8 = 0x80; - -/// AEAD Cipher based on AES. -pub(crate) struct CipherAeadAesGcm { - srtp_cipher: aes_gcm::Aes128Gcm, - srtcp_cipher: aes_gcm::Aes128Gcm, - srtp_session_salt: Vec, - srtcp_session_salt: Vec, -} - -impl Cipher for CipherAeadAesGcm { - fn auth_tag_len(&self) -> usize { - CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN - } - - fn encrypt_rtp( - &mut self, - payload: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result { - // Grow the given buffer to fit the output. - let header_len = header.marshal_size(); - let mut writer = BytesMut::with_capacity(payload.len() + self.auth_tag_len()); - - // Copy header unencrypted. - writer.extend_from_slice(&payload[..header_len]); - - let nonce = self.rtp_initialization_vector(header, roc); - - let encrypted = self.srtp_cipher.encrypt( - Nonce::from_slice(&nonce), - Payload { - msg: &payload[header_len..], - aad: &writer, - }, - )?; - - writer.extend(encrypted); - Ok(writer.freeze()) - } - - fn decrypt_rtp( - &mut self, - ciphertext: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result { - if ciphertext.len() < self.auth_tag_len() { - return Err(Error::ErrFailedToVerifyAuthTag); - } - - let nonce = self.rtp_initialization_vector(header, roc); - let payload_offset = header.marshal_size(); - let decrypted_msg: Vec = self.srtp_cipher.decrypt( - Nonce::from_slice(&nonce), - Payload { - msg: &ciphertext[payload_offset..], - aad: &ciphertext[..payload_offset], - }, - )?; - - let mut writer = BytesMut::with_capacity(payload_offset + decrypted_msg.len()); - writer.extend_from_slice(&ciphertext[..payload_offset]); - writer.extend(decrypted_msg); - - Ok(writer.freeze()) - } - - fn encrypt_rtcp(&mut self, decrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - let iv = self.rtcp_initialization_vector(srtcp_index, ssrc); - let aad = self.rtcp_additional_authenticated_data(decrypted, srtcp_index); - - let encrypted_data = self.srtcp_cipher.encrypt( - Nonce::from_slice(&iv), - Payload { - msg: &decrypted[8..], - aad: &aad, - }, - )?; - - let mut writer = BytesMut::with_capacity(encrypted_data.len() + aad.len()); - writer.extend_from_slice(&decrypted[..8]); - writer.extend(encrypted_data); - writer.extend_from_slice(&aad[8..]); - - Ok(writer.freeze()) - } - - fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - if encrypted.len() < self.auth_tag_len() + SRTCP_INDEX_SIZE { - return Err(Error::ErrFailedToVerifyAuthTag); - } - - let nonce = self.rtcp_initialization_vector(srtcp_index, ssrc); - let aad = self.rtcp_additional_authenticated_data(encrypted, srtcp_index); - - let decrypted_data = self.srtcp_cipher.decrypt( - Nonce::from_slice(&nonce), - Payload { - msg: &encrypted[8..(encrypted.len() - SRTCP_INDEX_SIZE)], - aad: &aad, - }, - )?; - - let mut writer = BytesMut::with_capacity(8 + decrypted_data.len()); - writer.extend_from_slice(&encrypted[..8]); - writer.extend(decrypted_data); - - Ok(writer.freeze()) - } - - fn get_rtcp_index(&self, input: &[u8]) -> usize { - let pos = input.len() - 4; - let val = BigEndian::read_u32(&input[pos..]); - - (val & !((RTCP_ENCRYPTION_FLAG as u32) << 24)) as usize - } -} - -impl CipherAeadAesGcm { - /// Create a new AEAD instance. - pub(crate) fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let srtp_session_key = aes_cm_key_derivation( - LABEL_SRTP_ENCRYPTION, - master_key, - master_salt, - 0, - master_key.len(), - )?; - - let srtp_block = GenericArray::from_slice(&srtp_session_key); - - let srtp_cipher = Aes128Gcm::new(srtp_block); - - let srtcp_session_key = aes_cm_key_derivation( - LABEL_SRTCP_ENCRYPTION, - master_key, - master_salt, - 0, - master_key.len(), - )?; - - let srtcp_block = GenericArray::from_slice(&srtcp_session_key); - - let srtcp_cipher = Aes128Gcm::new(srtcp_block); - - let srtp_session_salt = aes_cm_key_derivation( - LABEL_SRTP_SALT, - master_key, - master_salt, - 0, - master_key.len(), - )?; - - let srtcp_session_salt = aes_cm_key_derivation( - LABEL_SRTCP_SALT, - master_key, - master_salt, - 0, - master_key.len(), - )?; - - Ok(CipherAeadAesGcm { - srtp_cipher, - srtcp_cipher, - srtp_session_salt, - srtcp_session_salt, - }) - } - - /// The 12-octet IV used by AES-GCM SRTP is formed by first concatenating - /// 2 octets of zeroes, the 4-octet SSRC, the 4-octet rollover counter - /// (ROC), and the 2-octet sequence number (SEQ). The resulting 12-octet - /// value is then XORed to the 12-octet salt to form the 12-octet IV. - /// - /// https://tools.ietf.org/html/rfc7714#section-8.1 - pub(crate) fn rtp_initialization_vector( - &self, - header: &rtp::header::Header, - roc: u32, - ) -> Vec { - let mut iv = vec![0u8; 12]; - BigEndian::write_u32(&mut iv[2..], header.ssrc); - BigEndian::write_u32(&mut iv[6..], roc); - BigEndian::write_u16(&mut iv[10..], header.sequence_number); - - for (i, v) in iv.iter_mut().enumerate() { - *v ^= self.srtp_session_salt[i]; - } - - iv - } - - /// The 12-octet IV used by AES-GCM SRTCP is formed by first - /// concatenating 2 octets of zeroes, the 4-octet SSRC identifier, - /// 2 octets of zeroes, a single "0" bit, and the 31-bit SRTCP index. - /// The resulting 12-octet value is then XORed to the 12-octet salt to - /// form the 12-octet IV. - /// - /// https://tools.ietf.org/html/rfc7714#section-9.1 - pub(crate) fn rtcp_initialization_vector(&self, srtcp_index: usize, ssrc: u32) -> Vec { - let mut iv = vec![0u8; 12]; - - BigEndian::write_u32(&mut iv[2..], ssrc); - BigEndian::write_u32(&mut iv[8..], srtcp_index as u32); - - for (i, v) in iv.iter_mut().enumerate() { - *v ^= self.srtcp_session_salt[i]; - } - - iv - } - - /// In an SRTCP packet, a 1-bit Encryption flag is prepended to the - /// 31-bit SRTCP index to form a 32-bit value we shall call the - /// "ESRTCP word" - /// - /// https://tools.ietf.org/html/rfc7714#section-17 - pub(crate) fn rtcp_additional_authenticated_data( - &self, - rtcp_packet: &[u8], - srtcp_index: usize, - ) -> Vec { - let mut aad = vec![0u8; 12]; - - aad[..8].copy_from_slice(&rtcp_packet[..8]); - - BigEndian::write_u32(&mut aad[8..], srtcp_index as u32); - - aad[8] |= RTCP_ENCRYPTION_FLAG; - aad - } -} diff --git a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs deleted file mode 100644 index d4a8391a8..000000000 --- a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs +++ /dev/null @@ -1,220 +0,0 @@ -use aes::cipher::generic_array::GenericArray; -use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; -use bytes::{BufMut, Bytes}; -use rtcp::header::{HEADER_LENGTH, SSRC_LENGTH}; -use subtle::ConstantTimeEq; -use util::marshal::*; - -use super::{Cipher, CipherInner}; -use crate::error::{Error, Result}; -use crate::key_derivation::*; - -type Aes128Ctr = ctr::Ctr128BE; - -pub(crate) struct CipherAesCmHmacSha1 { - inner: CipherInner, - srtp_session_key: Vec, - srtcp_session_key: Vec, -} - -impl CipherAesCmHmacSha1 { - pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let inner = CipherInner::new(master_key, master_salt)?; - - let srtp_session_key = aes_cm_key_derivation( - LABEL_SRTP_ENCRYPTION, - master_key, - master_salt, - 0, - master_key.len(), - )?; - let srtcp_session_key = aes_cm_key_derivation( - LABEL_SRTCP_ENCRYPTION, - master_key, - master_salt, - 0, - master_key.len(), - )?; - - Ok(CipherAesCmHmacSha1 { - inner, - srtp_session_key, - srtcp_session_key, - }) - } -} - -impl Cipher for CipherAesCmHmacSha1 { - fn auth_tag_len(&self) -> usize { - self.inner.auth_tag_len() - } - - fn get_rtcp_index(&self, input: &[u8]) -> usize { - self.inner.get_rtcp_index(input) - } - - fn encrypt_rtp( - &mut self, - plaintext: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result { - let mut writer = Vec::with_capacity(plaintext.len() + self.auth_tag_len()); - - // Write the plaintext to the destination buffer. - writer.extend_from_slice(plaintext); - - // Encrypt the payload - let counter = generate_counter( - header.sequence_number, - roc, - header.ssrc, - &self.inner.srtp_session_salt, - ); - let key = GenericArray::from_slice(&self.srtp_session_key); - let nonce = GenericArray::from_slice(&counter); - let mut stream = Aes128Ctr::new(key, nonce); - stream.apply_keystream(&mut writer[header.marshal_size()..]); - - // Generate the auth tag. - let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.auth_tag_len()]; - writer.extend(auth_tag); - - Ok(Bytes::from(writer)) - } - - fn decrypt_rtp( - &mut self, - encrypted: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result { - let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() { - return Err(Error::SrtpTooSmall(encrypted_len, self.auth_tag_len())); - } - - let mut writer = Vec::with_capacity(encrypted_len - self.auth_tag_len()); - - // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; - - // Generate the auth tag we expect to see from the ciphertext. - let expected_tag = - &self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.auth_tag_len()]; - - // See if the auth tag actually matches. - // We use a constant time comparison to prevent timing attacks. - if actual_tag.ct_eq(expected_tag).unwrap_u8() != 1 { - return Err(Error::RtpFailedToVerifyAuthTag); - } - - // Write cipher_text to the destination buffer. - writer.extend_from_slice(cipher_text); - - // Decrypt the ciphertext for the payload. - let counter = generate_counter( - header.sequence_number, - roc, - header.ssrc, - &self.inner.srtp_session_salt, - ); - - let key = GenericArray::from_slice(&self.srtp_session_key); - let nonce = GenericArray::from_slice(&counter); - let mut stream = Aes128Ctr::new(key, nonce); - stream.seek(0); - stream.apply_keystream(&mut writer[header.marshal_size()..]); - - Ok(Bytes::from(writer)) - } - - fn encrypt_rtcp(&mut self, decrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - let mut writer = - Vec::with_capacity(decrypted.len() + SRTCP_INDEX_SIZE + self.auth_tag_len()); - - // Write the decrypted to the destination buffer. - writer.extend_from_slice(decrypted); - - // Encrypt everything after header - let counter = generate_counter( - (srtcp_index & 0xFFFF) as u16, - (srtcp_index >> 16) as u32, - ssrc, - &self.inner.srtcp_session_salt, - ); - - let key = GenericArray::from_slice(&self.srtcp_session_key); - let nonce = GenericArray::from_slice(&counter); - let mut stream = Aes128Ctr::new(key, nonce); - - stream.apply_keystream(&mut writer[HEADER_LENGTH + SSRC_LENGTH..]); - - // Add SRTCP index and set Encryption bit - writer.put_u32(srtcp_index as u32 | (1u32 << 31)); - - // Generate the auth tag. - let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.auth_tag_len()]; - writer.extend(auth_tag); - - Ok(Bytes::from(writer)) - } - - fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() + SRTCP_INDEX_SIZE { - return Err(Error::SrtcpTooSmall( - encrypted_len, - self.auth_tag_len() + SRTCP_INDEX_SIZE, - )); - } - - let tail_offset = encrypted_len - (self.auth_tag_len() + SRTCP_INDEX_SIZE); - - let mut writer = Vec::with_capacity(tail_offset); - - writer.extend_from_slice(&encrypted[0..tail_offset]); - - let is_encrypted = encrypted[tail_offset] >> 7; - if is_encrypted == 0 { - return Ok(Bytes::from(writer)); - } - - // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - if actual_tag.len() != self.auth_tag_len() { - return Err(Error::RtcpInvalidLengthAuthTag( - actual_tag.len(), - self.auth_tag_len(), - )); - } - - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; - - // Generate the auth tag we expect to see from the ciphertext. - let expected_tag = &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.auth_tag_len()]; - - // See if the auth tag actually matches. - // We use a constant time comparison to prevent timing attacks. - if actual_tag.ct_eq(expected_tag).unwrap_u8() != 1 { - return Err(Error::RtcpFailedToVerifyAuthTag); - } - - let counter = generate_counter( - (srtcp_index & 0xFFFF) as u16, - (srtcp_index >> 16) as u32, - ssrc, - &self.inner.srtcp_session_salt, - ); - - let key = GenericArray::from_slice(&self.srtcp_session_key); - let nonce = GenericArray::from_slice(&counter); - let mut stream = Aes128Ctr::new(key, nonce); - - stream.seek(0); - stream.apply_keystream(&mut writer[HEADER_LENGTH + SSRC_LENGTH..]); - - Ok(Bytes::from(writer)) - } -} diff --git a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs deleted file mode 100644 index 00bae63ae..000000000 --- a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs +++ /dev/null @@ -1,133 +0,0 @@ -use byteorder::{BigEndian, ByteOrder}; -use hmac::{Hmac, Mac}; -use sha1::Sha1; - -use super::Cipher; -use crate::error::{Error, Result}; -use crate::key_derivation::*; -use crate::protection_profile::*; - -#[cfg(not(feature = "openssl"))] -mod ctrcipher; - -#[cfg(feature = "openssl")] -mod opensslcipher; - -#[cfg(not(feature = "openssl"))] -pub(crate) use ctrcipher::CipherAesCmHmacSha1; - -#[cfg(feature = "openssl")] -pub(crate) use opensslcipher::CipherAesCmHmacSha1; - -type HmacSha1 = Hmac; - -pub const CIPHER_AES_CM_HMAC_SHA1AUTH_TAG_LEN: usize = 10; - -pub(crate) struct CipherInner { - srtp_session_salt: Vec, - srtp_session_auth: HmacSha1, - srtcp_session_salt: Vec, - srtcp_session_auth: HmacSha1, -} - -impl CipherInner { - pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let srtp_session_salt = aes_cm_key_derivation( - LABEL_SRTP_SALT, - master_key, - master_salt, - 0, - master_salt.len(), - )?; - let srtcp_session_salt = aes_cm_key_derivation( - LABEL_SRTCP_SALT, - master_key, - master_salt, - 0, - master_salt.len(), - )?; - - let auth_key_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_key_len(); - - let srtp_session_auth_tag = aes_cm_key_derivation( - LABEL_SRTP_AUTHENTICATION_TAG, - master_key, - master_salt, - 0, - auth_key_len, - )?; - let srtcp_session_auth_tag = aes_cm_key_derivation( - LABEL_SRTCP_AUTHENTICATION_TAG, - master_key, - master_salt, - 0, - auth_key_len, - )?; - - let srtp_session_auth = HmacSha1::new_from_slice(&srtp_session_auth_tag) - .map_err(|e| Error::Other(e.to_string()))?; - let srtcp_session_auth = HmacSha1::new_from_slice(&srtcp_session_auth_tag) - .map_err(|e| Error::Other(e.to_string()))?; - - Ok(Self { - srtp_session_salt, - srtp_session_auth, - srtcp_session_salt, - srtcp_session_auth, - }) - } - - /// https://tools.ietf.org/html/rfc3711#section-4.2 - /// In the case of SRTP, M SHALL consist of the Authenticated - /// Portion of the packet (as specified in Figure 1) concatenated with - /// the roc, M = Authenticated Portion || roc; - /// - /// The pre-defined authentication transform for SRTP is HMAC-SHA1 - /// [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL - /// be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to - /// the session authentication key and M as specified above, i.e., - /// HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag - /// left-most bits. - /// - Authenticated portion of the packet is everything BEFORE MKI - /// - k_a is the session message authentication key - /// - n_tag is the bit-length of the output authentication tag - fn generate_srtp_auth_tag(&self, buf: &[u8], roc: u32) -> [u8; 20] { - let mut signer = self.srtp_session_auth.clone(); - - signer.update(buf); - - // For SRTP only, we need to hash the rollover counter as well. - signer.update(&roc.to_be_bytes()); - - signer.finalize().into_bytes().into() - } - - /// https://tools.ietf.org/html/rfc3711#section-4.2 - /// - /// The pre-defined authentication transform for SRTP is HMAC-SHA1 - /// [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL - /// be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to - /// the session authentication key and M as specified above, i.e., - /// HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag - /// left-most bits. - /// - Authenticated portion of the packet is everything BEFORE MKI - /// - k_a is the session message authentication key - /// - n_tag is the bit-length of the output authentication tag - fn generate_srtcp_auth_tag(&self, buf: &[u8]) -> [u8; 20] { - let mut signer = self.srtcp_session_auth.clone(); - - signer.update(buf); - - signer.finalize().into_bytes().into() - } - - fn auth_tag_len(&self) -> usize { - CIPHER_AES_CM_HMAC_SHA1AUTH_TAG_LEN - } - - fn get_rtcp_index(&self, input: &[u8]) -> usize { - let tail_offset = input.len() - (self.auth_tag_len() + SRTCP_INDEX_SIZE); - (BigEndian::read_u32(&input[tail_offset..tail_offset + SRTCP_INDEX_SIZE]) & !(1 << 31)) - as usize - } -} diff --git a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs deleted file mode 100644 index 48ce6c7a4..000000000 --- a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs +++ /dev/null @@ -1,261 +0,0 @@ -use bytes::{BufMut, Bytes}; -use openssl::cipher_ctx::CipherCtx; -use rtcp::header::{HEADER_LENGTH, SSRC_LENGTH}; -use subtle::ConstantTimeEq; -use util::marshal::*; - -use super::{Cipher, CipherInner}; -use crate::{ - error::{Error, Result}, - key_derivation::*, -}; - -pub(crate) struct CipherAesCmHmacSha1 { - inner: CipherInner, - rtp_ctx: CipherCtx, - rtcp_ctx: CipherCtx, -} - -impl CipherAesCmHmacSha1 { - pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let inner = CipherInner::new(master_key, master_salt)?; - - let srtp_session_key = aes_cm_key_derivation( - LABEL_SRTP_ENCRYPTION, - master_key, - master_salt, - 0, - master_key.len(), - )?; - let srtcp_session_key = aes_cm_key_derivation( - LABEL_SRTCP_ENCRYPTION, - master_key, - master_salt, - 0, - master_key.len(), - )?; - - let t = openssl::cipher::Cipher::aes_128_ctr(); - let mut rtp_ctx = CipherCtx::new().map_err(|e| Error::Other(e.to_string()))?; - rtp_ctx - .encrypt_init(Some(t), Some(&srtp_session_key[..]), None) - .map_err(|e| Error::Other(e.to_string()))?; - - let t = openssl::cipher::Cipher::aes_128_ctr(); - let mut rtcp_ctx = CipherCtx::new().map_err(|e| Error::Other(e.to_string()))?; - rtcp_ctx - .encrypt_init(Some(t), Some(&srtcp_session_key[..]), None) - .map_err(|e| Error::Other(e.to_string()))?; - - Ok(Self { - inner, - rtp_ctx, - rtcp_ctx, - }) - } -} - -impl Cipher for CipherAesCmHmacSha1 { - fn auth_tag_len(&self) -> usize { - self.inner.auth_tag_len() - } - - fn get_rtcp_index(&self, input: &[u8]) -> usize { - self.inner.get_rtcp_index(input) - } - - fn encrypt_rtp( - &mut self, - plaintext: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result { - let header_len = header.marshal_size(); - let mut writer = Vec::with_capacity(plaintext.len() + self.auth_tag_len()); - - // Copy the header unencrypted. - writer.extend_from_slice(&plaintext[..header_len]); - - // Encrypt the payload - let nonce = generate_counter( - header.sequence_number, - roc, - header.ssrc, - &self.inner.srtp_session_salt, - ); - writer.resize(plaintext.len(), 0); - self.rtp_ctx.encrypt_init(None, None, Some(&nonce)).unwrap(); - let count = self - .rtp_ctx - .cipher_update(&plaintext[header_len..], Some(&mut writer[header_len..])) - .unwrap(); - self.rtp_ctx - .cipher_final(&mut writer[header_len + count..]) - .unwrap(); - - // Generate and write the auth tag. - let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.auth_tag_len()]; - writer.extend_from_slice(auth_tag); - - Ok(Bytes::from(writer)) - } - - fn decrypt_rtp( - &mut self, - encrypted: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result { - let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() { - return Err(Error::SrtpTooSmall(encrypted_len, self.auth_tag_len())); - } - let header_len = header.marshal_size(); - - let mut writer = Vec::with_capacity(encrypted_len - self.auth_tag_len()); - - // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; - - // Generate the auth tag we expect to see from the ciphertext. - let expected_tag = - &self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.auth_tag_len()]; - - // See if the auth tag actually matches. - // We use a constant time comparison to prevent timing attacks. - if actual_tag.ct_eq(expected_tag).unwrap_u8() != 1 { - return Err(Error::RtpFailedToVerifyAuthTag); - } - - // Write cipher_text to the destination buffer. - writer.extend_from_slice(&cipher_text[..header_len]); - - // Decrypt the ciphertext for the payload. - let nonce = generate_counter( - header.sequence_number, - roc, - header.ssrc, - &self.inner.srtp_session_salt, - ); - - writer.resize(encrypted_len - self.auth_tag_len(), 0); - self.rtp_ctx.decrypt_init(None, None, Some(&nonce)).unwrap(); - let count = self - .rtp_ctx - .cipher_update(&cipher_text[header_len..], Some(&mut writer[header_len..])) - .unwrap(); - self.rtp_ctx - .cipher_final(&mut writer[header_len + count..]) - .unwrap(); - - Ok(Bytes::from(writer)) - } - - fn encrypt_rtcp(&mut self, decrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - let decrypted_len = decrypted.len(); - - let mut writer = Vec::with_capacity(decrypted_len + SRTCP_INDEX_SIZE + self.auth_tag_len()); - - // Write the decrypted to the destination buffer. - writer.extend_from_slice(&decrypted[..HEADER_LENGTH + SSRC_LENGTH]); - - // Encrypt everything after header - let nonce = generate_counter( - (srtcp_index & 0xFFFF) as u16, - (srtcp_index >> 16) as u32, - ssrc, - &self.inner.srtcp_session_salt, - ); - - writer.resize(decrypted_len, 0); - self.rtcp_ctx - .encrypt_init(None, None, Some(&nonce)) - .unwrap(); - let count = self - .rtcp_ctx - .cipher_update( - &decrypted[HEADER_LENGTH + SSRC_LENGTH..], - Some(&mut writer[HEADER_LENGTH + SSRC_LENGTH..]), - ) - .unwrap(); - self.rtcp_ctx - .cipher_final(&mut writer[HEADER_LENGTH + SSRC_LENGTH + count..]) - .unwrap(); - - // Add SRTCP index and set Encryption bit - writer.put_u32(srtcp_index as u32 | (1u32 << 31)); - - // Generate the auth tag. - let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.auth_tag_len()]; - writer.extend(auth_tag); - - Ok(Bytes::from(writer)) - } - - fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - let encrypted_len = encrypted.len(); - - if encrypted_len < self.auth_tag_len() + SRTCP_INDEX_SIZE { - return Err(Error::SrtcpTooSmall( - encrypted_len, - self.auth_tag_len() + SRTCP_INDEX_SIZE, - )); - } - - let tail_offset = encrypted_len - (self.auth_tag_len() + SRTCP_INDEX_SIZE); - - let mut writer = Vec::with_capacity(tail_offset); - - writer.extend_from_slice(&encrypted[..HEADER_LENGTH + SSRC_LENGTH]); - - let is_encrypted = encrypted[tail_offset] >> 7; - if is_encrypted == 0 { - return Ok(Bytes::from(writer)); - } - - // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - if actual_tag.len() != self.auth_tag_len() { - return Err(Error::RtcpInvalidLengthAuthTag( - actual_tag.len(), - self.auth_tag_len(), - )); - } - - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; - - // Generate the auth tag we expect to see from the ciphertext. - let expected_tag = &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.auth_tag_len()]; - - // See if the auth tag actually matches. - // We use a constant time comparison to prevent timing attacks. - if actual_tag.ct_eq(expected_tag).unwrap_u8() != 1 { - return Err(Error::RtcpFailedToVerifyAuthTag); - } - - let nonce = generate_counter( - (srtcp_index & 0xFFFF) as u16, - (srtcp_index >> 16) as u32, - ssrc, - &self.inner.srtcp_session_salt, - ); - - writer.resize(tail_offset, 0); - self.rtcp_ctx - .decrypt_init(None, None, Some(&nonce)) - .unwrap(); - let count = self - .rtcp_ctx - .cipher_update( - &encrypted[HEADER_LENGTH + SSRC_LENGTH..tail_offset], - Some(&mut writer[HEADER_LENGTH + SSRC_LENGTH..]), - ) - .unwrap(); - self.rtcp_ctx - .cipher_final(&mut writer[HEADER_LENGTH + SSRC_LENGTH + count..]) - .unwrap(); - - Ok(Bytes::from(writer)) - } -} diff --git a/srtp/src/cipher/mod.rs b/srtp/src/cipher/mod.rs deleted file mode 100644 index ce3cc192b..000000000 --- a/srtp/src/cipher/mod.rs +++ /dev/null @@ -1,61 +0,0 @@ -pub mod cipher_aead_aes_gcm; -pub mod cipher_aes_cm_hmac_sha1; - -use bytes::Bytes; - -use crate::error::Result; - -///NOTE: Auth tag and AEAD auth tag are placed at the different position in SRTCP -/// -///In non-AEAD cipher, the authentication tag is placed *after* the ESRTCP word -///(Encrypted-flag and SRTCP index). -/// -///> AES_128_CM_HMAC_SHA1_80 -///> | RTCP Header | Encrypted payload |E| SRTCP Index | Auth tag | -///> ^ |----------| -///> | ^ -///> | authTagLen=10 -///> aeadAuthTagLen=0 -/// -///In AEAD cipher, the AEAD authentication tag is embedded in the ciphertext. -///It is *before* the ESRTCP word (Encrypted-flag and SRTCP index). -/// -///> AEAD_AES_128_GCM -///> | RTCP Header | Encrypted payload | AEAD auth tag |E| SRTCP Index | -///> |---------------| ^ -///> ^ authTagLen=0 -///> aeadAuthTagLen=16 -/// -///See https://tools.ietf.org/html/rfc7714 for the full specifications. - -/// Cipher represents a implementation of one -/// of the SRTP Specific ciphers. -pub(crate) trait Cipher { - /// Get authenticated tag length. - fn auth_tag_len(&self) -> usize; - - /// Retrieved RTCP index. - fn get_rtcp_index(&self, input: &[u8]) -> usize; - - /// Encrypt RTP payload. - fn encrypt_rtp( - &mut self, - payload: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result; - - /// Decrypt RTP payload. - fn decrypt_rtp( - &mut self, - payload: &[u8], - header: &rtp::header::Header, - roc: u32, - ) -> Result; - - /// Encrypt RTCP payload. - fn encrypt_rtcp(&mut self, payload: &[u8], srtcp_index: usize, ssrc: u32) -> Result; - - /// Decrypt RTCP payload. - fn decrypt_rtcp(&mut self, payload: &[u8], srtcp_index: usize, ssrc: u32) -> Result; -} diff --git a/srtp/src/config.rs b/srtp/src/config.rs deleted file mode 100644 index 8fbab5f22..000000000 --- a/srtp/src/config.rs +++ /dev/null @@ -1,83 +0,0 @@ -use util::KeyingMaterialExporter; - -use crate::error::Result; -use crate::option::*; -use crate::protection_profile::*; - -const LABEL_EXTRACTOR_DTLS_SRTP: &str = "EXTRACTOR-dtls_srtp"; - -/// SessionKeys bundles the keys required to setup an SRTP session -#[derive(Default, Debug, Clone)] -pub struct SessionKeys { - pub local_master_key: Vec, - pub local_master_salt: Vec, - pub remote_master_key: Vec, - pub remote_master_salt: Vec, -} - -/// Config is used to configure a session. -/// You can provide either a KeyingMaterialExporter to export keys -/// or directly pass the keys themselves. -/// After a Config is passed to a session it must not be modified. -#[derive(Default)] -pub struct Config { - pub keys: SessionKeys, - pub profile: ProtectionProfile, - //LoggerFactory: logging.LoggerFactory - /// List of local/remote context options. - /// ReplayProtection is enabled on remote context by default. - /// Default replay protection window size is 64. - pub local_rtp_options: Option, - pub remote_rtp_options: Option, - - pub local_rtcp_options: Option, - pub remote_rtcp_options: Option, -} - -impl Config { - /// ExtractSessionKeysFromDTLS allows setting the Config SessionKeys by - /// extracting them from DTLS. This behavior is defined in RFC5764: - /// - pub async fn extract_session_keys_from_dtls( - &mut self, - exporter: impl KeyingMaterialExporter, - is_client: bool, - ) -> Result<()> { - let key_len = self.profile.key_len(); - let salt_len = self.profile.salt_len(); - - let keying_material = exporter - .export_keying_material( - LABEL_EXTRACTOR_DTLS_SRTP, - &[], - (key_len * 2) + (salt_len * 2), - ) - .await?; - - let mut offset = 0; - let client_write_key = keying_material[offset..offset + key_len].to_vec(); - offset += key_len; - - let server_write_key = keying_material[offset..offset + key_len].to_vec(); - offset += key_len; - - let client_write_salt = keying_material[offset..offset + salt_len].to_vec(); - offset += salt_len; - - let server_write_salt = keying_material[offset..offset + salt_len].to_vec(); - - if is_client { - self.keys.local_master_key = client_write_key; - self.keys.local_master_salt = client_write_salt; - self.keys.remote_master_key = server_write_key; - self.keys.remote_master_salt = server_write_salt; - } else { - self.keys.local_master_key = server_write_key; - self.keys.local_master_salt = server_write_salt; - self.keys.remote_master_key = client_write_key; - self.keys.remote_master_salt = client_write_salt; - } - - Ok(()) - } -} diff --git a/srtp/src/context/context_test.rs b/srtp/src/context/context_test.rs deleted file mode 100644 index b2114c981..000000000 --- a/srtp/src/context/context_test.rs +++ /dev/null @@ -1,305 +0,0 @@ -use bytes::Bytes; -use lazy_static::lazy_static; - -use super::*; -use crate::key_derivation::*; - -const CIPHER_CONTEXT_ALGO: ProtectionProfile = ProtectionProfile::Aes128CmHmacSha1_80; -const DEFAULT_SSRC: u32 = 0; - -#[test] -fn test_context_roc() -> Result<()> { - let key_len = CIPHER_CONTEXT_ALGO.key_len(); - let salt_len = CIPHER_CONTEXT_ALGO.salt_len(); - - let mut c = Context::new( - &vec![0; key_len], - &vec![0; salt_len], - CIPHER_CONTEXT_ALGO, - None, - None, - )?; - - let roc = c.get_roc(123); - assert!(roc.is_none(), "ROC must return None for unused SSRC"); - - c.set_roc(123, 100); - let roc = c.get_roc(123); - if let Some(r) = roc { - assert_eq!(r, 100, "ROC is set to 100, but returned {r}") - } else { - panic!("ROC must return value for used SSRC"); - } - - Ok(()) -} - -#[test] -fn test_context_index() -> Result<()> { - let key_len = CIPHER_CONTEXT_ALGO.key_len(); - let salt_len = CIPHER_CONTEXT_ALGO.salt_len(); - - let mut c = Context::new( - &vec![0; key_len], - &vec![0; salt_len], - CIPHER_CONTEXT_ALGO, - None, - None, - )?; - - let index = c.get_index(123); - assert!(index.is_none(), "Index must return None for unused SSRC"); - - c.set_index(123, 100); - let index = c.get_index(123); - if let Some(i) = index { - assert_eq!(i, 100, "Index is set to 100, but returned {i}"); - } else { - panic!("Index must return true for used SSRC") - } - - Ok(()) -} - -#[test] -fn test_key_len() -> Result<()> { - let key_len = CIPHER_CONTEXT_ALGO.key_len(); - let salt_len = CIPHER_CONTEXT_ALGO.salt_len(); - - let result = Context::new(&[], &vec![0; salt_len], CIPHER_CONTEXT_ALGO, None, None); - assert!(result.is_err(), "CreateContext accepted a 0 length key"); - - let result = Context::new(&vec![0; key_len], &[], CIPHER_CONTEXT_ALGO, None, None); - assert!(result.is_err(), "CreateContext accepted a 0 length salt"); - - let result = Context::new( - &vec![0; key_len], - &vec![0; salt_len], - CIPHER_CONTEXT_ALGO, - None, - None, - ); - assert!( - result.is_ok(), - "CreateContext failed with a valid length key and salt" - ); - - Ok(()) -} - -#[test] -fn test_valid_packet_counter() -> Result<()> { - let master_key = vec![ - 0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, - 0x89, - ]; - let master_salt = vec![ - 0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c, - ]; - - let srtp_session_salt = aes_cm_key_derivation( - LABEL_SRTP_SALT, - &master_key, - &master_salt, - 0, - master_salt.len(), - )?; - - let s = SrtpSsrcState { - ssrc: 4160032510, - ..Default::default() - }; - let expected_counter = [ - 0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00, - 0x00, - ]; - let counter = generate_counter(32846, s.rollover_counter, s.ssrc, &srtp_session_salt); - assert_eq!( - counter, expected_counter, - "Session Key {counter:?} does not match expected {expected_counter:?}", - ); - - Ok(()) -} - -#[test] -fn test_rollover_count() -> Result<()> { - let mut s = SrtpSsrcState { - ssrc: DEFAULT_SSRC, - ..Default::default() - }; - - // Set initial seqnum - let roc = s.next_rollover_count(65530); - assert_eq!(roc, 0, "Initial rolloverCounter must be 0"); - s.update_rollover_count(65530); - - // Invalid packets never update ROC - s.next_rollover_count(0); - s.next_rollover_count(0x4000); - s.next_rollover_count(0x8000); - s.next_rollover_count(0xFFFF); - s.next_rollover_count(0); - - // We rolled over to 0 - let roc = s.next_rollover_count(0); - assert_eq!(roc, 1, "rolloverCounter was not updated after it crossed 0"); - s.update_rollover_count(0); - - let roc = s.next_rollover_count(65530); - assert_eq!( - roc, 0, - "rolloverCounter was not updated when it rolled back, failed to handle out of order" - ); - s.update_rollover_count(65530); - - let roc = s.next_rollover_count(5); - assert_eq!( - roc, 1, - "rolloverCounter was not updated when it rolled over initial, to handle out of order" - ); - s.update_rollover_count(5); - - s.next_rollover_count(6); - s.update_rollover_count(6); - - s.next_rollover_count(7); - s.update_rollover_count(7); - - let roc = s.next_rollover_count(8); - assert_eq!( - roc, 1, - "rolloverCounter was improperly updated for non-significant packets" - ); - s.update_rollover_count(8); - - // valid packets never update ROC - let roc = s.next_rollover_count(0x4000); - assert_eq!( - roc, 1, - "rolloverCounter was improperly updated for non-significant packets" - ); - s.update_rollover_count(0x4000); - - let roc = s.next_rollover_count(0x8000); - assert_eq!( - roc, 1, - "rolloverCounter was improperly updated for non-significant packets" - ); - s.update_rollover_count(0x8000); - - let roc = s.next_rollover_count(0xFFFF); - assert_eq!( - roc, 1, - "rolloverCounter was improperly updated for non-significant packets" - ); - s.update_rollover_count(0xFFFF); - - let roc = s.next_rollover_count(0); - assert_eq!( - roc, 2, - "rolloverCounter must be incremented after wrapping, got {roc}" - ); - - Ok(()) -} - -lazy_static! { - static ref MASTER_KEY: Bytes = Bytes::from_static(&[ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - 0x0f, - ]); - static ref MASTER_SALT: Bytes = Bytes::from_static(&[ - 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, - ]); - static ref DECRYPTED_RTP_PACKET: Bytes = Bytes::from_static(&[ - 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, - 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, - ]); - static ref ENCRYPTED_RTP_PACKET: Bytes = Bytes::from_static(&[ - 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e, - 0xde, 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97, - 0x68, 0x26, 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, 0x27, 0xe8, 0xa3, 0x92, - ]); - static ref DECRYPTED_RTCP_PACKET: Bytes = Bytes::from_static(&[ - 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, - 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, - ]); - static ref ENCRYPTED_RTCP_PACKET: Bytes = Bytes::from_static(&[ - 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, - 0x55, 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, - 0xb4, 0x46, 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01, - ]); -} - -#[test] -fn test_encrypt_rtp() { - let mut ctx = Context::new( - &MASTER_KEY, - &MASTER_SALT, - ProtectionProfile::AeadAes128Gcm, - None, - None, - ) - .expect("Error creating srtp context"); - - let gotten_encrypted_rtp_packet = ctx - .encrypt_rtp(&DECRYPTED_RTP_PACKET) - .expect("Error encrypting rtp payload"); - - assert_eq!(gotten_encrypted_rtp_packet, *ENCRYPTED_RTP_PACKET) -} - -#[test] -fn test_decrypt_rtp() { - let mut ctx = Context::new( - &MASTER_KEY, - &MASTER_SALT, - ProtectionProfile::AeadAes128Gcm, - None, - None, - ) - .expect("Error creating srtp context"); - - let gotten_decrypted_rtp_packet = ctx - .decrypt_rtp(&ENCRYPTED_RTP_PACKET) - .expect("Error decrypting rtp payload"); - - assert_eq!(gotten_decrypted_rtp_packet, *DECRYPTED_RTP_PACKET) -} - -#[test] -fn test_encrypt_rtcp() { - let mut ctx = Context::new( - &MASTER_KEY, - &MASTER_SALT, - ProtectionProfile::AeadAes128Gcm, - None, - None, - ) - .expect("Error creating srtp context"); - - let gotten_encrypted_rtcp_packet = ctx - .encrypt_rtcp(&DECRYPTED_RTCP_PACKET) - .expect("Error encrypting rtcp payload"); - - assert_eq!(gotten_encrypted_rtcp_packet, *ENCRYPTED_RTCP_PACKET) -} - -#[test] -fn test_decrypt_rtcp() { - let mut ctx = Context::new( - &MASTER_KEY, - &MASTER_SALT, - ProtectionProfile::AeadAes128Gcm, - None, - None, - ) - .expect("Error creating srtp context"); - - let gotten_decrypted_rtcp_packet = ctx - .decrypt_rtcp(&ENCRYPTED_RTCP_PACKET) - .expect("Error decrypting rtcp payload"); - - assert_eq!(gotten_decrypted_rtcp_packet, *DECRYPTED_RTCP_PACKET) -} diff --git a/srtp/src/context/mod.rs b/srtp/src/context/mod.rs deleted file mode 100644 index ae3f3b1ce..000000000 --- a/srtp/src/context/mod.rs +++ /dev/null @@ -1,201 +0,0 @@ -#[cfg(test)] -mod context_test; -#[cfg(test)] -mod srtcp_test; -#[cfg(test)] -mod srtp_test; - -use std::collections::HashMap; - -use util::replay_detector::*; - -use crate::cipher::cipher_aead_aes_gcm::*; -use crate::cipher::cipher_aes_cm_hmac_sha1::*; -use crate::cipher::*; -use crate::error::{Error, Result}; -use crate::option::*; -use crate::protection_profile::*; - -pub mod srtcp; -pub mod srtp; - -const MAX_ROC_DISORDER: u16 = 100; - -/// Encrypt/Decrypt state for a single SRTP SSRC -#[derive(Default)] -pub(crate) struct SrtpSsrcState { - ssrc: u32, - rollover_counter: u32, - rollover_has_processed: bool, - last_sequence_number: u16, - replay_detector: Option>, -} - -/// Encrypt/Decrypt state for a single SRTCP SSRC -#[derive(Default)] -pub(crate) struct SrtcpSsrcState { - srtcp_index: usize, - ssrc: u32, - replay_detector: Option>, -} - -impl SrtpSsrcState { - pub fn next_rollover_count(&self, sequence_number: u16) -> u32 { - let mut roc = self.rollover_counter; - - if !self.rollover_has_processed { - } else if sequence_number == 0 { - // We exactly hit the rollover count - - // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER - // otherwise we already incremented for disorder - if self.last_sequence_number > MAX_ROC_DISORDER { - roc += 1; - } - } else if self.last_sequence_number < MAX_ROC_DISORDER - && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) - { - // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max - // So we fell behind, drop to account for jitter - roc -= 1; - } else if sequence_number < MAX_ROC_DISORDER - && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) - { - // our current is within a MAX_ROCDISORDER of 0 - // and our last sequence number was a high sequence number, increment to account for jitter - roc += 1; - } - - roc - } - - /// https://tools.ietf.org/html/rfc3550#appendix-A.1 - pub fn update_rollover_count(&mut self, sequence_number: u16) { - if !self.rollover_has_processed { - self.rollover_has_processed = true; - } else if sequence_number == 0 { - // We exactly hit the rollover count - - // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER - // otherwise we already incremented for disorder - if self.last_sequence_number > MAX_ROC_DISORDER { - self.rollover_counter += 1; - } - } else if self.last_sequence_number < MAX_ROC_DISORDER - && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) - { - // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max - // So we fell behind, drop to account for jitter - self.rollover_counter -= 1; - } else if sequence_number < MAX_ROC_DISORDER - && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) - { - // our current is within a MAX_ROCDISORDER of 0 - // and our last sequence number was a high sequence number, increment to account for jitter - self.rollover_counter += 1; - } - self.last_sequence_number = sequence_number; - } -} - -/// Context represents a SRTP cryptographic context -/// Context can only be used for one-way operations -/// it must either used ONLY for encryption or ONLY for decryption -pub struct Context { - cipher: Box, - - srtp_ssrc_states: HashMap, - srtcp_ssrc_states: HashMap, - - new_srtp_replay_detector: ContextOption, - new_srtcp_replay_detector: ContextOption, -} - -impl Context { - /// CreateContext creates a new SRTP Context - pub fn new( - master_key: &[u8], - master_salt: &[u8], - profile: ProtectionProfile, - srtp_ctx_opt: Option, - srtcp_ctx_opt: Option, - ) -> Result { - let key_len = profile.key_len(); - let salt_len = profile.salt_len(); - - if master_key.len() != key_len { - return Err(Error::SrtpMasterKeyLength(key_len, master_key.len())); - } else if master_salt.len() != salt_len { - return Err(Error::SrtpSaltLength(salt_len, master_salt.len())); - } - - let cipher: Box = match profile { - ProtectionProfile::Aes128CmHmacSha1_80 => { - Box::new(CipherAesCmHmacSha1::new(master_key, master_salt)?) - } - - ProtectionProfile::AeadAes128Gcm => { - Box::new(CipherAeadAesGcm::new(master_key, master_salt)?) - } - }; - - let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt { - ctx_opt - } else { - srtp_no_replay_protection() - }; - - let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt { - ctx_opt - } else { - srtcp_no_replay_protection() - }; - - Ok(Context { - cipher, - srtp_ssrc_states: HashMap::new(), - srtcp_ssrc_states: HashMap::new(), - new_srtp_replay_detector: srtp_ctx_opt, - new_srtcp_replay_detector: srtcp_ctx_opt, - }) - } - - fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtpSsrcState { - let s = SrtpSsrcState { - ssrc, - replay_detector: Some((self.new_srtp_replay_detector)()), - ..Default::default() - }; - - self.srtp_ssrc_states.entry(ssrc).or_insert(s) - } - - fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState { - let s = SrtcpSsrcState { - ssrc, - replay_detector: Some((self.new_srtcp_replay_detector)()), - ..Default::default() - }; - self.srtcp_ssrc_states.entry(ssrc).or_insert(s) - } - - /// roc returns SRTP rollover counter value of specified SSRC. - fn get_roc(&self, ssrc: u32) -> Option { - self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter) - } - - /// set_roc sets SRTP rollover counter value of specified SSRC. - fn set_roc(&mut self, ssrc: u32, roc: u32) { - self.get_srtp_ssrc_state(ssrc).rollover_counter = roc; - } - - /// index returns SRTCP index value of specified SSRC. - fn get_index(&self, ssrc: u32) -> Option { - self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index) - } - - /// set_index sets SRTCP index value of specified SSRC. - fn set_index(&mut self, ssrc: u32, index: usize) { - self.get_srtcp_ssrc_state(ssrc).srtcp_index = index; - } -} diff --git a/srtp/src/context/srtcp.rs b/srtp/src/context/srtcp.rs deleted file mode 100644 index 8bc54fda1..000000000 --- a/srtp/src/context/srtcp.rs +++ /dev/null @@ -1,50 +0,0 @@ -use bytes::Bytes; -use util::marshal::*; - -use super::*; -use crate::error::Result; - -impl Context { - /// DecryptRTCP decrypts a RTCP packet with an encrypted payload - pub fn decrypt_rtcp(&mut self, encrypted: &[u8]) -> Result { - let mut buf = encrypted; - rtcp::header::Header::unmarshal(&mut buf)?; - - let index = self.cipher.get_rtcp_index(encrypted); - let ssrc = u32::from_be_bytes([encrypted[4], encrypted[5], encrypted[6], encrypted[7]]); - - if let Some(replay_detector) = &mut self.get_srtcp_ssrc_state(ssrc).replay_detector { - if !replay_detector.check(index as u64) { - return Err(Error::SrtcpSsrcDuplicated(ssrc, index)); - } - } - - let dst = self.cipher.decrypt_rtcp(encrypted, index, ssrc)?; - - if let Some(replay_detector) = &mut self.get_srtcp_ssrc_state(ssrc).replay_detector { - replay_detector.accept(); - } - - Ok(dst) - } - - /// EncryptRTCP marshals and encrypts an RTCP packet, writing to the dst buffer provided. - /// If the dst buffer does not have the capacity to hold `len(plaintext) + 14` bytes, a new one will be allocated and returned. - pub fn encrypt_rtcp(&mut self, decrypted: &[u8]) -> Result { - let mut buf = decrypted; - rtcp::header::Header::unmarshal(&mut buf)?; - - let ssrc = u32::from_be_bytes([decrypted[4], decrypted[5], decrypted[6], decrypted[7]]); - - let index = { - let state = self.get_srtcp_ssrc_state(ssrc); - state.srtcp_index += 1; - if state.srtcp_index > MAX_SRTCP_INDEX { - state.srtcp_index = 0; - } - state.srtcp_index - }; - - self.cipher.encrypt_rtcp(decrypted, index, ssrc) - } -} diff --git a/srtp/src/context/srtcp_test.rs b/srtp/src/context/srtcp_test.rs deleted file mode 100644 index 55d772258..000000000 --- a/srtp/src/context/srtcp_test.rs +++ /dev/null @@ -1,257 +0,0 @@ -use bytes::{Buf, Bytes, BytesMut}; -use lazy_static::lazy_static; - -use super::*; -use crate::key_derivation::*; - -pub struct RTCPTestCase { - ssrc: u32, - index: usize, - encrypted: Bytes, - decrypted: Bytes, -} - -lazy_static! { - static ref RTCP_TEST_MASTER_KEY: Bytes = Bytes::from_static(&[ - 0xfd, 0xa6, 0x25, 0x95, 0xd7, 0xf6, 0x92, 0x6f, 0x7d, 0x9c, 0x02, 0x4c, 0xc9, 0x20, 0x9f, - 0x34 - ]); - - static ref RTCP_TEST_MASTER_SALT: Bytes = Bytes::from_static(&[ - 0xa9, 0x65, 0x19, 0x85, 0x54, 0x0b, 0x47, 0xbe, 0x2f, 0x27, 0xa8, 0xb8, 0x81, 0x23 - ]); - - static ref RTCP_TEST_CASES: Vec = vec![ - RTCPTestCase { - ssrc: 0x66ef91ff, - index: 0, - encrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0xcd, 0x34, 0xc5, 0x78, 0xb2, 0x8b, - 0xe1, 0x6b, 0xc5, 0x09, 0xd5, 0x77, 0xe4, 0xce, 0x5f, 0x20, 0x80, 0x21, 0xbd, 0x66, - 0x74, 0x65, 0xe9, 0x5f, 0x49, 0xe5, 0xf5, 0xc0, 0x68, 0x4e, 0xe5, 0x6a, 0x78, 0x07, - 0x75, 0x46, 0xed, 0x90, 0xf6, 0xdc, 0x9d, 0xef, 0x3b, 0xdf, 0xf2, 0x79, 0xa9, 0xd8, - 0x80, 0x00, 0x00, 0x01, 0x60, 0xc0, 0xae, 0xb5, 0x6f, 0x40, 0x88, 0x0e, 0x28, 0xba - ]), - decrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0xdf, 0x48, 0x80, 0xdd, 0x61, 0xa6, - 0x2e, 0xd3, 0xd8, 0xbc, 0xde, 0xbe, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x16, 0x04, - 0x81, 0xca, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0x01, 0x10, 0x52, 0x6e, 0x54, 0x35, - 0x43, 0x6d, 0x4a, 0x68, 0x7a, 0x79, 0x65, 0x74, 0x41, 0x78, 0x77, 0x2b, 0x00, 0x00 - ]), - }, - RTCPTestCase{ - ssrc: 0x11111111, - index: 0, - encrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, - 0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46, - 0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38, - 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, - 0x80, 0x00, 0x00, 0x01, 0x34, 0x3c, 0x2e, 0x83, 0x17, 0x13, 0x93, 0x69, 0xcf, 0xc0 - ]), - decrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0xdf, 0x48, 0x80, 0xdd, 0x61, 0xa6, - 0x2e, 0xd3, 0xd8, 0xbc, 0xde, 0xbe, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x16, 0x04, - 0x81, 0xca, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0x01, 0x10, 0x52, 0x6e, 0x54, 0x35, - 0x43, 0x6d, 0x4a, 0x68, 0x7a, 0x79, 0x65, 0x74, 0x41, 0x78, 0x77, 0x2b, 0x00, 0x00 - ]), - }, - RTCPTestCase{ - ssrc: 0x11111111, - index: 0x7ffffffe, // Upper boundary of index - encrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, - 0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46, - 0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38, - 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, - 0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed, 0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97 - ]), - decrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x0, 0x6, 0x11, 0x11, 0x11, 0x11, 0x4, 0x99, 0x47, 0x53, 0xc4, 0x1e, - 0xb9, 0xde, 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, 0xbb, 0x6a, 0x29, 0xb8, - 0x1, 0xb7, 0x2e, 0x4b, 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x4, 0x5e, 0x86, 0x90, - 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x1, 0x7d - ]), - }, - RTCPTestCase{ - ssrc: 0x11111111, - index: 0x7fffffff, // Will be wrapped to 0 - encrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, - 0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46, - 0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38, - 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, - 0x80, 0x00, 0x00, 0x00, 0x7d, 0x51, 0xf8, 0x0e, 0x56, 0x40, 0x72, 0x7b, 0x9e, 0x02 - ]), - decrypted: Bytes::from_static(&[ - 0x80, 0xc8, 0x0, 0x6, 0x11, 0x11, 0x11, 0x11, 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, - 0x74, 0xed, 0x8a, 0x54, 0xc, 0xcf, 0xd5, 0x9, 0xb1, 0x40, 0x1, 0x42, 0xc3, 0x9a, - 0x76, 0x0, 0xa9, 0xd4, 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, 0x72, 0xf9, - 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28 - ]), - }, - ]; -} - -#[test] -fn test_rtcp_lifecycle() -> Result<()> { - let mut encrypt_context = Context::new( - &RTCP_TEST_MASTER_KEY, - &RTCP_TEST_MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - )?; - let mut decrypt_context = Context::new( - &RTCP_TEST_MASTER_KEY, - &RTCP_TEST_MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - )?; - - for test_case in &*RTCP_TEST_CASES { - let decrypt_result = decrypt_context.decrypt_rtcp(&test_case.encrypted)?; - assert_eq!( - decrypt_result, test_case.decrypted, - "RTCP failed to decrypt" - ); - - encrypt_context.set_index(test_case.ssrc, test_case.index); - let encrypt_result = encrypt_context.encrypt_rtcp(&test_case.decrypted)?; - assert_eq!( - encrypt_result, test_case.encrypted, - "RTCP failed to encrypt" - ); - } - - Ok(()) -} - -#[test] -fn test_rtcp_invalid_auth_tag() -> Result<()> { - let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); - - let mut decrypt_context = Context::new( - &RTCP_TEST_MASTER_KEY, - &RTCP_TEST_MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - )?; - - let decrypt_result = decrypt_context.decrypt_rtcp(&RTCP_TEST_CASES[0].encrypted)?; - assert_eq!( - decrypt_result, RTCP_TEST_CASES[0].decrypted, - "RTCP failed to decrypt" - ); - - // Zero out auth tag - let mut rtcp_packet = BytesMut::new(); - rtcp_packet.extend_from_slice(&RTCP_TEST_CASES[0].encrypted); - let rtcp_packet_len = rtcp_packet.len(); - rtcp_packet[rtcp_packet_len - auth_tag_len..].copy_from_slice(&vec![0; auth_tag_len]); - let rtcp_packet = rtcp_packet.freeze(); - let decrypt_result = decrypt_context.decrypt_rtcp(&rtcp_packet); - assert!( - decrypt_result.is_err(), - "Was able to decrypt RTCP packet with invalid Auth Tag" - ); - - Ok(()) -} - -#[test] -fn test_rtcp_replay_detector_separation() -> Result<()> { - let mut decrypt_context = Context::new( - &RTCP_TEST_MASTER_KEY, - &RTCP_TEST_MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - Some(srtcp_replay_protection(10)), - )?; - - let rtcp_packet1 = RTCP_TEST_CASES[0].encrypted.clone(); - let decrypt_result1 = decrypt_context.decrypt_rtcp(&rtcp_packet1)?; - assert_eq!( - decrypt_result1, RTCP_TEST_CASES[0].decrypted, - "RTCP failed to decrypt" - ); - - let rtcp_packet2 = RTCP_TEST_CASES[1].encrypted.clone(); - let decrypt_result2 = decrypt_context.decrypt_rtcp(&rtcp_packet2)?; - assert_eq!( - decrypt_result2, RTCP_TEST_CASES[1].decrypted, - "RTCP failed to decrypt" - ); - - let result = decrypt_context.decrypt_rtcp(&rtcp_packet1); - assert!( - result.is_err(), - "Was able to decrypt duplicated RTCP packet" - ); - - let result = decrypt_context.decrypt_rtcp(&rtcp_packet2); - assert!( - result.is_err(), - "Was able to decrypt duplicated RTCP packet" - ); - - Ok(()) -} - -fn get_rtcp_index(encrypted: &Bytes, auth_tag_len: usize) -> u32 { - let tail_offset = encrypted.len() - (auth_tag_len + SRTCP_INDEX_SIZE); - let reader = &mut encrypted.slice(tail_offset..tail_offset + SRTCP_INDEX_SIZE); - //^(1 << 31) - reader.get_u32() & 0x7FFFFFFF -} - -#[test] -fn test_encrypt_rtcp_separation() -> Result<()> { - let mut encrypt_context = Context::new( - &RTCP_TEST_MASTER_KEY, - &RTCP_TEST_MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - )?; - - let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); - - let mut decrypt_context = Context::new( - &RTCP_TEST_MASTER_KEY, - &RTCP_TEST_MASTER_SALT, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - Some(srtcp_replay_protection(10)), - )?; - - let inputs = vec![ - RTCP_TEST_CASES[0].decrypted.clone(), - RTCP_TEST_CASES[1].decrypted.clone(), - RTCP_TEST_CASES[0].decrypted.clone(), - RTCP_TEST_CASES[1].decrypted.clone(), - ]; - let mut encrypted_rctps = vec![]; - - for input in &inputs { - let encrypted = encrypt_context.encrypt_rtcp(input)?; - encrypted_rctps.push(encrypted); - } - - for (i, expected_index) in [1, 1, 2, 2].iter().enumerate() { - assert_eq!( - *expected_index, - get_rtcp_index(&encrypted_rctps[i], auth_tag_len), - "RTCP index does not match" - ); - } - - for (i, output) in encrypted_rctps.iter().enumerate() { - let decrypted = decrypt_context.decrypt_rtcp(output)?; - assert_eq!(inputs[i], decrypted); - } - - Ok(()) -} diff --git a/srtp/src/context/srtp.rs b/srtp/src/context/srtp.rs deleted file mode 100644 index aaf0d9931..000000000 --- a/srtp/src/context/srtp.rs +++ /dev/null @@ -1,70 +0,0 @@ -use bytes::Bytes; -use util::marshal::*; - -use super::*; -use crate::error::Result; - -impl Context { - pub fn decrypt_rtp_with_header( - &mut self, - encrypted: &[u8], - header: &rtp::header::Header, - ) -> Result { - let roc = { - let state = self.get_srtp_ssrc_state(header.ssrc); - if let Some(replay_detector) = &mut state.replay_detector { - if !replay_detector.check(header.sequence_number as u64) { - return Err(Error::SrtpSsrcDuplicated( - header.ssrc, - header.sequence_number, - )); - } - } - - state.next_rollover_count(header.sequence_number) - }; - - let dst = self.cipher.decrypt_rtp(encrypted, header, roc)?; - { - let state = self.get_srtp_ssrc_state(header.ssrc); - if let Some(replay_detector) = &mut state.replay_detector { - replay_detector.accept(); - } - state.update_rollover_count(header.sequence_number); - } - - Ok(dst) - } - - /// DecryptRTP decrypts a RTP packet with an encrypted payload - pub fn decrypt_rtp(&mut self, encrypted: &[u8]) -> Result { - let mut buf = encrypted; - let header = rtp::header::Header::unmarshal(&mut buf)?; - self.decrypt_rtp_with_header(encrypted, &header) - } - - pub fn encrypt_rtp_with_header( - &mut self, - payload: &[u8], - header: &rtp::header::Header, - ) -> Result { - let roc = self - .get_srtp_ssrc_state(header.ssrc) - .next_rollover_count(header.sequence_number); - - let dst = self.cipher.encrypt_rtp(payload, header, roc)?; - - self.get_srtp_ssrc_state(header.ssrc) - .update_rollover_count(header.sequence_number); - - Ok(dst) - } - - /// EncryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. - /// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, a new one will be allocated and returned. - pub fn encrypt_rtp(&mut self, plaintext: &[u8]) -> Result { - let mut buf = plaintext; - let header = rtp::header::Header::unmarshal(&mut buf)?; - self.encrypt_rtp_with_header(plaintext, &header) - } -} diff --git a/srtp/src/context/srtp_test.rs b/srtp/src/context/srtp_test.rs deleted file mode 100644 index 8e5e5083a..000000000 --- a/srtp/src/context/srtp_test.rs +++ /dev/null @@ -1,172 +0,0 @@ -use bytes::Bytes; -use lazy_static::lazy_static; -use util::marshal::*; - -use super::*; - -struct RTPTestCase { - sequence_number: u16, - encrypted: Bytes, -} - -lazy_static! { - static ref RTP_TEST_CASE_DECRYPTED: Bytes = Bytes::from_static(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05]); - static ref RTP_TEST_CASES: Vec = vec![ - RTPTestCase { - sequence_number: 5000, - encrypted: Bytes::from_static(&[ - 0x6d, 0xd3, 0x7e, 0xd5, 0x99, 0xb7, 0x2d, 0x28, 0xb1, 0xf3, 0xa1, 0xf0, 0xc, 0xfb, - 0xfd, 0x8 - ]), - }, - RTPTestCase { - sequence_number: 5001, - encrypted: Bytes::from_static(&[ - 0xda, 0x47, 0xb, 0x2a, 0x74, 0x53, 0x65, 0xbd, 0x2f, 0xeb, 0xdc, 0x4b, 0x6d, 0x23, - 0xf3, 0xde - ]), - }, - RTPTestCase { - sequence_number: 5002, - encrypted: Bytes::from_static(&[ - 0x6e, 0xa7, 0x69, 0x8d, 0x24, 0x6d, 0xdc, 0xbf, 0xec, 0x2, 0x1c, 0xd1, 0x60, 0x76, - 0xc1, 0x0e - ]), - }, - RTPTestCase { - sequence_number: 5003, - encrypted: Bytes::from_static(&[ - 0x24, 0x7e, 0x96, 0xc8, 0x7d, 0x33, 0xa2, 0x92, 0x8d, 0x13, 0x8d, 0xe0, 0x76, 0x9f, - 0x08, 0xdc - ]), - }, - RTPTestCase { - sequence_number: 5004, - encrypted: Bytes::from_static(&[ - 0x75, 0x43, 0x28, 0xe4, 0x3a, 0x77, 0x59, 0x9b, 0x2e, 0xdf, 0x7b, 0x12, 0x68, 0x0b, - 0x57, 0x49 - ]), - }, - RTPTestCase{ - sequence_number: 65535, // upper boundary - encrypted: Bytes::from_static(&[ - 0xaf, 0xf7, 0xc2, 0x70, 0x37, 0x20, 0x83, 0x9c, 0x2c, 0x63, 0x85, 0x15, 0x0e, 0x44, - 0xca, 0x36 - ]), - }, - ]; -} - -fn build_test_context() -> Result { - let master_key = Bytes::from_static(&[ - 0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, - 0x89, - ]); - let master_salt = Bytes::from_static(&[ - 0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c, - ]); - - Context::new( - &master_key, - &master_salt, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - ) -} - -#[test] -fn test_rtp_invalid_auth() -> Result<()> { - let master_key = Bytes::from_static(&[ - 0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, - 0x89, - ]); - let invalid_salt = Bytes::from_static(&[ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]); - - let mut encrypt_context = build_test_context()?; - let mut invalid_context = Context::new( - &master_key, - &invalid_salt, - ProtectionProfile::Aes128CmHmacSha1_80, - None, - None, - )?; - - for test_case in &*RTP_TEST_CASES { - let pkt = rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: test_case.sequence_number, - ..Default::default() - }, - payload: RTP_TEST_CASE_DECRYPTED.clone(), - }; - - let pkt_raw = pkt.marshal()?; - let out = encrypt_context.encrypt_rtp(&pkt_raw)?; - - let result = invalid_context.decrypt_rtp(&out); - assert!( - result.is_err(), - "Managed to decrypt with incorrect salt for packet with SeqNum: {}", - test_case.sequence_number - ); - } - - Ok(()) -} - -#[test] -fn test_rtp_lifecycle() -> Result<()> { - let mut encrypt_context = build_test_context()?; - let mut decrypt_context = build_test_context()?; - let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); - - for test_case in RTP_TEST_CASES.iter() { - let decrypted_pkt = rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: test_case.sequence_number, - ..Default::default() - }, - payload: RTP_TEST_CASE_DECRYPTED.clone(), - }; - - let decrypted_raw = decrypted_pkt.marshal()?; - - let encrypted_pkt = rtp::packet::Packet { - header: rtp::header::Header { - sequence_number: test_case.sequence_number, - ..Default::default() - }, - payload: test_case.encrypted.clone(), - }; - - let encrypted_raw = encrypted_pkt.marshal()?; - let actual_encrypted = encrypt_context.encrypt_rtp(&decrypted_raw)?; - assert_eq!( - actual_encrypted, encrypted_raw, - "RTP packet with SeqNum invalid encryption: {}", - test_case.sequence_number - ); - - let actual_decrypted = decrypt_context.decrypt_rtp(&encrypted_raw)?; - assert_ne!( - encrypted_raw[..encrypted_raw.len() - auth_tag_len].to_vec(), - actual_decrypted, - "DecryptRTP improperly encrypted in place" - ); - - assert_eq!( - actual_decrypted, decrypted_raw, - "RTP packet with SeqNum invalid decryption: {}", - test_case.sequence_number, - ) - } - - Ok(()) -} - -//TODO: BenchmarkEncryptRTP -//TODO: BenchmarkEncryptRTPInPlace -//TODO: BenchmarkDecryptRTP diff --git a/srtp/src/error.rs b/srtp/src/error.rs deleted file mode 100644 index b4f1a0346..000000000 --- a/srtp/src/error.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::io; - -use thiserror::Error; -use tokio::sync::mpsc::error::SendError as MpscSendError; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("duplicated packet")] - ErrDuplicated, - #[error("SRTP master key is not long enough")] - ErrShortSrtpMasterKey, - #[error("SRTP master salt is not long enough")] - ErrShortSrtpMasterSalt, - #[error("no such SRTP Profile")] - ErrNoSuchSrtpProfile, - #[error("indexOverKdr > 0 is not supported yet")] - ErrNonZeroKdrNotSupported, - #[error("exporter called with wrong label")] - ErrExporterWrongLabel, - #[error("no config provided")] - ErrNoConfig, - #[error("no conn provided")] - ErrNoConn, - #[error("failed to verify auth tag")] - ErrFailedToVerifyAuthTag, - #[error("packet is too short to be rtcp packet")] - ErrTooShortRtcp, - #[error("payload differs")] - ErrPayloadDiffers, - #[error("started channel used incorrectly, should only be closed")] - ErrStartedChannelUsedIncorrectly, - #[error("stream has not been inited, unable to close")] - ErrStreamNotInited, - #[error("stream is already closed")] - ErrStreamAlreadyClosed, - #[error("stream is already inited")] - ErrStreamAlreadyInited, - #[error("failed to cast child")] - ErrFailedTypeAssertion, - - #[error("index_over_kdr > 0 is not supported yet")] - UnsupportedIndexOverKdr, - #[error("SRTP Master Key must be len {0}, got {1}")] - SrtpMasterKeyLength(usize, usize), - #[error("SRTP Salt must be len {0}, got {1}")] - SrtpSaltLength(usize, usize), - #[error("SyntaxError: {0}")] - ExtMapParse(String), - #[error("srtp ssrc={0} index={1}: duplicated")] - SrtpSsrcDuplicated(u32, u16), - #[error("srtcp ssrc={0} index={1}: duplicated")] - SrtcpSsrcDuplicated(u32, usize), - #[error("ssrc {0} not exist in srtcp_ssrc_state")] - SsrcMissingFromSrtcp(u32), - #[error("Stream with ssrc {0} exists")] - StreamWithSsrcExists(u32), - #[error("Session RTP/RTCP type must be same as input buffer")] - SessionRtpRtcpTypeMismatch, - #[error("Session EOF")] - SessionEof, - #[error("too short SRTP packet: only {0} bytes, expected > {1} bytes")] - SrtpTooSmall(usize, usize), - #[error("too short SRTCP packet: only {0} bytes, expected > {1} bytes")] - SrtcpTooSmall(usize, usize), - #[error("failed to verify rtp auth tag")] - RtpFailedToVerifyAuthTag, - #[error("too short auth tag: only {0} bytes, expected > {1} bytes")] - RtcpInvalidLengthAuthTag(usize, usize), - #[error("failed to verify rtcp auth tag")] - RtcpFailedToVerifyAuthTag, - #[error("SessionSRTP has been closed")] - SessionSrtpAlreadyClosed, - #[error("this stream is not a RTPStream")] - InvalidRtpStream, - #[error("this stream is not a RTCPStream")] - InvalidRtcpStream, - - #[error("{0}")] - Io(#[source] IoError), - #[error("{0}")] - KeyingMaterial(#[from] util::KeyingMaterialExporterError), - #[error("mpsc send: {0}")] - MpscSend(String), - #[error("{0}")] - Util(#[from] util::Error), - #[error("{0}")] - Rtcp(#[from] rtcp::Error), - #[error("aes gcm: {0}")] - AesGcm(#[from] aes_gcm::Error), - - #[error("{0}")] - Other(String), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} - -// Because Tokio SendError is parameterized, we sadly lose the backtrace. -impl From> for Error { - fn from(e: MpscSendError) -> Self { - Error::MpscSend(e.to_string()) - } -} diff --git a/srtp/src/key_derivation.rs b/srtp/src/key_derivation.rs deleted file mode 100644 index 343450937..000000000 --- a/srtp/src/key_derivation.rs +++ /dev/null @@ -1,173 +0,0 @@ -use aes::cipher::generic_array::GenericArray; -use aes::cipher::BlockEncrypt; -use aes::Aes128; -use aes_gcm::KeyInit; - -use crate::error::{Error, Result}; - -pub const LABEL_SRTP_ENCRYPTION: u8 = 0x00; -pub const LABEL_SRTP_AUTHENTICATION_TAG: u8 = 0x01; -pub const LABEL_SRTP_SALT: u8 = 0x02; -pub const LABEL_SRTCP_ENCRYPTION: u8 = 0x03; -pub const LABEL_SRTCP_AUTHENTICATION_TAG: u8 = 0x04; -pub const LABEL_SRTCP_SALT: u8 = 0x05; - -pub(crate) const SRTCP_INDEX_SIZE: usize = 4; - -pub(crate) fn aes_cm_key_derivation( - label: u8, - master_key: &[u8], - master_salt: &[u8], - index_over_kdr: usize, - out_len: usize, -) -> Result> { - if index_over_kdr != 0 { - // 24-bit "index DIV kdr" must be xored to prf input. - return Err(Error::UnsupportedIndexOverKdr); - } - - // https://tools.ietf.org/html/rfc3711#appendix-B.3 - // The input block for AES-CM is generated by exclusive-oring the master salt with the - // concatenation of the encryption key label 0x00 with (index DIV kdr), - // - index is 'rollover count' and DIV is 'divided by' - - let n_master_key = master_key.len(); - let n_master_salt = master_salt.len(); - - let mut prf_in = vec![0u8; n_master_key]; - prf_in[..n_master_salt].copy_from_slice(master_salt); - - prf_in[7] ^= label; - - //The resulting value is then AES encrypted using the master key to get the cipher key. - let key = GenericArray::from_slice(master_key); - let block = Aes128::new(key); - - let mut out = vec![0u8; ((out_len + n_master_key) / n_master_key) * n_master_key]; - for (i, n) in (0..out_len).step_by(n_master_key).enumerate() { - //BigEndian.PutUint16(prfIn[nMasterKey-2:], i) - prf_in[n_master_key - 2] = ((i >> 8) & 0xFF) as u8; - prf_in[n_master_key - 1] = (i & 0xFF) as u8; - - out[n..n + n_master_key].copy_from_slice(&prf_in); - let out_key = GenericArray::from_mut_slice(&mut out[n..n + n_master_key]); - block.encrypt_block(out_key); - } - - Ok(out[..out_len].to_vec()) -} - -/// Generate IV https://tools.ietf.org/html/rfc3711#section-4.1.1 -/// where the 128-bit integer value IV SHALL be defined by the SSRC, the -/// SRTP packet index i, and the SRTP session salting key k_s, as below. -/// ROC = a 32-bit unsigned rollover counter (roc), which records how many -/// times the 16-bit RTP sequence number has been reset to zero after -/// passing through 65,535 -/// ```nobuild -/// i = 2^16 * roc + SEQ -/// IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16) -/// ``` -pub(crate) fn generate_counter( - sequence_number: u16, - rollover_counter: u32, - ssrc: u32, - session_salt: &[u8], -) -> [u8; 16] { - assert!(session_salt.len() <= 16); - - let mut counter = [0; 16]; - - let ssrc_be = ssrc.to_be_bytes(); - let rollover_be = rollover_counter.to_be_bytes(); - let seq_be = ((sequence_number as u32) << 16).to_be_bytes(); - - counter[4..8].copy_from_slice(&ssrc_be); - counter[8..12].copy_from_slice(&rollover_be); - counter[12..16].copy_from_slice(&seq_be); - - for i in 0..session_salt.len() { - counter[i] ^= session_salt[i]; - } - - counter -} - -#[cfg(test)] -mod test { - use super::*; - use crate::protection_profile::*; - - #[test] - fn test_valid_session_keys() -> Result<()> { - // Key Derivation Test Vectors from https://tools.ietf.org/html/rfc3711#appendix-B.3 - let master_key = vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ]; - let master_salt = vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ]; - - let expected_session_key = vec![ - 0xC6, 0x1E, 0x7A, 0x93, 0x74, 0x4F, 0x39, 0xEE, 0x10, 0x73, 0x4A, 0xFE, 0x3F, 0xF7, - 0xA0, 0x87, - ]; - let expected_session_salt = vec![ - 0x30, 0xCB, 0xBC, 0x08, 0x86, 0x3D, 0x8C, 0x85, 0xD4, 0x9D, 0xB3, 0x4A, 0x9A, 0xE1, - ]; - let expected_session_auth_tag = vec![ - 0xCE, 0xBE, 0x32, 0x1F, 0x6F, 0xF7, 0x71, 0x6B, 0x6F, 0xD4, 0xAB, 0x49, 0xAF, 0x25, - 0x6A, 0x15, 0x6D, 0x38, 0xBA, 0xA4, - ]; - - let session_key = aes_cm_key_derivation( - LABEL_SRTP_ENCRYPTION, - &master_key, - &master_salt, - 0, - master_key.len(), - )?; - assert_eq!( - session_key, expected_session_key, - "Session Key:\n{session_key:?} \ndoes not match expected:\n{expected_session_key:?}\nMaster Key:\n{master_key:?}\nMaster Salt:\n{master_salt:?}\n", - ); - - let session_salt = aes_cm_key_derivation( - LABEL_SRTP_SALT, - &master_key, - &master_salt, - 0, - master_salt.len(), - )?; - assert_eq!( - session_salt, expected_session_salt, - "Session Salt {session_salt:?} does not match expected {expected_session_salt:?}" - ); - - let auth_key_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_key_len(); - - let session_auth_tag = aes_cm_key_derivation( - LABEL_SRTP_AUTHENTICATION_TAG, - &master_key, - &master_salt, - 0, - auth_key_len, - )?; - assert_eq!( - session_auth_tag, expected_session_auth_tag, - "Session Auth Tag {session_auth_tag:?} does not match expected {expected_session_auth_tag:?}", - ); - - Ok(()) - } - - // This test asserts that calling aesCmKeyDerivation with a non-zero indexOverKdr fails - // Currently this isn't supported, but the API makes sure we can add this in the future - #[test] - fn test_index_over_kdr() -> Result<()> { - let result = aes_cm_key_derivation(LABEL_SRTP_AUTHENTICATION_TAG, &[], &[], 1, 0); - assert!(result.is_err()); - - Ok(()) - } -} diff --git a/srtp/src/lib.rs b/srtp/src/lib.rs deleted file mode 100644 index 044042eb2..000000000 --- a/srtp/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -mod cipher; -pub mod config; -pub mod context; -mod error; -mod key_derivation; -pub mod option; -pub mod protection_profile; -pub mod session; -pub mod stream; - -pub use error::Error; diff --git a/srtp/src/option.rs b/srtp/src/option.rs deleted file mode 100644 index d2513a866..000000000 --- a/srtp/src/option.rs +++ /dev/null @@ -1,36 +0,0 @@ -use util::replay_detector::*; - -pub type ContextOption = Box Box) + Send + Sync>; - -pub(crate) const MAX_SEQUENCE_NUMBER: u16 = 65535; -pub(crate) const MAX_SRTCP_INDEX: usize = 0x7FFFFFFF; - -/// srtp_replay_protection sets SRTP replay protection window size. -pub fn srtp_replay_protection(window_size: usize) -> ContextOption { - Box::new(move || -> Box { - Box::new(WrappedSlidingWindowDetector::new( - window_size, - MAX_SEQUENCE_NUMBER as u64, - )) - }) -} - -/// Sets SRTCP replay protection window size. -pub fn srtcp_replay_protection(window_size: usize) -> ContextOption { - Box::new(move || -> Box { - Box::new(WrappedSlidingWindowDetector::new( - window_size, - MAX_SRTCP_INDEX as u64, - )) - }) -} - -/// srtp_no_replay_protection disables SRTP replay protection. -pub fn srtp_no_replay_protection() -> ContextOption { - Box::new(|| -> Box { Box::::default() }) -} - -/// srtcp_no_replay_protection disables SRTCP replay protection. -pub fn srtcp_no_replay_protection() -> ContextOption { - Box::new(|| -> Box { Box::::default() }) -} diff --git a/srtp/src/protection_profile.rs b/srtp/src/protection_profile.rs deleted file mode 100644 index 0991ea737..000000000 --- a/srtp/src/protection_profile.rs +++ /dev/null @@ -1,37 +0,0 @@ -/// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite -#[derive(Default, Debug, Clone, Copy)] -#[repr(u8)] -pub enum ProtectionProfile { - #[default] - Aes128CmHmacSha1_80 = 0x0001, - AeadAes128Gcm = 0x0007, -} - -impl ProtectionProfile { - pub(crate) fn key_len(&self) -> usize { - match *self { - ProtectionProfile::Aes128CmHmacSha1_80 | ProtectionProfile::AeadAes128Gcm => 16, - } - } - - pub(crate) fn salt_len(&self) -> usize { - match *self { - ProtectionProfile::Aes128CmHmacSha1_80 => 14, - ProtectionProfile::AeadAes128Gcm => 12, - } - } - - pub(crate) fn auth_tag_len(&self) -> usize { - match *self { - ProtectionProfile::Aes128CmHmacSha1_80 => 10, //CIPHER_AES_CM_HMAC_SHA1AUTH_TAG_LEN, - ProtectionProfile::AeadAes128Gcm => 16, //CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN, - } - } - - pub(crate) fn auth_key_len(&self) -> usize { - match *self { - ProtectionProfile::Aes128CmHmacSha1_80 => 20, - ProtectionProfile::AeadAes128Gcm => 0, - } - } -} diff --git a/srtp/src/session/mod.rs b/srtp/src/session/mod.rs deleted file mode 100644 index efee1bef8..000000000 --- a/srtp/src/session/mod.rs +++ /dev/null @@ -1,271 +0,0 @@ -#[cfg(test)] -mod session_rtcp_test; -#[cfg(test)] -mod session_rtp_test; - -use std::collections::{HashMap, HashSet}; -use std::marker::{Send, Sync}; -use std::sync::Arc; - -use bytes::Bytes; -use tokio::sync::{mpsc, Mutex}; -use util::conn::Conn; -use util::marshal::*; - -use crate::config::*; -use crate::context::*; -use crate::error::{Error, Result}; -use crate::option::*; -use crate::stream::*; - -const DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW: usize = 64; -const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64; - -/// Session implements io.ReadWriteCloser and provides a bi-directional SRTP session -/// SRTP itself does not have a design like this, but it is common in most applications -/// for local/remote to each have their own keying material. This provides those patterns -/// instead of making everyone re-implement -pub struct Session { - local_context: Arc>, - streams_map: Arc>>>, - new_stream_rx: Arc>>>, - close_stream_tx: mpsc::Sender, - close_session_tx: mpsc::Sender<()>, - pub(crate) udp_tx: Arc, - is_rtp: bool, -} - -impl Session { - pub async fn new( - conn: Arc, - config: Config, - is_rtp: bool, - ) -> Result { - let local_context = Context::new( - &config.keys.local_master_key, - &config.keys.local_master_salt, - config.profile, - config.local_rtp_options, - config.local_rtcp_options, - )?; - - let mut remote_context = Context::new( - &config.keys.remote_master_key, - &config.keys.remote_master_salt, - config.profile, - if config.remote_rtp_options.is_none() { - Some(srtp_replay_protection( - DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW, - )) - } else { - config.remote_rtp_options - }, - if config.remote_rtcp_options.is_none() { - Some(srtcp_replay_protection( - DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW, - )) - } else { - config.remote_rtcp_options - }, - )?; - - let streams_map = Arc::new(Mutex::new(HashMap::new())); - let (mut new_stream_tx, new_stream_rx) = mpsc::channel(8); - let (close_stream_tx, mut close_stream_rx) = mpsc::channel(8); - let (close_session_tx, mut close_session_rx) = mpsc::channel(8); - let udp_tx = Arc::clone(&conn); - let udp_rx = Arc::clone(&conn); - let cloned_streams_map = Arc::clone(&streams_map); - let cloned_close_stream_tx = close_stream_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0u8; 8192]; - - loop { - let incoming_stream = Session::incoming( - &udp_rx, - &mut buf, - &cloned_streams_map, - &cloned_close_stream_tx, - &mut new_stream_tx, - &mut remote_context, - is_rtp, - ); - let close_stream = close_stream_rx.recv(); - let close_session = close_session_rx.recv(); - - tokio::select! { - result = incoming_stream => match result{ - Ok(()) => {}, - Err(err) => log::info!("{}", err), - }, - opt = close_stream => if let Some(ssrc) = opt { - Session::close_stream(&cloned_streams_map, ssrc).await - }, - _ = close_session => break - } - } - }); - - Ok(Session { - local_context: Arc::new(Mutex::new(local_context)), - streams_map, - new_stream_rx: Arc::new(Mutex::new(new_stream_rx)), - close_stream_tx, - close_session_tx, - udp_tx, - is_rtp, - }) - } - - async fn close_stream(streams_map: &Arc>>>, ssrc: u32) { - let mut streams = streams_map.lock().await; - streams.remove(&ssrc); - } - - async fn incoming( - udp_rx: &Arc, - buf: &mut [u8], - streams_map: &Arc>>>, - close_stream_tx: &mpsc::Sender, - new_stream_tx: &mut mpsc::Sender>, - remote_context: &mut Context, - is_rtp: bool, - ) -> Result<()> { - let n = udp_rx.recv(buf).await?; - if n == 0 { - return Err(Error::SessionEof); - } - - let decrypted = if is_rtp { - remote_context.decrypt_rtp(&buf[0..n])? - } else { - remote_context.decrypt_rtcp(&buf[0..n])? - }; - - let mut buf = &decrypted[..]; - let ssrcs = if is_rtp { - vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc] - } else { - let pkts = rtcp::packet::unmarshal(&mut buf)?; - destination_ssrc(&pkts) - }; - - for ssrc in ssrcs { - let (stream, is_new) = - Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc) - .await; - if is_new { - log::trace!( - "srtp session got new {} stream {}", - if is_rtp { "rtp" } else { "rtcp" }, - ssrc - ); - new_stream_tx.send(Arc::clone(&stream)).await?; - } - - match stream.buffer.write(&decrypted).await { - Ok(_) => {} - Err(err) => { - // Silently drop data when the buffer is full. - if util::Error::ErrBufferFull != err { - return Err(err.into()); - } - } - } - } - - Ok(()) - } - - async fn get_or_create_stream( - streams_map: &Arc>>>, - close_stream_tx: mpsc::Sender, - is_rtp: bool, - ssrc: u32, - ) -> (Arc, bool) { - let mut streams = streams_map.lock().await; - - if let Some(stream) = streams.get(&ssrc) { - (Arc::clone(stream), false) - } else { - let stream = Arc::new(Stream::new(ssrc, close_stream_tx, is_rtp)); - streams.insert(ssrc, Arc::clone(&stream)); - (stream, true) - } - } - - /// open on the given SSRC to create a stream, it can be used - /// if you want a certain SSRC, but don't want to wait for Accept - pub async fn open(&self, ssrc: u32) -> Arc { - let (stream, _) = Session::get_or_create_stream( - &self.streams_map, - self.close_stream_tx.clone(), - self.is_rtp, - ssrc, - ) - .await; - - stream - } - - /// accept returns a stream to handle RTCP for a single SSRC - pub async fn accept(&self) -> Result> { - let mut new_stream_rx = self.new_stream_rx.lock().await; - let result = new_stream_rx.recv().await; - if let Some(stream) = result { - Ok(stream) - } else { - Err(Error::SessionSrtpAlreadyClosed) - } - } - - pub async fn close(&self) -> Result<()> { - self.close_session_tx.send(()).await?; - - Ok(()) - } - - pub async fn write(&self, buf: &Bytes, is_rtp: bool) -> Result { - if self.is_rtp != is_rtp { - return Err(Error::SessionRtpRtcpTypeMismatch); - } - - let encrypted = { - let mut local_context = self.local_context.lock().await; - - if is_rtp { - local_context.encrypt_rtp(buf)? - } else { - local_context.encrypt_rtcp(buf)? - } - }; - - Ok(self.udp_tx.send(&encrypted).await?) - } - - pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result { - let raw = pkt.marshal()?; - self.write(&raw, true).await - } - - pub async fn write_rtcp( - &self, - pkt: &(dyn rtcp::packet::Packet + Send + Sync), - ) -> Result { - let raw = pkt.marshal()?; - self.write(&raw, false).await - } -} - -/// create a list of Destination SSRCs -/// that's a superset of all Destinations in the slice. -fn destination_ssrc(pkts: &[Box]) -> Vec { - let mut ssrc_set = HashSet::new(); - for p in pkts { - for ssrc in p.destination_ssrc() { - ssrc_set.insert(ssrc); - } - } - ssrc_set.into_iter().collect() -} diff --git a/srtp/src/session/session_rtcp_test.rs b/srtp/src/session/session_rtcp_test.rs deleted file mode 100644 index 8d880e9e0..000000000 --- a/srtp/src/session/session_rtcp_test.rs +++ /dev/null @@ -1,239 +0,0 @@ -use std::sync::Arc; - -use bytes::{Bytes, BytesMut}; -use rtcp::payload_feedbacks::*; -use tokio::sync::{mpsc, Mutex}; -use util::conn::conn_pipe::*; - -use super::*; -use crate::error::Result; -use crate::protection_profile::*; - -async fn build_session_srtcp_pair() -> Result<(Session, Session)> { - let (ua, ub) = pipe(); - - let ca = Config { - profile: ProtectionProfile::Aes128CmHmacSha1_80, - keys: SessionKeys { - local_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - local_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - remote_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - remote_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - }, - - local_rtp_options: None, - remote_rtp_options: None, - - local_rtcp_options: None, - remote_rtcp_options: None, - }; - - let cb = Config { - profile: ProtectionProfile::Aes128CmHmacSha1_80, - keys: SessionKeys { - local_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - local_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - remote_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - remote_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - }, - - local_rtp_options: None, - remote_rtp_options: None, - - local_rtcp_options: None, - remote_rtcp_options: None, - }; - - let sa = Session::new(Arc::new(ua), ca, false).await?; - let sb = Session::new(Arc::new(ub), cb, false).await?; - - Ok((sa, sb)) -} - -const TEST_SSRC: u32 = 5000; - -#[tokio::test] -async fn test_session_srtcp_accept() -> Result<()> { - let (sa, sb) = build_session_srtcp_pair().await?; - - let rtcp_packet = picture_loss_indication::PictureLossIndication { - media_ssrc: TEST_SSRC, - ..Default::default() - }; - - let test_payload = rtcp_packet.marshal()?; - sa.write_rtcp(&rtcp_packet).await?; - - let read_stream = sb.accept().await?; - let ssrc = read_stream.get_ssrc(); - assert_eq!( - ssrc, TEST_SSRC, - "SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})" - ); - - let mut read_buffer = BytesMut::with_capacity(test_payload.len()); - read_buffer.resize(test_payload.len(), 0u8); - read_stream.read(&mut read_buffer).await?; - - assert_eq!( - &test_payload[..], - &read_buffer[..], - "Sent buffer does not match the one received exp({:?}) actual({:?})", - &test_payload[..], - &read_buffer[..] - ); - - sa.close().await?; - sb.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_session_srtcp_listen() -> Result<()> { - let (sa, sb) = build_session_srtcp_pair().await?; - - let rtcp_packet = picture_loss_indication::PictureLossIndication { - media_ssrc: TEST_SSRC, - ..Default::default() - }; - - let test_payload = rtcp_packet.marshal()?; - let read_stream = sb.open(TEST_SSRC).await; - - sa.write_rtcp(&rtcp_packet).await?; - - let mut read_buffer = BytesMut::with_capacity(test_payload.len()); - read_buffer.resize(test_payload.len(), 0u8); - read_stream.read(&mut read_buffer).await?; - - assert_eq!( - &test_payload[..], - &read_buffer[..], - "Sent buffer does not match the one received exp({:?}) actual({:?})", - &test_payload[..], - &read_buffer[..] - ); - - sa.close().await?; - sb.close().await?; - - Ok(()) -} - -fn encrypt_srtcp( - context: &mut Context, - pkt: &(dyn rtcp::packet::Packet + Send + Sync), -) -> Result { - let decrypted = pkt.marshal()?; - let encrypted = context.encrypt_rtcp(&decrypted)?; - Ok(encrypted) -} - -const PLI_PACKET_SIZE: usize = 8; - -async fn get_sender_ssrc(read_stream: &Arc) -> Result { - let auth_tag_size = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); - - let mut read_buffer = BytesMut::with_capacity(PLI_PACKET_SIZE + auth_tag_size); - read_buffer.resize(PLI_PACKET_SIZE + auth_tag_size, 0u8); - - let pkts = read_stream.read_rtcp(&mut read_buffer).await?; - let mut bytes = &pkts[0].marshal()?[..]; - let pli = picture_loss_indication::PictureLossIndication::unmarshal(&mut bytes)?; - - Ok(pli.sender_ssrc) -} - -#[tokio::test] -async fn test_session_srtcp_replay_protection() -> Result<()> { - let (sa, sb) = build_session_srtcp_pair().await?; - - let read_stream = sb.open(TEST_SSRC).await; - - // Generate test packets - let mut packets = vec![]; - let mut expected_ssrc = vec![]; - { - let mut local_context = sa.local_context.lock().await; - for i in 0..0x10u32 { - expected_ssrc.push(i); - - let packet = picture_loss_indication::PictureLossIndication { - media_ssrc: TEST_SSRC, - sender_ssrc: i, - }; - - let encrypted = encrypt_srtcp(&mut local_context, &packet)?; - - packets.push(encrypted); - } - } - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - - let received_ssrc = Arc::new(Mutex::new(vec![])); - let cloned_received_ssrc = Arc::clone(&received_ssrc); - let count = expected_ssrc.len(); - - tokio::spawn(async move { - let mut i = 0; - while i < count { - match get_sender_ssrc(&read_stream).await { - Ok(ssrc) => { - let mut r = cloned_received_ssrc.lock().await; - r.push(ssrc); - - i += 1; - } - Err(_) => break, - } - } - - drop(done_tx); - }); - - // Write with replay attack - for packet in &packets { - sa.udp_tx.send(packet).await?; - - // Immediately replay - sa.udp_tx.send(packet).await?; - } - for packet in &packets { - // Delayed replay - sa.udp_tx.send(packet).await?; - } - - done_rx.recv().await; - - sa.close().await?; - sb.close().await?; - - { - let received_ssrc = received_ssrc.lock().await; - assert_eq!(&expected_ssrc[..], &received_ssrc[..]); - } - - Ok(()) -} diff --git a/srtp/src/session/session_rtp_test.rs b/srtp/src/session/session_rtp_test.rs deleted file mode 100644 index 5764d1346..000000000 --- a/srtp/src/session/session_rtp_test.rs +++ /dev/null @@ -1,308 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use bytes::{Bytes, BytesMut}; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, Mutex}; - -use super::*; -use crate::error::Result; -use crate::protection_profile::*; - -async fn build_session_srtp_pair() -> Result<(Session, Session)> { - let ua = UdpSocket::bind("127.0.0.1:0").await?; - let ub = UdpSocket::bind("127.0.0.1:0").await?; - - ua.connect(ub.local_addr()?).await?; - ub.connect(ua.local_addr()?).await?; - - let ca = Config { - profile: ProtectionProfile::Aes128CmHmacSha1_80, - keys: SessionKeys { - local_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - local_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - remote_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - remote_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - }, - - local_rtp_options: None, - remote_rtp_options: None, - - local_rtcp_options: None, - remote_rtcp_options: None, - }; - - let cb = Config { - profile: ProtectionProfile::Aes128CmHmacSha1_80, - keys: SessionKeys { - local_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - local_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - remote_master_key: vec![ - 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, - 0x41, 0x39, - ], - remote_master_salt: vec![ - 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, - ], - }, - - local_rtp_options: None, - remote_rtp_options: None, - - local_rtcp_options: None, - remote_rtcp_options: None, - }; - - let sa = Session::new(Arc::new(ua), ca, true).await?; - let sb = Session::new(Arc::new(ub), cb, true).await?; - - Ok((sa, sb)) -} - -const TEST_SSRC: u32 = 5000; -const RTP_HEADER_SIZE: usize = 12; - -#[tokio::test] -async fn test_session_srtp_accept() -> Result<()> { - let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]); - let mut read_buffer = BytesMut::with_capacity(RTP_HEADER_SIZE + test_payload.len()); - read_buffer.resize(RTP_HEADER_SIZE + test_payload.len(), 0u8); - let (sa, sb) = build_session_srtp_pair().await?; - - let packet = rtp::packet::Packet { - header: rtp::header::Header { - ssrc: TEST_SSRC, - ..Default::default() - }, - payload: test_payload.clone(), - }; - sa.write_rtp(&packet).await?; - - let read_stream = sb.accept().await?; - let ssrc = read_stream.get_ssrc(); - assert_eq!( - ssrc, TEST_SSRC, - "SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})" - ); - - read_stream.read(&mut read_buffer).await?; - - assert_eq!( - &test_payload[..], - &read_buffer[RTP_HEADER_SIZE..], - "Sent buffer does not match the one received exp({:?}) actual({:?})", - &test_payload[..], - &read_buffer[RTP_HEADER_SIZE..] - ); - - sa.close().await?; - sb.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_session_srtp_listen() -> Result<()> { - let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]); - let mut read_buffer = BytesMut::with_capacity(RTP_HEADER_SIZE + test_payload.len()); - read_buffer.resize(RTP_HEADER_SIZE + test_payload.len(), 0u8); - let (sa, sb) = build_session_srtp_pair().await?; - - let packet = rtp::packet::Packet { - header: rtp::header::Header { - ssrc: TEST_SSRC, - ..Default::default() - }, - payload: test_payload.clone(), - }; - - let read_stream = sb.open(TEST_SSRC).await; - - sa.write_rtp(&packet).await?; - - read_stream.read(&mut read_buffer).await?; - - assert_eq!( - &test_payload[..], - &read_buffer[RTP_HEADER_SIZE..], - "Sent buffer does not match the one received exp({:?}) actual({:?})", - &test_payload[..], - &read_buffer[RTP_HEADER_SIZE..] - ); - - sa.close().await?; - sb.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_session_srtp_multi_ssrc() -> Result<()> { - let ssrcs = vec![5000, 5001, 5002]; - let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]); - let mut read_buffer = BytesMut::with_capacity(RTP_HEADER_SIZE + test_payload.len()); - read_buffer.resize(RTP_HEADER_SIZE + test_payload.len(), 0u8); - let (sa, sb) = build_session_srtp_pair().await?; - - let mut read_streams = HashMap::new(); - for ssrc in &ssrcs { - let read_stream = sb.open(*ssrc).await; - read_streams.insert(*ssrc, read_stream); - } - - for ssrc in &ssrcs { - let packet = rtp::packet::Packet { - header: rtp::header::Header { - ssrc: *ssrc, - ..Default::default() - }, - payload: test_payload.clone(), - }; - sa.write_rtp(&packet).await?; - - if let Some(read_stream) = read_streams.get_mut(ssrc) { - read_stream.read(&mut read_buffer).await?; - - assert_eq!( - &test_payload[..], - &read_buffer[RTP_HEADER_SIZE..], - "Sent buffer does not match the one received exp({:?}) actual({:?})", - &test_payload[..], - &read_buffer[RTP_HEADER_SIZE..] - ); - } else { - panic!("ssrc {} not found", *ssrc); - } - } - - sa.close().await?; - sb.close().await?; - - Ok(()) -} - -fn encrypt_srtp(context: &mut Context, pkt: &rtp::packet::Packet) -> Result { - let decrypted = pkt.marshal()?; - let encrypted = context.encrypt_rtp(&decrypted)?; - Ok(encrypted) -} - -async fn payload_srtp( - read_stream: &Arc, - header_size: usize, - expected_payload: &[u8], -) -> Result { - let mut read_buffer = BytesMut::with_capacity(header_size + expected_payload.len()); - read_buffer.resize(header_size + expected_payload.len(), 0u8); - - let pkt = read_stream.read_rtp(&mut read_buffer).await?; - - assert_eq!( - expected_payload, - &pkt.payload[..], - "Sent buffer does not match the one received exp({:?}) actual({:?})", - expected_payload, - &pkt.payload[..] - ); - - Ok(pkt.header.sequence_number) -} - -#[tokio::test] -async fn test_session_srtp_replay_protection() -> Result<()> { - let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]); - - let (sa, sb) = build_session_srtp_pair().await?; - - let read_stream = sb.open(TEST_SSRC).await; - - // Generate test packets - let mut packets = vec![]; - let mut expected_sequence_number = vec![]; - { - let mut local_context = sa.local_context.lock().await; - let mut i = 0xFFF0u16; - while i != 0x10 { - expected_sequence_number.push(i); - - let packet = rtp::packet::Packet { - header: rtp::header::Header { - ssrc: TEST_SSRC, - sequence_number: i, - ..Default::default() - }, - payload: test_payload.clone(), - }; - - let encrypted = encrypt_srtp(&mut local_context, &packet)?; - - packets.push(encrypted); - - if i == 0xFFFF { - i = 0; - } else { - i += 1; - } - } - } - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - - let received_sequence_number = Arc::new(Mutex::new(vec![])); - let cloned_received_sequence_number = Arc::clone(&received_sequence_number); - let count = expected_sequence_number.len(); - - tokio::spawn(async move { - let mut i = 0; - while i < count { - let seq = payload_srtp(&read_stream, RTP_HEADER_SIZE, &test_payload) - .await - .unwrap(); - let mut r = cloned_received_sequence_number.lock().await; - r.push(seq); - - i += 1; - } - - drop(done_tx); - }); - - // Write with replay attack - for packet in &packets { - sa.udp_tx.send(packet).await?; - - // Immediately replay - sa.udp_tx.send(packet).await?; - } - for packet in &packets { - // Delayed replay - sa.udp_tx.send(packet).await?; - } - - done_rx.recv().await; - - sa.close().await?; - sb.close().await?; - - { - let received_sequence_number = received_sequence_number.lock().await; - assert_eq!(&received_sequence_number[..], &expected_sequence_number[..]); - } - - Ok(()) -} diff --git a/srtp/src/stream.rs b/srtp/src/stream.rs deleted file mode 100644 index 735f12cf3..000000000 --- a/srtp/src/stream.rs +++ /dev/null @@ -1,91 +0,0 @@ -use tokio::sync::mpsc; -use util::marshal::*; -use util::Buffer; - -use crate::error::{Error, Result}; - -/// Limit the buffer size to 1MB -pub const SRTP_BUFFER_SIZE: usize = 1000 * 1000; - -/// Limit the buffer size to 100KB -pub const SRTCP_BUFFER_SIZE: usize = 100 * 1000; - -/// Stream handles decryption for a single RTP/RTCP SSRC -#[derive(Debug)] -pub struct Stream { - ssrc: u32, - tx: mpsc::Sender, - pub(crate) buffer: Buffer, - is_rtp: bool, -} - -impl Stream { - /// Create a new stream - pub fn new(ssrc: u32, tx: mpsc::Sender, is_rtp: bool) -> Self { - Stream { - ssrc, - tx, - // Create a buffer with a 1MB limit - buffer: Buffer::new( - 0, - if is_rtp { - SRTP_BUFFER_SIZE - } else { - SRTCP_BUFFER_SIZE - }, - ), - is_rtp, - } - } - - /// GetSSRC returns the SSRC we are demuxing for - pub fn get_ssrc(&self) -> u32 { - self.ssrc - } - - /// Check if RTP is a stream. - pub fn is_rtp_stream(&self) -> bool { - self.is_rtp - } - - /// Read reads and decrypts full RTP packet from the nextConn - pub async fn read(&self, buf: &mut [u8]) -> Result { - Ok(self.buffer.read(buf, None).await?) - } - - /// ReadRTP reads and decrypts full RTP packet and its header from the nextConn - pub async fn read_rtp(&self, buf: &mut [u8]) -> Result { - if !self.is_rtp { - return Err(Error::InvalidRtpStream); - } - - let n = self.buffer.read(buf, None).await?; - let mut b = &buf[..n]; - let pkt = rtp::packet::Packet::unmarshal(&mut b)?; - - Ok(pkt) - } - - /// read_rtcp reads and decrypts full RTP packet and its header from the nextConn - pub async fn read_rtcp( - &self, - buf: &mut [u8], - ) -> Result>> { - if self.is_rtp { - return Err(Error::InvalidRtcpStream); - } - - let n = self.buffer.read(buf, None).await?; - let mut b = &buf[..n]; - let pkt = rtcp::packet::unmarshal(&mut b)?; - - Ok(pkt) - } - - /// Close removes the ReadStream from the session and cleans up any associated state - pub async fn close(&self) -> Result<()> { - self.buffer.close().await; - let _ = self.tx.send(self.ssrc).await; - Ok(()) - } -} diff --git a/stun/.gitignore b/stun/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/stun/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/stun/CHANGELOG.md b/stun/CHANGELOG.md deleted file mode 100644 index 2a6eaf605..000000000 --- a/stun/CHANGELOG.md +++ /dev/null @@ -1,16 +0,0 @@ -# webrtc-stun changelog - -## Unreleased - -## v0.4.4 - -* Increased minimum support rust version to `1.60.0`. -* Increased required `webrtc-util` version to `0.7.0`. - -## v0.4.3 - -* [#9 update deps + loosen some requirements](https://github.com/webrtc-rs/stun/pull/9) by [@melekes](https://github.com/melekes). - -## Prior to 0.4.3 - -Before 0.4.3 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/stun/releases). diff --git a/stun/Cargo.toml b/stun/Cargo.toml deleted file mode 100644 index efe071965..000000000 --- a/stun/Cargo.toml +++ /dev/null @@ -1,59 +0,0 @@ -[package] -name = "stun" -version = "0.6.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of STUN" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/stun" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/stun" - -[features] -default = [] -bench = [] - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["conn"] } - -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -lazy_static = "1" -url = "2" -rand = "0.8" -base64 = "0.21" -subtle = "2.4" -crc = "3" -ring = "0.17" -md-5 = "0.10" -thiserror = "1" - -[dev-dependencies] -tokio-test = "0.4" -clap = "3" -criterion = "0.5" - - -[[bench]] -name = "bench" -harness = false - -[[example]] -name = "stun_client" -path = "examples/stun_client.rs" -bench = false - -[[example]] -name = "stun_decode" -path = "examples/stun_decode.rs" -bench = false diff --git a/stun/LICENSE-APACHE b/stun/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/stun/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/stun/LICENSE-MIT b/stun/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/stun/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/stun/README.md b/stun/README.md deleted file mode 100644 index 1f70e3642..000000000 --- a/stun/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of STUN. Rewrite Pion STUN in Rust -

diff --git a/stun/benches/bench.rs b/stun/benches/bench.rs deleted file mode 100644 index 3c967880d..000000000 --- a/stun/benches/bench.rs +++ /dev/null @@ -1,659 +0,0 @@ -use std::io::Cursor; -use std::net::Ipv4Addr; -use std::ops::{Add, Sub}; -use std::time::Duration; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use criterion::{criterion_group, criterion_main, Criterion}; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; -use stun::addr::{AlternateServer, MappedAddress}; -use stun::agent::{noop_handler, Agent, TransactionId}; -use stun::attributes::{ - ATTR_CHANNEL_NUMBER, ATTR_DONT_FRAGMENT, ATTR_ERROR_CODE, ATTR_MESSAGE_INTEGRITY, ATTR_NONCE, - ATTR_REALM, ATTR_SOFTWARE, ATTR_USERNAME, ATTR_XORMAPPED_ADDRESS, -}; -use stun::error_code::{ErrorCode, ErrorCodeAttribute, CODE_STALE_NONCE}; -use stun::fingerprint::{FINGERPRINT, FINGERPRINT_SIZE}; -use stun::integrity::MessageIntegrity; -use stun::message::{ - is_message, Getter, Message, MessageType, Setter, ATTRIBUTE_HEADER_SIZE, BINDING_REQUEST, - CLASS_REQUEST, MESSAGE_HEADER_SIZE, METHOD_BINDING, -}; -use stun::textattrs::{Nonce, Realm, Software, Username}; -use stun::uattrs::UnknownAttributes; -use stun::xoraddr::{xor_bytes, XorMappedAddress}; -use tokio::time::Instant; - -// AGENT_COLLECT_CAP is initial capacity for Agent.Collect slices, -// sufficient to make function zero-alloc in most cases. -const AGENT_COLLECT_CAP: usize = 100; - -fn benchmark_addr(c: &mut Criterion) { - let mut m = Message::new(); - - let ma_addr = MappedAddress { - ip: "122.12.34.5".parse().unwrap(), - port: 5412, - }; - c.bench_function("BenchmarkMappedAddress_AddTo", |b| { - b.iter(|| { - ma_addr.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - - let as_addr = AlternateServer { - ip: "122.12.34.5".parse().unwrap(), - port: 5412, - }; - c.bench_function("BenchmarkAlternateServer_AddTo", |b| { - b.iter(|| { - as_addr.add_to(&mut m).unwrap(); - m.reset(); - }) - }); -} - -fn benchmark_agent(c: &mut Criterion) { - let deadline = Instant::now().add(Duration::from_secs(60 * 60 * 24)); - let gc_deadline = deadline.sub(Duration::from_secs(1)); - - { - let mut a = Agent::new(noop_handler()); - for _ in 0..AGENT_COLLECT_CAP { - a.start(TransactionId::new(), deadline).unwrap(); - } - - c.bench_function("BenchmarkAgent_GC", |b| { - b.iter(|| { - a.collect(gc_deadline).unwrap(); - }) - }); - - a.close().unwrap(); - } - - { - let mut a = Agent::new(noop_handler()); - for _ in 0..AGENT_COLLECT_CAP { - a.start(TransactionId::new(), deadline).unwrap(); - } - - let mut m = Message::new(); - m.build(&[Box::::default()]).unwrap(); - c.bench_function("BenchmarkAgent_Process", |b| { - b.iter(|| { - a.process(m.clone()).unwrap(); - }) - }); - - a.close().unwrap(); - } -} - -fn benchmark_attributes(c: &mut Criterion) { - { - let m = Message::new(); - c.bench_function("BenchmarkMessage_GetNotFound", |b| { - b.iter(|| { - let _ = m.get(ATTR_REALM); - }) - }); - } - - { - let mut m = Message::new(); - m.add(ATTR_USERNAME, &[1, 2, 3, 4, 5, 6, 7]); - c.bench_function("BenchmarkMessage_Get", |b| { - b.iter(|| { - let _ = m.get(ATTR_USERNAME); - }) - }); - } -} - -//TODO: add benchmark_client - -fn benchmark_error_code(c: &mut Criterion) { - { - let mut m = Message::new(); - c.bench_function("BenchmarkErrorCode_AddTo", |b| { - b.iter(|| { - let _ = CODE_STALE_NONCE.add_to(&mut m); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let a = ErrorCodeAttribute { - code: ErrorCode(404), - reason: b"not found!".to_vec(), - }; - c.bench_function("BenchmarkErrorCodeAttribute_AddTo", |b| { - b.iter(|| { - let _ = a.add_to(&mut m); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let mut a = ErrorCodeAttribute { - code: ErrorCode(404), - reason: b"not found!".to_vec(), - }; - let _ = a.add_to(&mut m); - c.bench_function("BenchmarkErrorCodeAttribute_GetFrom", |b| { - b.iter(|| { - a.get_from(&m).unwrap(); - }) - }); - } -} - -fn benchmark_fingerprint(c: &mut Criterion) { - { - let mut m = Message::new(); - let s = Software::new(ATTR_SOFTWARE, "software".to_owned()); - let addr = XorMappedAddress { - ip: Ipv4Addr::new(213, 1, 223, 5).into(), - port: 0, - }; - let _ = addr.add_to(&mut m); - let _ = s.add_to(&mut m); - c.bench_function("BenchmarkFingerprint_AddTo", |b| { - b.iter(|| { - let _ = FINGERPRINT.add_to(&mut m); - m.write_length(); - m.length -= (ATTRIBUTE_HEADER_SIZE + FINGERPRINT_SIZE) as u32; - m.raw.drain(m.length as usize + MESSAGE_HEADER_SIZE..); - m.attributes.0.drain(m.attributes.0.len() - 1..); - }) - }); - } - - { - let mut m = Message::new(); - let s = Software::new(ATTR_SOFTWARE, "software".to_owned()); - let addr = XorMappedAddress { - ip: Ipv4Addr::new(213, 1, 223, 5).into(), - port: 0, - }; - let _ = addr.add_to(&mut m); - let _ = s.add_to(&mut m); - m.write_header(); - FINGERPRINT.add_to(&mut m).unwrap(); - m.write_header(); - c.bench_function("BenchmarkFingerprint_Check", |b| { - b.iter(|| { - FINGERPRINT.check(&m).unwrap(); - }) - }); - } -} - -fn benchmark_message_build_overhead(c: &mut Criterion) { - let t = BINDING_REQUEST; - let username = Username::new(ATTR_USERNAME, "username".to_owned()); - let nonce = Nonce::new(ATTR_NONCE, "nonce".to_owned()); - let realm = Realm::new(ATTR_REALM, "example.org".to_owned()); - - { - let mut m = Message::new(); - c.bench_function("BenchmarkBuildOverhead/Build", |b| { - b.iter(|| { - let _ = m.build(&[ - Box::new(username.clone()), - Box::new(nonce.clone()), - Box::new(realm.clone()), - Box::new(FINGERPRINT), - ]); - }) - }); - } - - { - let mut m = Message::new(); - c.bench_function("BenchmarkBuildOverhead/Raw", |b| { - b.iter(|| { - m.reset(); - m.write_header(); - m.set_type(t); - let _ = username.add_to(&mut m); - let _ = nonce.add_to(&mut m); - let _ = realm.add_to(&mut m); - let _ = FINGERPRINT.add_to(&mut m); - }) - }); - } -} - -fn benchmark_message_integrity(c: &mut Criterion) { - { - let mut m = Message::new(); - let integrity = MessageIntegrity::new_short_term_integrity("password".to_owned()); - m.write_header(); - c.bench_function("BenchmarkMessageIntegrity_AddTo", |b| { - b.iter(|| { - m.write_header(); - integrity.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - m.raw = Vec::with_capacity(1024); - let software = Software::new(ATTR_SOFTWARE, "software".to_owned()); - let _ = software.add_to(&mut m); - let integrity = MessageIntegrity::new_short_term_integrity("password".to_owned()); - m.write_header(); - integrity.add_to(&mut m).unwrap(); - m.write_header(); - c.bench_function("BenchmarkMessageIntegrity_Check", |b| { - b.iter(|| { - integrity.check(&mut m).unwrap(); - }) - }); - } -} - -fn benchmark_message(c: &mut Criterion) { - { - let mut m = Message::new(); - c.bench_function("BenchmarkMessage_Write", |b| { - b.iter(|| { - m.add(ATTR_ERROR_CODE, &[0xff, 0x11, 0x12, 0x34]); - m.transaction_id = TransactionId::new(); - m.typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - m.write_header(); - m.reset(); - }) - }); - } - - { - let m = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - c.bench_function("BenchmarkMessageType_Value", |b| { - b.iter(|| { - let _ = m.value(); - }) - }); - } - - { - let typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ, - length: 0, - transaction_id: TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), - ..Default::default() - }; - m.write_header(); - let mut buf = vec![]; - c.bench_function("BenchmarkMessage_WriteTo", |b| { - b.iter(|| { - { - let mut writer = Cursor::new(&mut buf); - m.write_to(&mut writer).unwrap(); - } - buf.clear(); - }) - }); - } - - { - let typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ, - length: 0, - transaction_id: TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), - ..Default::default() - }; - m.write_header(); - let mut mrec = Message::new(); - c.bench_function("BenchmarkMessage_ReadFrom", |b| { - b.iter(|| { - let mut reader = Cursor::new(&m.raw); - mrec.read_from(&mut reader).unwrap(); - mrec.reset(); - }) - }); - } - - { - let typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ, - length: 0, - transaction_id: TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), - ..Default::default() - }; - m.write_header(); - let mut mrec = Message::new(); - c.bench_function("BenchmarkMessage_ReadBytes", |b| { - b.iter(|| { - mrec.write(&m.raw).unwrap(); - mrec.reset(); - }) - }); - } - - { - let typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ, - transaction_id: TransactionId::new(), - ..Default::default() - }; - let software = Software::new(ATTR_SOFTWARE, "cydev/stun test".to_owned()); - software.add_to(&mut m).unwrap(); - m.write_header(); - c.bench_function("BenchmarkIsMessage", |b| { - b.iter(|| { - assert!(is_message(&m.raw), "Should be message"); - }) - }); - } - - { - let mut m = Message::new(); - m.write_header(); - c.bench_function("BenchmarkMessage_NewTransactionID", |b| { - b.iter(|| { - m.new_transaction_id().unwrap(); - }) - }); - } - - { - let mut m = Message::new(); - let s = Software::new(ATTR_SOFTWARE, "software".to_owned()); - let addr = XorMappedAddress { - ip: Ipv4Addr::new(213, 1, 223, 5).into(), - ..Default::default() - }; - c.bench_function("BenchmarkMessageFull", |b| { - b.iter(|| { - addr.add_to(&mut m).unwrap(); - s.add_to(&mut m).unwrap(); - m.write_attributes(); - m.write_header(); - FINGERPRINT.add_to(&mut m).unwrap(); - m.write_header(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let s = Software::new(ATTR_SOFTWARE, "software".to_owned()); - let addr = XorMappedAddress { - ip: Ipv4Addr::new(213, 1, 223, 5).into(), - ..Default::default() - }; - c.bench_function("BenchmarkMessageFullHardcore", |b| { - b.iter(|| { - addr.add_to(&mut m).unwrap(); - s.add_to(&mut m).unwrap(); - m.write_header(); - m.reset(); - }) - }); - } - - { - let typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ, - transaction_id: TransactionId::new(), - raw: vec![0u8; 128], - ..Default::default() - }; - c.bench_function("BenchmarkMessage_WriteHeader", |b| { - b.iter(|| { - m.write_header(); - }) - }); - } - - { - let mut m = Message::new(); - m.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2])), - Box::new(Software::new(ATTR_SOFTWARE, "webrtc-rs/stun".to_owned())), - Box::new(MessageIntegrity::new_long_term_integrity( - "username".to_owned(), - "realm".to_owned(), - "password".to_owned(), - )), - Box::new(FINGERPRINT), - ]) - .unwrap(); - let mut a = Message::new(); - m.clone_to(&mut a).unwrap(); - c.bench_function("BenchmarkMessage_CloneTo", |b| { - b.iter(|| { - m.clone_to(&mut a).unwrap(); - }) - }); - } - - { - let mut m = Message::new(); - m.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2])), - Box::new(FINGERPRINT), - ]) - .unwrap(); - let mut a = Message::new(); - m.clone_to(&mut a).unwrap(); - c.bench_function("BenchmarkMessage_AddTo", |b| { - b.iter(|| { - m.add_to(&mut a).unwrap(); - }) - }); - } - - { - let typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ, - transaction_id: TransactionId::new(), - ..Default::default() - }; - m.add(ATTR_ERROR_CODE, &[0xff, 0xfe, 0xfa]); - m.write_header(); - let mut mdecoded = Message::new(); - c.bench_function("BenchmarkDecode", |b| { - b.iter(|| { - mdecoded.reset(); - mdecoded.raw.clone_from(&m.raw); - mdecoded.decode().unwrap(); - }) - }); - } -} - -fn benchmark_text_attributes(c: &mut Criterion) { - { - let mut m = Message::new(); - let u = Username::new(ATTR_USERNAME, "test".to_owned()); - c.bench_function("BenchmarkUsername_AddTo", |b| { - b.iter(|| { - u.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let mut u = Username::new(ATTR_USERNAME, "test".to_owned()); - u.add_to(&mut m).unwrap(); - c.bench_function("BenchmarkUsername_GetFrom", |b| { - b.iter(|| { - u.get_from(&m).unwrap(); - u.text.clear(); - }) - }); - } - - { - let mut m = Message::new(); - let n = Nonce::new(ATTR_NONCE, "nonce".to_owned()); - c.bench_function("BenchmarkNonce_AddTo", |b| { - b.iter(|| { - n.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let nonce = String::from_utf8(vec![b'a'; 2048]).unwrap(); - let n = Nonce::new(ATTR_NONCE, nonce); - c.bench_function("BenchmarkNonce_AddTo_BadLength", |b| { - b.iter(|| { - assert!(n.add_to(&mut m).is_err()); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let mut n = Nonce::new(ATTR_NONCE, "nonce".to_owned()); - n.add_to(&mut m).unwrap(); - c.bench_function("BenchmarkNonce_GetFrom", |b| { - b.iter(|| { - n.get_from(&m).unwrap(); - }) - }); - } -} - -fn benchmark_unknown_attributes(c: &mut Criterion) { - let mut m = Message::new(); - let a = UnknownAttributes(vec![ - ATTR_DONT_FRAGMENT, - ATTR_CHANNEL_NUMBER, - ATTR_REALM, - ATTR_MESSAGE_INTEGRITY, - ]); - - { - c.bench_function("BenchmarkUnknownAttributes/AddTo", |b| { - b.iter(|| { - a.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - a.add_to(&mut m).unwrap(); - let mut attrs = UnknownAttributes(Vec::with_capacity(10)); - c.bench_function("BenchmarkUnknownAttributes/GetFrom", |b| { - b.iter(|| { - attrs.get_from(&m).unwrap(); - attrs.0.clear(); - }) - }); - } -} - -fn benchmark_xor(c: &mut Criterion) { - let mut r = StdRng::seed_from_u64(666); - let mut a = [0u8; 1024]; - let mut d = [0u8; 1024]; - r.fill(&mut a); - r.fill(&mut d); - let mut dst = [0u8; 1024]; - c.bench_function("BenchmarkXOR", |b| { - b.iter(|| { - let _ = xor_bytes(&mut dst, &a, &d); - }) - }); -} - -fn benchmark_xoraddr(c: &mut Criterion) { - { - let mut m = Message::new(); - let ip = "192.168.1.32".parse().unwrap(); - c.bench_function("BenchmarkXORMappedAddress_AddTo", |b| { - b.iter(|| { - let addr = XorMappedAddress { ip, port: 3654 }; - addr.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let transaction_id = BASE64_STANDARD.decode("jxhBARZwX+rsC6er").unwrap(); - - m.transaction_id.0.copy_from_slice(&transaction_id); - let addr_value = [0, 1, 156, 213, 244, 159, 56, 174]; //hex.DecodeString("00019cd5f49f38ae") - m.add(ATTR_XORMAPPED_ADDRESS, &addr_value); - let mut addr = XorMappedAddress::default(); - c.bench_function("BenchmarkXORMappedAddress_GetFrom", |b| { - b.iter(|| { - addr.get_from(&m).unwrap(); - }) - }); - } -} - -criterion_group!( - benches, - benchmark_addr, - benchmark_agent, - benchmark_attributes, - //TODO: benchmark_client, - benchmark_error_code, - benchmark_fingerprint, - benchmark_message_build_overhead, - benchmark_message_integrity, - benchmark_message, - benchmark_text_attributes, - benchmark_unknown_attributes, - benchmark_xor, - benchmark_xoraddr, -); -criterion_main!(benches); diff --git a/stun/codecov.yml b/stun/codecov.yml deleted file mode 100644 index 4b006c24c..000000000 --- a/stun/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 5ed548cd-073b-4748-b584-ca2d637027bf - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/stun/doc/webrtc.rs.png b/stun/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/stun/doc/webrtc.rs.png and /dev/null differ diff --git a/stun/examples/stun_client.rs b/stun/examples/stun_client.rs deleted file mode 100644 index 48e2c0415..000000000 --- a/stun/examples/stun_client.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::sync::Arc; - -use clap::{App, Arg}; -use stun::agent::*; -use stun::client::*; -use stun::message::*; -use stun::xoraddr::*; -use stun::Error; -use tokio::net::UdpSocket; - -#[tokio::main] -async fn main() -> Result<(), Error> { - let mut app = App::new("STUN Client") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of STUN Client") - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("server") - .required_unless("FULLHELP") - .takes_value(true) - .default_value("stun.l.google.com:19302") - .long("server") - .help("STUN Server"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let server = matches.value_of("server").unwrap(); - - let (handler_tx, mut handler_rx) = tokio::sync::mpsc::unbounded_channel(); - - let conn = UdpSocket::bind("0:0").await?; - println!("Local address: {}", conn.local_addr()?); - - println!("Connecting to: {server}"); - conn.connect(server).await?; - - let mut client = ClientBuilder::new().with_conn(Arc::new(conn)).build()?; - - let mut msg = Message::new(); - msg.build(&[Box::::default(), Box::new(BINDING_REQUEST)])?; - - client.send(&msg, Some(Arc::new(handler_tx))).await?; - - if let Some(event) = handler_rx.recv().await { - let msg = event.event_body?; - let mut xor_addr = XorMappedAddress::default(); - xor_addr.get_from(&msg)?; - println!("Got response: {xor_addr}"); - } - - client.close().await?; - - Ok(()) -} diff --git a/stun/examples/stun_decode.rs b/stun/examples/stun_decode.rs deleted file mode 100644 index 959cbc15a..000000000 --- a/stun/examples/stun_decode.rs +++ /dev/null @@ -1,44 +0,0 @@ -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use clap::{App, Arg}; -use stun::message::Message; - -fn main() { - let mut app = App::new("STUN decode") - .version("0.1.0") - .author("Jtplouffe ") - .about("An example of STUN decode") - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("data") - .required_unless("FULLHELP") - .takes_value(true) - .index(1) - .help("base64 encoded message, e.g. 'AAEAHCESpEJML0JTQWsyVXkwcmGALwAWaHR0cDovL2xvY2FsaG9zdDozMDAwLwAA'"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let encoded_data = matches.value_of("data").unwrap(); - let decoded_data = match BASE64_STANDARD.decode(encoded_data) { - Ok(d) => d, - Err(e) => panic!("Unable to decode base64 value: {e}"), - }; - - let mut message = Message::new(); - message.raw = decoded_data; - - match message.decode() { - Ok(_) => println!("{message}"), - Err(e) => panic!("Unable to decode message: {e}"), - } -} diff --git a/stun/src/addr.rs b/stun/src/addr.rs deleted file mode 100644 index cafe609ff..000000000 --- a/stun/src/addr.rs +++ /dev/null @@ -1,128 +0,0 @@ -#[cfg(test)] -mod addr_test; - -use std::fmt; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - -use crate::attributes::*; -use crate::error::*; -use crate::message::*; - -pub(crate) const FAMILY_IPV4: u16 = 0x01; -pub(crate) const FAMILY_IPV6: u16 = 0x02; -pub(crate) const IPV4LEN: usize = 4; -pub(crate) const IPV6LEN: usize = 16; - -/// MappedAddress represents MAPPED-ADDRESS attribute. -/// -/// This attribute is used only by servers for achieving backwards -/// compatibility with RFC 3489 clients. -/// -/// RFC 5389 Section 15.1 -pub struct MappedAddress { - pub ip: IpAddr, - pub port: u16, -} - -impl fmt::Display for MappedAddress { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let family = match self.ip { - IpAddr::V4(_) => FAMILY_IPV4, - IpAddr::V6(_) => FAMILY_IPV6, - }; - if family == FAMILY_IPV4 { - write!(f, "{}:{}", self.ip, self.port) - } else { - write!(f, "[{}]:{}", self.ip, self.port) - } - } -} - -impl Default for MappedAddress { - fn default() -> Self { - MappedAddress { - ip: IpAddr::V4(Ipv4Addr::from(0)), - port: 0, - } - } -} - -impl Setter for MappedAddress { - /// add_to adds MAPPED-ADDRESS to message. - fn add_to(&self, m: &mut Message) -> Result<()> { - self.add_to_as(m, ATTR_MAPPED_ADDRESS) - } -} - -impl Getter for MappedAddress { - /// get_from decodes MAPPED-ADDRESS from message. - fn get_from(&mut self, m: &Message) -> Result<()> { - self.get_from_as(m, ATTR_MAPPED_ADDRESS) - } -} - -impl MappedAddress { - /// get_from_as decodes MAPPED-ADDRESS value in message m as an attribute of type t. - pub fn get_from_as(&mut self, m: &Message, t: AttrType) -> Result<()> { - let v = m.get(t)?; - if v.len() <= 4 { - return Err(Error::ErrUnexpectedEof); - } - - let family = u16::from_be_bytes([v[0], v[1]]); - if family != FAMILY_IPV6 && family != FAMILY_IPV4 { - return Err(Error::Other(format!("bad value {family}"))); - } - self.port = u16::from_be_bytes([v[2], v[3]]); - - if family == FAMILY_IPV6 { - let mut ip = [0; IPV6LEN]; - let l = std::cmp::min(ip.len(), v[4..].len()); - ip[..l].copy_from_slice(&v[4..4 + l]); - self.ip = IpAddr::V6(Ipv6Addr::from(ip)); - } else { - let mut ip = [0; IPV4LEN]; - let l = std::cmp::min(ip.len(), v[4..].len()); - ip[..l].copy_from_slice(&v[4..4 + l]); - self.ip = IpAddr::V4(Ipv4Addr::from(ip)); - }; - - Ok(()) - } - - /// add_to_as adds MAPPED-ADDRESS value to m as t attribute. - pub fn add_to_as(&self, m: &mut Message, t: AttrType) -> Result<()> { - let family = match self.ip { - IpAddr::V4(_) => FAMILY_IPV4, - IpAddr::V6(_) => FAMILY_IPV6, - }; - - let mut value = vec![0u8; 4]; - //value[0] = 0 // first 8 bits are zeroes - value[0..2].copy_from_slice(&family.to_be_bytes()); - value[2..4].copy_from_slice(&self.port.to_be_bytes()); - - match self.ip { - IpAddr::V4(ipv4) => value.extend_from_slice(&ipv4.octets()), - IpAddr::V6(ipv6) => value.extend_from_slice(&ipv6.octets()), - }; - - m.add(t, &value); - Ok(()) - } -} - -/// AlternateServer represents ALTERNATE-SERVER attribute. -/// -/// RFC 5389 Section 15.11 -pub type AlternateServer = MappedAddress; - -/// ResponseOrigin represents RESPONSE-ORIGIN attribute. -/// -/// RFC 5780 Section 7.3 -pub type ResponseOrigin = MappedAddress; - -/// OtherAddress represents OTHER-ADDRESS attribute. -/// -/// RFC 5780 Section 7.4 -pub type OtherAddress = MappedAddress; diff --git a/stun/src/addr/addr_test.rs b/stun/src/addr/addr_test.rs deleted file mode 100644 index 77f5ac6d9..000000000 --- a/stun/src/addr/addr_test.rs +++ /dev/null @@ -1,183 +0,0 @@ -use super::*; -use crate::error::*; - -#[test] -fn test_mapped_address() -> Result<()> { - let mut m = Message::new(); - let addr = MappedAddress { - ip: "122.12.34.5".parse().unwrap(), - port: 5412, - }; - assert_eq!(addr.to_string(), "122.12.34.5:5412", "bad string {addr}"); - - //"add_to" - { - addr.add_to(&mut m)?; - - //"GetFrom" - { - let mut got = MappedAddress::default(); - got.get_from(&m)?; - assert_eq!(got.ip, addr.ip, "got bad IP: {}", got.ip); - - //"Not found" - { - let message = Message::new(); - let result = got.get_from(&message); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "should be not found: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - } - //"Bad family" - { - let (mut v, _) = m.attributes.get(ATTR_MAPPED_ADDRESS); - v.value[0] = 32; - got.get_from(&m)? - } - //"Bad length" - { - let mut message = Message::new(); - message.add(ATTR_MAPPED_ADDRESS, &[1, 2, 3]); - let result = got.get_from(&message); - if let Err(err) = result { - assert_eq!( - Error::ErrUnexpectedEof, - err, - "<{}> should be <{}>", - err, - Error::ErrUnexpectedEof - ); - } else { - panic!("expected error, but got ok"); - } - } - } - } - - Ok(()) -} - -#[test] -fn test_mapped_address_v6() -> Result<()> { - let mut m = Message::new(); - let addr = MappedAddress { - ip: "::".parse().unwrap(), - port: 5412, - }; - - //"add_to" - { - addr.add_to(&mut m)?; - - //"GetFrom" - { - let mut got = MappedAddress::default(); - got.get_from(&m)?; - assert_eq!(got.ip, addr.ip, "got bad IP: {}", got.ip); - - //"Not found" - { - let message = Message::new(); - let result = got.get_from(&message); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "<{}> should be <{}>", - err, - Error::ErrAttributeNotFound, - ); - } else { - panic!("expected error, but got ok"); - } - } - } - } - Ok(()) -} - -#[test] -fn test_alternate_server() -> Result<()> { - let mut m = Message::new(); - let addr = MappedAddress { - ip: "122.12.34.5".parse().unwrap(), - port: 5412, - }; - - //"add_to" - { - addr.add_to(&mut m)?; - - //"GetFrom" - { - let mut got = AlternateServer::default(); - got.get_from(&m)?; - assert_eq!(got.ip, addr.ip, "got bad IP: {}", got.ip); - - //"Not found" - { - let message = Message::new(); - let result = got.get_from(&message); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "<{}> should be <{}>", - err, - Error::ErrAttributeNotFound, - ); - } else { - panic!("expected error, but got ok"); - } - } - } - } - - Ok(()) -} - -#[test] -fn test_other_address() -> Result<()> { - let mut m = Message::new(); - let addr = OtherAddress { - ip: "122.12.34.5".parse().unwrap(), - port: 5412, - }; - - //"add_to" - { - addr.add_to(&mut m)?; - - //"GetFrom" - { - let mut got = OtherAddress::default(); - got.get_from(&m)?; - assert_eq!(got.ip, addr.ip, "got bad IP: {}", got.ip); - - //"Not found" - { - let message = Message::new(); - let result = got.get_from(&message); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "<{}> should be <{}>", - err, - Error::ErrAttributeNotFound, - ); - } else { - panic!("expected error, but got ok"); - } - } - } - } - - Ok(()) -} diff --git a/stun/src/agent.rs b/stun/src/agent.rs deleted file mode 100644 index 4562df722..000000000 --- a/stun/src/agent.rs +++ /dev/null @@ -1,283 +0,0 @@ -#[cfg(test)] -mod agent_test; - -use std::collections::HashMap; -use std::sync::Arc; - -use rand::Rng; -use tokio::sync::mpsc; -use tokio::time::Instant; - -use crate::client::ClientTransaction; -use crate::error::*; -use crate::message::*; - -/// Handler handles state changes of transaction. -/// Handler is called on transaction state change. -/// Usage of e is valid only during call, user must -/// copy needed fields explicitly. -pub type Handler = Option>>; - -/// noop_handler just discards any event. -pub fn noop_handler() -> Handler { - None -} - -/// Agent is low-level abstraction over transaction list that -/// handles concurrency (all calls are goroutine-safe) and -/// time outs (via Collect call). -pub struct Agent { - /// transactions is map of transactions that are currently - /// in progress. Event handling is done in such way when - /// transaction is unregistered before AgentTransaction access, - /// minimizing mux lock and protecting AgentTransaction from - /// data races via unexpected concurrent access. - transactions: HashMap, - /// all calls are invalid if true - closed: bool, - /// handles transactions - handler: Handler, -} - -#[derive(Debug, Clone)] -pub enum EventType { - Callback(TransactionId), - Insert(ClientTransaction), - Remove(TransactionId), - Close, -} - -impl Default for EventType { - fn default() -> Self { - EventType::Callback(TransactionId::default()) - } -} - -/// Event is passed to Handler describing the transaction event. -/// Do not reuse outside Handler. -#[derive(Debug)] //Clone -pub struct Event { - pub event_type: EventType, - pub event_body: Result, -} - -impl Default for Event { - fn default() -> Self { - Event { - event_type: EventType::default(), - event_body: Ok(Message::default()), - } - } -} - -/// AgentTransaction represents transaction in progress. -/// Concurrent access is invalid. -pub(crate) struct AgentTransaction { - id: TransactionId, - deadline: Instant, -} - -/// AGENT_COLLECT_CAP is initial capacity for Agent.Collect slices, -/// sufficient to make function zero-alloc in most cases. -const AGENT_COLLECT_CAP: usize = 100; - -#[derive(PartialEq, Eq, Hash, Copy, Clone, Default, Debug)] -pub struct TransactionId(pub [u8; TRANSACTION_ID_SIZE]); - -impl TransactionId { - /// new returns new random transaction ID using crypto/rand - /// as source. - pub fn new() -> Self { - let mut b = TransactionId([0u8; TRANSACTION_ID_SIZE]); - rand::thread_rng().fill(&mut b.0); - b - } -} - -impl Setter for TransactionId { - fn add_to(&self, m: &mut Message) -> Result<()> { - m.transaction_id = *self; - m.write_transaction_id(); - Ok(()) - } -} - -/// ClientAgent is Agent implementation that is used by Client to -/// process transactions. -#[derive(Debug)] -pub enum ClientAgent { - Process(Message), - Collect(Instant), - Start(TransactionId, Instant), - Stop(TransactionId), - Close, -} - -impl Agent { - /// new initializes and returns new Agent with provided handler. - pub fn new(handler: Handler) -> Self { - Agent { - transactions: HashMap::new(), - closed: false, - handler, - } - } - - /// stop_with_error removes transaction from list and calls handler with - /// provided error. Can return ErrTransactionNotExists and ErrAgentClosed. - pub fn stop_with_error(&mut self, id: TransactionId, error: Error) -> Result<()> { - if self.closed { - return Err(Error::ErrAgentClosed); - } - - let v = self.transactions.remove(&id); - if let Some(t) = v { - if let Some(handler) = &self.handler { - handler.send(Event { - event_type: EventType::Callback(t.id), - event_body: Err(error), - })?; - } - Ok(()) - } else { - Err(Error::ErrTransactionNotExists) - } - } - - /// process incoming message, synchronously passing it to handler. - pub fn process(&mut self, message: Message) -> Result<()> { - if self.closed { - return Err(Error::ErrAgentClosed); - } - - self.transactions.remove(&message.transaction_id); - - let e = Event { - event_type: EventType::Callback(message.transaction_id), - event_body: Ok(message), - }; - - if let Some(handler) = &self.handler { - handler.send(e)?; - } - - Ok(()) - } - - /// close terminates all transactions with ErrAgentClosed and renders Agent to - /// closed state. - pub fn close(&mut self) -> Result<()> { - if self.closed { - return Err(Error::ErrAgentClosed); - } - - for id in self.transactions.keys() { - let e = Event { - event_type: EventType::Callback(*id), - event_body: Err(Error::ErrAgentClosed), - }; - if let Some(handler) = &self.handler { - handler.send(e)?; - } - } - self.transactions = HashMap::new(); - self.closed = true; - self.handler = noop_handler(); - - Ok(()) - } - - /// start registers transaction with provided id and deadline. - /// Could return ErrAgentClosed, ErrTransactionExists. - /// - /// Agent handler is guaranteed to be eventually called. - pub fn start(&mut self, id: TransactionId, deadline: Instant) -> Result<()> { - if self.closed { - return Err(Error::ErrAgentClosed); - } - if self.transactions.contains_key(&id) { - return Err(Error::ErrTransactionExists); - } - - self.transactions - .insert(id, AgentTransaction { id, deadline }); - - Ok(()) - } - - /// stop stops transaction by id with ErrTransactionStopped, blocking - /// until handler returns. - pub fn stop(&mut self, id: TransactionId) -> Result<()> { - self.stop_with_error(id, Error::ErrTransactionStopped) - } - - /// collect terminates all transactions that have deadline before provided - /// time, blocking until all handlers will process ErrTransactionTimeOut. - /// Will return ErrAgentClosed if agent is already closed. - /// - /// It is safe to call Collect concurrently but makes no sense. - pub fn collect(&mut self, deadline: Instant) -> Result<()> { - if self.closed { - // Doing nothing if agent is closed. - // All transactions should be already closed - // during Close() call. - return Err(Error::ErrAgentClosed); - } - - let mut to_remove: Vec = Vec::with_capacity(AGENT_COLLECT_CAP); - - // Adding all transactions with deadline before gc_time - // to toCall and to_remove slices. - // No allocs if there are less than AGENT_COLLECT_CAP - // timed out transactions. - for (id, t) in &self.transactions { - if t.deadline < deadline { - to_remove.push(*id); - } - } - // Un-registering timed out transactions. - for id in &to_remove { - self.transactions.remove(id); - } - - for id in to_remove { - let event = Event { - event_type: EventType::Callback(id), - event_body: Err(Error::ErrTransactionTimeOut), - }; - if let Some(handler) = &self.handler { - handler.send(event)?; - } - } - - Ok(()) - } - - /// set_handler sets agent handler to h. - pub fn set_handler(&mut self, h: Handler) -> Result<()> { - if self.closed { - return Err(Error::ErrAgentClosed); - } - self.handler = h; - - Ok(()) - } - - pub(crate) async fn run(mut agent: Agent, mut rx: mpsc::Receiver) { - while let Some(client_agent) = rx.recv().await { - let result = match client_agent { - ClientAgent::Process(message) => agent.process(message), - ClientAgent::Collect(deadline) => agent.collect(deadline), - ClientAgent::Start(tid, deadline) => agent.start(tid, deadline), - ClientAgent::Stop(tid) => agent.stop(tid), - ClientAgent::Close => agent.close(), - }; - - if let Err(err) = result { - if Error::ErrAgentClosed == err { - break; - } - } - } - } -} diff --git a/stun/src/agent/agent_test.rs b/stun/src/agent/agent_test.rs deleted file mode 100644 index 8ca1afa23..000000000 --- a/stun/src/agent/agent_test.rs +++ /dev/null @@ -1,195 +0,0 @@ -use std::ops::Add; - -use tokio::time::Duration; - -use super::*; -use crate::error::*; - -#[tokio::test] -async fn test_agent_process_in_transaction() -> Result<()> { - let mut m = Message::new(); - let (handler_tx, mut handler_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut a = Agent::new(Some(Arc::new(handler_tx))); - m.transaction_id = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - a.start(m.transaction_id, Instant::now())?; - a.process(m)?; - a.close()?; - - while let Some(e) = handler_rx.recv().await { - assert!(e.event_body.is_ok(), "got error: {:?}", e.event_body); - - let tid = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - assert_eq!( - e.event_body.as_ref().unwrap().transaction_id, - tid, - "{:?} (got) != {:?} (expected)", - e.event_body.as_ref().unwrap().transaction_id, - tid - ); - } - - Ok(()) -} - -#[tokio::test] -async fn test_agent_process() -> Result<()> { - let mut m = Message::new(); - let (handler_tx, mut handler_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut a = Agent::new(Some(Arc::new(handler_tx))); - m.transaction_id = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - a.process(m.clone())?; - a.close()?; - - while let Some(e) = handler_rx.recv().await { - assert!(e.event_body.is_ok(), "got error: {:?}", e.event_body); - - let tid = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - assert_eq!( - e.event_body.as_ref().unwrap().transaction_id, - tid, - "{:?} (got) != {:?} (expected)", - e.event_body.as_ref().unwrap().transaction_id, - tid - ); - } - - let result = a.process(m); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrAgentClosed, - "closed agent should return <{}>, but got <{}>", - Error::ErrAgentClosed, - err, - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_agent_start() -> Result<()> { - let mut a = Agent::new(noop_handler()); - let id = TransactionId::new(); - let deadline = Instant::now().add(Duration::from_secs(3600)); - a.start(id, deadline)?; - - let result = a.start(id, deadline); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrTransactionExists, - "duplicate start should return <{}>, got <{}>", - Error::ErrTransactionExists, - err, - ); - } else { - panic!("expected error, but got ok"); - } - a.close()?; - - let id = TransactionId::new(); - let result = a.start(id, deadline); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrAgentClosed, - "start on closed agent should return <{}>, got <{}>", - Error::ErrAgentClosed, - err, - ); - } else { - panic!("expected error, but got ok"); - } - - let result = a.set_handler(noop_handler()); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrAgentClosed, - "SetHandler on closed agent should return <{}>, got <{}>", - Error::ErrAgentClosed, - err, - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_agent_stop() -> Result<()> { - let (handler_tx, mut handler_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut a = Agent::new(Some(Arc::new(handler_tx))); - - let result = a.stop(TransactionId::default()); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrTransactionNotExists, - "unexpected error: {}, should be {}", - Error::ErrTransactionNotExists, - err, - ); - } else { - panic!("expected error, but got ok"); - } - - let id = TransactionId::new(); - let deadline = Instant::now().add(Duration::from_millis(200)); - a.start(id, deadline)?; - a.stop(id)?; - - let timeout = tokio::time::sleep(Duration::from_millis(400)); - tokio::pin!(timeout); - - tokio::select! { - evt = handler_rx.recv() => { - if let Err(err) = evt.unwrap().event_body{ - assert_eq!( - err, - Error::ErrTransactionStopped, - "unexpected error: {}, should be {}", - err, - Error::ErrTransactionStopped - ); - }else{ - panic!("expected error, got ok"); - } - } - _ = timeout.as_mut() => panic!("timed out"), - } - - a.close()?; - - let result = a.close(); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrAgentClosed, - "a.Close returned {} instead of {}", - Error::ErrAgentClosed, - err, - ); - } else { - panic!("expected error, but got ok"); - } - - let result = a.stop(TransactionId::default()); - if let Err(err) = result { - assert_eq!( - err, - Error::ErrAgentClosed, - "unexpected error: {}, should be {}", - Error::ErrAgentClosed, - err, - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} diff --git a/stun/src/attributes.rs b/stun/src/attributes.rs deleted file mode 100644 index f51a98edb..000000000 --- a/stun/src/attributes.rs +++ /dev/null @@ -1,209 +0,0 @@ -#[cfg(test)] -mod attributes_test; - -use std::fmt; - -use crate::error::*; -use crate::message::*; - -/// Attributes is list of message attributes. -#[derive(Default, PartialEq, Eq, Debug, Clone)] -pub struct Attributes(pub Vec); - -impl Attributes { - /// get returns first attribute from list by the type. - /// If attribute is present the RawAttribute is returned and the - /// boolean is true. Otherwise the returned RawAttribute will be - /// empty and boolean will be false. - pub fn get(&self, t: AttrType) -> (RawAttribute, bool) { - for candidate in &self.0 { - if candidate.typ == t { - return (candidate.clone(), true); - } - } - - (RawAttribute::default(), false) - } -} - -/// AttrType is attribute type. -#[derive(PartialEq, Debug, Eq, Default, Copy, Clone)] -pub struct AttrType(pub u16); - -impl fmt::Display for AttrType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let other = format!("0x{:x}", self.0); - - let s = match *self { - ATTR_MAPPED_ADDRESS => "MAPPED-ADDRESS", - ATTR_USERNAME => "USERNAME", - ATTR_ERROR_CODE => "ERROR-CODE", - ATTR_MESSAGE_INTEGRITY => "MESSAGE-INTEGRITY", - ATTR_UNKNOWN_ATTRIBUTES => "UNKNOWN-ATTRIBUTES", - ATTR_REALM => "REALM", - ATTR_NONCE => "NONCE", - ATTR_XORMAPPED_ADDRESS => "XOR-MAPPED-ADDRESS", - ATTR_SOFTWARE => "SOFTWARE", - ATTR_ALTERNATE_SERVER => "ALTERNATE-SERVER", - ATTR_FINGERPRINT => "FINGERPRINT", - ATTR_PRIORITY => "PRIORITY", - ATTR_USE_CANDIDATE => "USE-CANDIDATE", - ATTR_ICE_CONTROLLED => "ICE-CONTROLLED", - ATTR_ICE_CONTROLLING => "ICE-CONTROLLING", - ATTR_CHANNEL_NUMBER => "CHANNEL-NUMBER", - ATTR_LIFETIME => "LIFETIME", - ATTR_XOR_PEER_ADDRESS => "XOR-PEER-ADDRESS", - ATTR_DATA => "DATA", - ATTR_XOR_RELAYED_ADDRESS => "XOR-RELAYED-ADDRESS", - ATTR_EVEN_PORT => "EVEN-PORT", - ATTR_REQUESTED_TRANSPORT => "REQUESTED-TRANSPORT", - ATTR_DONT_FRAGMENT => "DONT-FRAGMENT", - ATTR_RESERVATION_TOKEN => "RESERVATION-TOKEN", - ATTR_CONNECTION_ID => "CONNECTION-ID", - ATTR_REQUESTED_ADDRESS_FAMILY => "REQUESTED-ADDRESS-FAMILY", - ATTR_MESSAGE_INTEGRITY_SHA256 => "MESSAGE-INTEGRITY-SHA256", - ATTR_PASSWORD_ALGORITHM => "PASSWORD-ALGORITHM", - ATTR_USER_HASH => "USERHASH", - ATTR_PASSWORD_ALGORITHMS => "PASSWORD-ALGORITHMS", - ATTR_ALTERNATE_DOMAIN => "ALTERNATE-DOMAIN", - _ => other.as_str(), - }; - - write!(f, "{s}") - } -} - -impl AttrType { - /// required returns true if type is from comprehension-required range (0x0000-0x7FFF). - pub fn required(&self) -> bool { - self.0 <= 0x7FFF - } - - /// optional returns true if type is from comprehension-optional range (0x8000-0xFFFF). - pub fn optional(&self) -> bool { - self.0 >= 0x8000 - } - - /// value returns uint16 representation of attribute type. - pub fn value(&self) -> u16 { - self.0 - } -} - -/// Attributes from comprehension-required range (0x0000-0x7FFF). -pub const ATTR_MAPPED_ADDRESS: AttrType = AttrType(0x0001); // MAPPED-ADDRESS -pub const ATTR_USERNAME: AttrType = AttrType(0x0006); // USERNAME -pub const ATTR_MESSAGE_INTEGRITY: AttrType = AttrType(0x0008); // MESSAGE-INTEGRITY -pub const ATTR_ERROR_CODE: AttrType = AttrType(0x0009); // ERROR-CODE -pub const ATTR_UNKNOWN_ATTRIBUTES: AttrType = AttrType(0x000A); // UNKNOWN-ATTRIBUTES -pub const ATTR_REALM: AttrType = AttrType(0x0014); // REALM -pub const ATTR_NONCE: AttrType = AttrType(0x0015); // NONCE -pub const ATTR_XORMAPPED_ADDRESS: AttrType = AttrType(0x0020); // XOR-MAPPED-ADDRESS - -/// Attributes from comprehension-optional range (0x8000-0xFFFF). -pub const ATTR_SOFTWARE: AttrType = AttrType(0x8022); // SOFTWARE -pub const ATTR_ALTERNATE_SERVER: AttrType = AttrType(0x8023); // ALTERNATE-SERVER -pub const ATTR_FINGERPRINT: AttrType = AttrType(0x8028); // FINGERPRINT - -/// Attributes from RFC 5245 ICE. -pub const ATTR_PRIORITY: AttrType = AttrType(0x0024); // PRIORITY -pub const ATTR_USE_CANDIDATE: AttrType = AttrType(0x0025); // USE-CANDIDATE -pub const ATTR_ICE_CONTROLLED: AttrType = AttrType(0x8029); // ICE-CONTROLLED -pub const ATTR_ICE_CONTROLLING: AttrType = AttrType(0x802A); // ICE-CONTROLLING - -/// Attributes from RFC 5766 TURN. -pub const ATTR_CHANNEL_NUMBER: AttrType = AttrType(0x000C); // CHANNEL-NUMBER -pub const ATTR_LIFETIME: AttrType = AttrType(0x000D); // LIFETIME -pub const ATTR_XOR_PEER_ADDRESS: AttrType = AttrType(0x0012); // XOR-PEER-ADDRESS -pub const ATTR_DATA: AttrType = AttrType(0x0013); // DATA -pub const ATTR_XOR_RELAYED_ADDRESS: AttrType = AttrType(0x0016); // XOR-RELAYED-ADDRESS -pub const ATTR_EVEN_PORT: AttrType = AttrType(0x0018); // EVEN-PORT -pub const ATTR_REQUESTED_TRANSPORT: AttrType = AttrType(0x0019); // REQUESTED-TRANSPORT -pub const ATTR_DONT_FRAGMENT: AttrType = AttrType(0x001A); // DONT-FRAGMENT -pub const ATTR_RESERVATION_TOKEN: AttrType = AttrType(0x0022); // RESERVATION-TOKEN - -/// Attributes from RFC 5780 NAT Behavior Discovery -pub const ATTR_CHANGE_REQUEST: AttrType = AttrType(0x0003); // CHANGE-REQUEST -pub const ATTR_PADDING: AttrType = AttrType(0x0026); // PADDING -pub const ATTR_RESPONSE_PORT: AttrType = AttrType(0x0027); // RESPONSE-PORT -pub const ATTR_CACHE_TIMEOUT: AttrType = AttrType(0x8027); // CACHE-TIMEOUT -pub const ATTR_RESPONSE_ORIGIN: AttrType = AttrType(0x802b); // RESPONSE-ORIGIN -pub const ATTR_OTHER_ADDRESS: AttrType = AttrType(0x802C); // OTHER-ADDRESS - -/// Attributes from RFC 3489, removed by RFC 5389, -/// but still used by RFC5389-implementing software like Vovida.org, reTURNServer, etc. -pub const ATTR_SOURCE_ADDRESS: AttrType = AttrType(0x0004); // SOURCE-ADDRESS -pub const ATTR_CHANGED_ADDRESS: AttrType = AttrType(0x0005); // CHANGED-ADDRESS - -/// Attributes from RFC 6062 TURN Extensions for TCP Allocations. -pub const ATTR_CONNECTION_ID: AttrType = AttrType(0x002a); // CONNECTION-ID - -/// Attributes from RFC 6156 TURN IPv6. -pub const ATTR_REQUESTED_ADDRESS_FAMILY: AttrType = AttrType(0x0017); // REQUESTED-ADDRESS-FAMILY - -/// Attributes from An Origin Attribute for the STUN Protocol. -pub const ATTR_ORIGIN: AttrType = AttrType(0x802F); - -/// Attributes from RFC 8489 STUN. -pub const ATTR_MESSAGE_INTEGRITY_SHA256: AttrType = AttrType(0x001C); // MESSAGE-INTEGRITY-SHA256 -pub const ATTR_PASSWORD_ALGORITHM: AttrType = AttrType(0x001D); // PASSWORD-ALGORITHM -pub const ATTR_USER_HASH: AttrType = AttrType(0x001E); // USER-HASH -pub const ATTR_PASSWORD_ALGORITHMS: AttrType = AttrType(0x8002); // PASSWORD-ALGORITHMS -pub const ATTR_ALTERNATE_DOMAIN: AttrType = AttrType(0x8003); // ALTERNATE-DOMAIN - -/// RawAttribute is a Type-Length-Value (TLV) object that -/// can be added to a STUN message. Attributes are divided into two -/// types: comprehension-required and comprehension-optional. STUN -/// agents can safely ignore comprehension-optional attributes they -/// don't understand, but cannot successfully process a message if it -/// contains comprehension-required attributes that are not -/// understood. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct RawAttribute { - pub typ: AttrType, - pub length: u16, // ignored while encoding - pub value: Vec, -} - -impl fmt::Display for RawAttribute { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}: {:?}", self.typ, self.value) - } -} - -impl Setter for RawAttribute { - /// add_to implements Setter, adding attribute as a.Type with a.Value and ignoring - /// the Length field. - fn add_to(&self, m: &mut Message) -> Result<()> { - m.add(self.typ, &self.value); - Ok(()) - } -} - -pub(crate) const PADDING: usize = 4; - -/// STUN aligns attributes on 32-bit boundaries, attributes whose content -/// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of -/// padding so that its value contains a multiple of 4 bytes. The -/// padding bits are ignored, and may be any value. -/// -/// https://tools.ietf.org/html/rfc5389#section-15 -pub(crate) fn nearest_padded_value_length(l: usize) -> usize { - let mut n = PADDING * (l / PADDING); - if n < l { - n += PADDING - } - n -} - -/// This method converts uint16 vlue to AttrType. If it finds an old attribute -/// type value, it also translates it to the new value to enable backward -/// compatibility. (See: https://github.com/pion/stun/issues/21) -pub(crate) fn compat_attr_type(val: u16) -> AttrType { - if val == 0x8020 { - // draft-ietf-behave-rfc3489bis-02, MS-TURN - ATTR_XORMAPPED_ADDRESS // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on) - } else { - AttrType(val) - } -} diff --git a/stun/src/attributes/attributes_test.rs b/stun/src/attributes/attributes_test.rs deleted file mode 100644 index 3be540f06..000000000 --- a/stun/src/attributes/attributes_test.rs +++ /dev/null @@ -1,86 +0,0 @@ -use super::*; -use crate::textattrs::TextAttribute; - -#[test] -fn test_raw_attribute_add_to() -> Result<()> { - let v = vec![1, 2, 3, 4]; - let mut m = Message::new(); - let ra = Box::new(RawAttribute { - typ: ATTR_DATA, - value: v.clone(), - ..Default::default() - }); - m.build(&[ra])?; - let got_v = m.get(ATTR_DATA)?; - assert_eq!(got_v, v, "value mismatch"); - - Ok(()) -} - -#[test] -fn test_message_get_no_allocs() -> Result<()> { - let mut m = Message::new(); - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "c".to_owned(), - }; - a.add_to(&mut m)?; - m.write_header(); - - //"Default" - { - m.get(ATTR_SOFTWARE)?; - } - //"Not found" - { - let result = m.get(ATTR_ORIGIN); - assert!(result.is_err(), "should error"); - } - - Ok(()) -} - -#[test] -fn test_padding() -> Result<()> { - let tt = vec![ - (4, 4), // 0 - (2, 4), // 1 - (5, 8), // 2 - (8, 8), // 3 - (11, 12), // 4 - (1, 4), // 5 - (3, 4), // 6 - (6, 8), // 7 - (7, 8), // 8 - (0, 0), // 9 - (40, 40), // 10 - ]; - - for (i, o) in tt { - let got = nearest_padded_value_length(i); - assert_eq!(got, o, "padded({i}) {got} (got) != {o} (expected)",); - } - - Ok(()) -} - -#[test] -fn test_attr_type_range() -> Result<()> { - let tests = vec![ - ATTR_PRIORITY, - ATTR_ERROR_CODE, - ATTR_USE_CANDIDATE, - ATTR_EVEN_PORT, - ATTR_REQUESTED_ADDRESS_FAMILY, - ]; - for a in tests { - assert!(!a.optional() && a.required(), "should be required"); - } - - let tests = vec![ATTR_SOFTWARE, ATTR_ICE_CONTROLLED, ATTR_ORIGIN]; - for a in tests { - assert!(!a.required() && a.optional(), "should be optional"); - } - - Ok(()) -} diff --git a/stun/src/checks.rs b/stun/src/checks.rs deleted file mode 100644 index 9f4c134be..000000000 --- a/stun/src/checks.rs +++ /dev/null @@ -1,48 +0,0 @@ -use subtle::ConstantTimeEq; - -use crate::attributes::*; -use crate::error::*; - -// check_size returns ErrAttrSizeInvalid if got is not equal to expected. -pub fn check_size(_at: AttrType, got: usize, expected: usize) -> Result<()> { - if got == expected { - Ok(()) - } else { - Err(Error::ErrAttributeSizeInvalid) - } -} - -// is_attr_size_invalid returns true if error means that attribute size is invalid. -pub fn is_attr_size_invalid(err: &Error) -> bool { - Error::ErrAttributeSizeInvalid == *err -} - -pub(crate) fn check_hmac(got: &[u8], expected: &[u8]) -> Result<()> { - if got.ct_eq(expected).unwrap_u8() != 1 { - Err(Error::ErrIntegrityMismatch) - } else { - Ok(()) - } -} - -pub(crate) fn check_fingerprint(got: u32, expected: u32) -> Result<()> { - if got == expected { - Ok(()) - } else { - Err(Error::ErrFingerprintMismatch) - } -} - -// check_overflow returns ErrAttributeSizeOverflow if got is bigger that max. -pub fn check_overflow(_at: AttrType, got: usize, max: usize) -> Result<()> { - if got <= max { - Ok(()) - } else { - Err(Error::ErrAttributeSizeOverflow) - } -} - -// is_attr_size_overflow returns true if error means that attribute size is too big. -pub fn is_attr_size_overflow(err: &Error) -> bool { - Error::ErrAttributeSizeOverflow == *err -} diff --git a/stun/src/client.rs b/stun/src/client.rs deleted file mode 100644 index 13d8bd627..000000000 --- a/stun/src/client.rs +++ /dev/null @@ -1,473 +0,0 @@ -#[cfg(test)] -mod client_test; - -use std::collections::HashMap; -use std::io::BufReader; -use std::marker::{Send, Sync}; -use std::ops::Add; -use std::sync::Arc; - -use tokio::sync::mpsc; -use tokio::time::{self, Duration, Instant}; -use util::Conn; - -use crate::agent::*; -use crate::error::*; -use crate::message::*; - -const DEFAULT_TIMEOUT_RATE: Duration = Duration::from_millis(5); -const DEFAULT_RTO: Duration = Duration::from_millis(300); -const DEFAULT_MAX_ATTEMPTS: u32 = 7; -const DEFAULT_MAX_BUFFER_SIZE: usize = 8; - -/// Collector calls function f with constant rate. -/// -/// The simple Collector is ticker which calls function on each tick. -pub trait Collector { - fn start( - &mut self, - rate: Duration, - client_agent_tx: Arc>, - ) -> Result<()>; - fn close(&mut self) -> Result<()>; -} - -#[derive(Default)] -struct TickerCollector { - close_tx: Option>, -} - -impl Collector for TickerCollector { - fn start( - &mut self, - rate: Duration, - client_agent_tx: Arc>, - ) -> Result<()> { - let (close_tx, mut close_rx) = mpsc::channel(1); - self.close_tx = Some(close_tx); - - tokio::spawn(async move { - let mut interval = time::interval(rate); - - loop { - tokio::select! { - _ = close_rx.recv() => break, - _ = interval.tick() => { - if client_agent_tx.send(ClientAgent::Collect(Instant::now())).await.is_err() { - break; - } - } - } - } - }); - - Ok(()) - } - - fn close(&mut self) -> Result<()> { - if self.close_tx.is_none() { - return Err(Error::ErrCollectorClosed); - } - self.close_tx.take(); - Ok(()) - } -} - -/// ClientTransaction represents transaction in progress. -/// If transaction is succeed or failed, f will be called -/// provided by event. -/// Concurrent access is invalid. -#[derive(Debug, Clone)] -pub struct ClientTransaction { - id: TransactionId, - attempt: u32, - calls: u32, - handler: Handler, - start: Instant, - rto: Duration, - raw: Vec, -} - -impl ClientTransaction { - pub(crate) fn handle(&mut self, e: Event) -> Result<()> { - self.calls += 1; - if self.calls == 1 { - if let Some(handler) = &self.handler { - handler.send(e)?; - } - } - Ok(()) - } - - pub(crate) fn next_timeout(&self, now: Instant) -> Instant { - now.add((self.attempt + 1) * self.rto) - } -} - -struct ClientSettings { - buffer_size: usize, - rto: Duration, - rto_rate: Duration, - max_attempts: u32, - closed: bool, - //handler: Handler, - collector: Option>, - c: Option>, -} - -impl Default for ClientSettings { - fn default() -> Self { - ClientSettings { - buffer_size: DEFAULT_MAX_BUFFER_SIZE, - rto: DEFAULT_RTO, - rto_rate: DEFAULT_TIMEOUT_RATE, - max_attempts: DEFAULT_MAX_ATTEMPTS, - closed: false, - //handler: None, - collector: None, - c: None, - } - } -} - -#[derive(Default)] -pub struct ClientBuilder { - settings: ClientSettings, -} - -impl ClientBuilder { - // WithHandler sets client handler which is called if Agent emits the Event - // with TransactionID that is not currently registered by Client. - // Useful for handling Data indications from TURN server. - //pub fn with_handler(mut self, handler: Handler) -> Self { - // self.settings.handler = handler; - // self - //} - - /// with_rto sets client RTO as defined in STUN RFC. - pub fn with_rto(mut self, rto: Duration) -> Self { - self.settings.rto = rto; - self - } - - /// with_timeout_rate sets RTO timer minimum resolution. - pub fn with_timeout_rate(mut self, d: Duration) -> Self { - self.settings.rto_rate = d; - self - } - - /// with_buffer_size sets buffer size. - pub fn with_buffer_size(mut self, buffer_size: usize) -> Self { - self.settings.buffer_size = buffer_size; - self - } - - /// with_collector rests client timeout collector, the implementation - /// of ticker which calls function on each tick. - pub fn with_collector(mut self, coll: Box) -> Self { - self.settings.collector = Some(coll); - self - } - - /// with_conn sets transport connection - pub fn with_conn(mut self, conn: Arc) -> Self { - self.settings.c = Some(conn); - self - } - - /// with_no_retransmit disables retransmissions and sets RTO to - /// DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO which will be effectively time out - /// if not set. - /// Useful for TCP connections where transport handles RTO. - pub fn with_no_retransmit(mut self) -> Self { - self.settings.max_attempts = 0; - if self.settings.rto == Duration::from_secs(0) { - self.settings.rto = DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO; - } - self - } - - pub fn new() -> Self { - ClientBuilder { - settings: ClientSettings::default(), - } - } - - pub fn build(self) -> Result { - if self.settings.c.is_none() { - return Err(Error::ErrNoConnection); - } - - let client = Client { - settings: self.settings, - ..Default::default() - } - .run()?; - - Ok(client) - } -} - -/// Client simulates "connection" to STUN server. -#[derive(Default)] -pub struct Client { - settings: ClientSettings, - close_tx: Option>, - client_agent_tx: Option>>, - handler_tx: Option>>, -} - -impl Client { - async fn read_until_closed( - mut close_rx: mpsc::Receiver<()>, - c: Arc, - client_agent_tx: Arc>, - ) { - let mut msg = Message::new(); - let mut buf = vec![0; 1024]; - - loop { - tokio::select! { - _ = close_rx.recv() => return, - res = c.recv(&mut buf) => { - if let Ok(n) = res { - let mut reader = BufReader::new(&buf[..n]); - let result = msg.read_from(&mut reader); - if result.is_err() { - continue; - } - - if client_agent_tx.send(ClientAgent::Process(msg.clone())).await.is_err(){ - return; - } - } - } - } - } - } - - fn insert(&mut self, ct: ClientTransaction) -> Result<()> { - if self.settings.closed { - return Err(Error::ErrClientClosed); - } - - if let Some(handler_tx) = &mut self.handler_tx { - handler_tx.send(Event { - event_type: EventType::Insert(ct), - ..Default::default() - })?; - } - - Ok(()) - } - - fn remove(&mut self, id: TransactionId) -> Result<()> { - if self.settings.closed { - return Err(Error::ErrClientClosed); - } - - if let Some(handler_tx) = &mut self.handler_tx { - handler_tx.send(Event { - event_type: EventType::Remove(id), - ..Default::default() - })?; - } - - Ok(()) - } - - fn start( - conn: Option>, - mut handler_rx: mpsc::UnboundedReceiver, - client_agent_tx: Arc>, - mut t: HashMap, - max_attempts: u32, - ) { - tokio::spawn(async move { - while let Some(event) = handler_rx.recv().await { - match event.event_type { - EventType::Close => { - break; - } - EventType::Insert(ct) => { - if t.contains_key(&ct.id) { - continue; - } - t.insert(ct.id, ct); - } - EventType::Remove(id) => { - t.remove(&id); - } - EventType::Callback(id) => { - let mut ct = if t.contains_key(&id) { - t.remove(&id).unwrap() - } else { - /*if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) { - c.handler(e) - }*/ - continue; - }; - - if ct.attempt >= max_attempts || event.event_body.is_ok() { - if let Some(handler) = ct.handler { - let _ = handler.send(event); - } - continue; - } - - // Doing re-transmission. - ct.attempt += 1; - - let raw = ct.raw.clone(); - let timeout = ct.next_timeout(Instant::now()); - let id = ct.id; - - // Starting client transaction. - t.insert(ct.id, ct); - - // Starting agent transaction. - if client_agent_tx - .send(ClientAgent::Start(id, timeout)) - .await - .is_err() - { - let ct = t.remove(&id).unwrap(); - if let Some(handler) = ct.handler { - let _ = handler.send(event); - } - continue; - } - - // Writing message to connection again. - if let Some(c) = &conn { - if c.send(&raw).await.is_err() { - let _ = client_agent_tx.send(ClientAgent::Stop(id)).await; - - let ct = t.remove(&id).unwrap(); - if let Some(handler) = ct.handler { - let _ = handler.send(event); - } - continue; - } - } - } - }; - } - }); - } - - /// close stops internal connection and agent, returning CloseErr on error. - pub async fn close(&mut self) -> Result<()> { - if self.settings.closed { - return Err(Error::ErrClientClosed); - } - - self.settings.closed = true; - - if let Some(collector) = &mut self.settings.collector { - let _ = collector.close(); - } - self.settings.collector.take(); - - self.close_tx.take(); //drop close channel - if let Some(client_agent_tx) = &mut self.client_agent_tx { - let _ = client_agent_tx.send(ClientAgent::Close).await; - } - self.client_agent_tx.take(); - - if let Some(c) = self.settings.c.take() { - c.close().await?; - } - - Ok(()) - } - - fn run(mut self) -> Result { - let (close_tx, close_rx) = mpsc::channel(1); - let (client_agent_tx, client_agent_rx) = mpsc::channel(self.settings.buffer_size); - let (handler_tx, handler_rx) = mpsc::unbounded_channel(); - let t: HashMap = HashMap::new(); - - let client_agent_tx = Arc::new(client_agent_tx); - let handler_tx = Arc::new(handler_tx); - self.client_agent_tx = Some(Arc::clone(&client_agent_tx)); - self.handler_tx = Some(Arc::clone(&handler_tx)); - self.close_tx = Some(close_tx); - - let conn = if let Some(conn) = &self.settings.c { - Arc::clone(conn) - } else { - return Err(Error::ErrNoConnection); - }; - - Client::start( - self.settings.c.clone(), - handler_rx, - Arc::clone(&client_agent_tx), - t, - self.settings.max_attempts, - ); - - let agent = Agent::new(Some(handler_tx)); - tokio::spawn(async move { Agent::run(agent, client_agent_rx).await }); - - if self.settings.collector.is_none() { - self.settings.collector = Some(Box::::default()); - } - if let Some(collector) = &mut self.settings.collector { - collector.start(self.settings.rto_rate, Arc::clone(&client_agent_tx))?; - } - - let conn_rx = Arc::clone(&conn); - tokio::spawn( - async move { Client::read_until_closed(close_rx, conn_rx, client_agent_tx).await }, - ); - - Ok(self) - } - - pub async fn send(&mut self, m: &Message, handler: Handler) -> Result<()> { - if self.settings.closed { - return Err(Error::ErrClientClosed); - } - - let has_handler = handler.is_some(); - - if handler.is_some() { - let t = ClientTransaction { - id: m.transaction_id, - attempt: 0, - calls: 0, - handler, - start: Instant::now(), - rto: self.settings.rto, - raw: m.raw.clone(), - }; - let d = t.next_timeout(t.start); - self.insert(t)?; - - if let Some(client_agent_tx) = &mut self.client_agent_tx { - client_agent_tx - .send(ClientAgent::Start(m.transaction_id, d)) - .await?; - } - } - - if let Some(c) = &self.settings.c { - let result = c.send(&m.raw).await; - if result.is_err() && has_handler { - self.remove(m.transaction_id)?; - - if let Some(client_agent_tx) = &mut self.client_agent_tx { - client_agent_tx - .send(ClientAgent::Stop(m.transaction_id)) - .await?; - } - } else if let Err(err) = result { - return Err(Error::Other(err.to_string())); - } - } - - Ok(()) - } -} diff --git a/stun/src/client/client_test.rs b/stun/src/client/client_test.rs deleted file mode 100644 index c7bad84ef..000000000 --- a/stun/src/client/client_test.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::*; - -#[test] -fn ensure_client_settings_is_send() { - let client = ClientSettings::default(); - - ensure_send(client); -} - -fn ensure_send(_: T) {} - -//TODO: add more client tests diff --git a/stun/src/error.rs b/stun/src/error.rs deleted file mode 100644 index 083b3adb0..000000000 --- a/stun/src/error.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::io; -use std::string::FromUtf8Error; - -use thiserror::Error; -use tokio::sync::mpsc::error::SendError as MpscSendError; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("attribute not found")] - ErrAttributeNotFound, - #[error("transaction is stopped")] - ErrTransactionStopped, - #[error("transaction not exists")] - ErrTransactionNotExists, - #[error("transaction exists with same id")] - ErrTransactionExists, - #[error("agent is closed")] - ErrAgentClosed, - #[error("transaction is timed out")] - ErrTransactionTimeOut, - #[error("no default reason for ErrorCode")] - ErrNoDefaultReason, - #[error("unexpected EOF")] - ErrUnexpectedEof, - #[error("attribute size is invalid")] - ErrAttributeSizeInvalid, - #[error("attribute size overflow")] - ErrAttributeSizeOverflow, - #[error("attempt to decode to nil message")] - ErrDecodeToNil, - #[error("unexpected EOF: not enough bytes to read header")] - ErrUnexpectedHeaderEof, - #[error("integrity check failed")] - ErrIntegrityMismatch, - #[error("fingerprint check failed")] - ErrFingerprintMismatch, - #[error("FINGERPRINT before MESSAGE-INTEGRITY attribute")] - ErrFingerprintBeforeIntegrity, - #[error("bad UNKNOWN-ATTRIBUTES size")] - ErrBadUnknownAttrsSize, - #[error("invalid length of IP value")] - ErrBadIpLength, - #[error("no connection provided")] - ErrNoConnection, - #[error("client is closed")] - ErrClientClosed, - #[error("no agent is set")] - ErrNoAgent, - #[error("collector is closed")] - ErrCollectorClosed, - #[error("unsupported network")] - ErrUnsupportedNetwork, - #[error("invalid url")] - ErrInvalidUrl, - #[error("unknown scheme type")] - ErrSchemeType, - #[error("invalid hostname")] - ErrHost, - #[error("{0}")] - Other(String), - #[error("url parse: {0}")] - Url(#[from] url::ParseError), - #[error("utf8: {0}")] - Utf8(#[from] FromUtf8Error), - #[error("{0}")] - Io(#[source] IoError), - #[error("mpsc send: {0}")] - MpscSend(String), - #[error("{0}")] - Util(#[from] util::Error), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} - -// Because Tokio SendError is parameterized, we sadly lose the backtrace. -impl From> for Error { - fn from(e: MpscSendError) -> Self { - Error::MpscSend(e.to_string()) - } -} diff --git a/stun/src/error_code.rs b/stun/src/error_code.rs deleted file mode 100644 index 8142a54cc..000000000 --- a/stun/src/error_code.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::collections::HashMap; -use std::fmt; - -use crate::attributes::*; -use crate::checks::*; -use crate::error::*; -use crate::message::*; - -// ErrorCodeAttribute represents ERROR-CODE attribute. -// -// RFC 5389 Section 15.6 -#[derive(Default)] -pub struct ErrorCodeAttribute { - pub code: ErrorCode, - pub reason: Vec, -} - -impl fmt::Display for ErrorCodeAttribute { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let reason = match String::from_utf8(self.reason.clone()) { - Ok(reason) => reason, - Err(_) => return Err(fmt::Error {}), - }; - - write!(f, "{}: {}", self.code.0, reason) - } -} - -// constants for ERROR-CODE encoding. -const ERROR_CODE_CLASS_BYTE: usize = 2; -const ERROR_CODE_NUMBER_BYTE: usize = 3; -const ERROR_CODE_REASON_START: usize = 4; -const ERROR_CODE_REASON_MAX_B: usize = 763; -const ERROR_CODE_MODULO: u16 = 100; - -impl Setter for ErrorCodeAttribute { - // add_to adds ERROR-CODE to m. - fn add_to(&self, m: &mut Message) -> Result<()> { - check_overflow( - ATTR_ERROR_CODE, - self.reason.len() + ERROR_CODE_REASON_START, - ERROR_CODE_REASON_MAX_B + ERROR_CODE_REASON_START, - )?; - - let mut value: Vec = Vec::with_capacity(ERROR_CODE_REASON_MAX_B); - - let number = (self.code.0 % ERROR_CODE_MODULO) as u8; // error code modulo 100 - let class = (self.code.0 / ERROR_CODE_MODULO) as u8; // hundred digit - value.extend_from_slice(&[0, 0]); - value.push(class); // [ERROR_CODE_CLASS_BYTE] - value.push(number); //[ERROR_CODE_NUMBER_BYTE] = - value.extend_from_slice(&self.reason); //[ERROR_CODE_REASON_START:] - - m.add(ATTR_ERROR_CODE, &value); - - Ok(()) - } -} - -impl Getter for ErrorCodeAttribute { - // GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid. - fn get_from(&mut self, m: &Message) -> Result<()> { - let v = m.get(ATTR_ERROR_CODE)?; - - if v.len() < ERROR_CODE_REASON_START { - return Err(Error::ErrUnexpectedEof); - } - - let class = v[ERROR_CODE_CLASS_BYTE] as u16; - let number = v[ERROR_CODE_NUMBER_BYTE] as u16; - let code = class * ERROR_CODE_MODULO + number; - self.code = ErrorCode(code); - self.reason = v[ERROR_CODE_REASON_START..].to_vec(); - - Ok(()) - } -} - -// ErrorCode is code for ERROR-CODE attribute. -#[derive(PartialEq, Eq, Hash, Copy, Clone, Default)] -pub struct ErrorCode(pub u16); - -impl Setter for ErrorCode { - // add_to adds ERROR-CODE with default reason to m. If there - // is no default reason, returns ErrNoDefaultReason. - fn add_to(&self, m: &mut Message) -> Result<()> { - if let Some(reason) = ERROR_REASONS.get(self) { - let a = ErrorCodeAttribute { - code: *self, - reason: reason.clone(), - }; - a.add_to(m) - } else { - Err(Error::ErrNoDefaultReason) - } - } -} - -// Possible error codes. -pub const CODE_TRY_ALTERNATE: ErrorCode = ErrorCode(300); -pub const CODE_BAD_REQUEST: ErrorCode = ErrorCode(400); -pub const CODE_UNAUTHORIZED: ErrorCode = ErrorCode(401); -pub const CODE_UNKNOWN_ATTRIBUTE: ErrorCode = ErrorCode(420); -pub const CODE_STALE_NONCE: ErrorCode = ErrorCode(438); -pub const CODE_ROLE_CONFLICT: ErrorCode = ErrorCode(487); -pub const CODE_SERVER_ERROR: ErrorCode = ErrorCode(500); - -// DEPRECATED constants. -// DEPRECATED, use CODE_UNAUTHORIZED. -pub const CODE_UNAUTHORISED: ErrorCode = CODE_UNAUTHORIZED; - -// Error codes from RFC 5766. -// -// RFC 5766 Section 15 -pub const CODE_FORBIDDEN: ErrorCode = ErrorCode(403); // Forbidden -pub const CODE_ALLOC_MISMATCH: ErrorCode = ErrorCode(437); // Allocation Mismatch -pub const CODE_WRONG_CREDENTIALS: ErrorCode = ErrorCode(441); // Wrong Credentials -pub const CODE_UNSUPPORTED_TRANS_PROTO: ErrorCode = ErrorCode(442); // Unsupported Transport Protocol -pub const CODE_ALLOC_QUOTA_REACHED: ErrorCode = ErrorCode(486); // Allocation Quota Reached -pub const CODE_INSUFFICIENT_CAPACITY: ErrorCode = ErrorCode(508); // Insufficient Capacity - -// Error codes from RFC 6062. -// -// RFC 6062 Section 6.3 -pub const CODE_CONN_ALREADY_EXISTS: ErrorCode = ErrorCode(446); -pub const CODE_CONN_TIMEOUT_OR_FAILURE: ErrorCode = ErrorCode(447); - -// Error codes from RFC 6156. -// -// RFC 6156 Section 10.2 -pub const CODE_ADDR_FAMILY_NOT_SUPPORTED: ErrorCode = ErrorCode(440); // Address Family not Supported -pub const CODE_PEER_ADDR_FAMILY_MISMATCH: ErrorCode = ErrorCode(443); // Peer Address Family Mismatch - -lazy_static! { - pub static ref ERROR_REASONS:HashMap> = - [ - (CODE_TRY_ALTERNATE, b"Try Alternate".to_vec()), - (CODE_BAD_REQUEST, b"Bad Request".to_vec()), - (CODE_UNAUTHORIZED, b"Unauthorized".to_vec()), - (CODE_UNKNOWN_ATTRIBUTE, b"Unknown Attribute".to_vec()), - (CODE_STALE_NONCE, b"Stale Nonce".to_vec()), - (CODE_SERVER_ERROR, b"Server Error".to_vec()), - (CODE_ROLE_CONFLICT, b"Role Conflict".to_vec()), - - // RFC 5766. - (CODE_FORBIDDEN, b"Forbidden".to_vec()), - (CODE_ALLOC_MISMATCH, b"Allocation Mismatch".to_vec()), - (CODE_WRONG_CREDENTIALS, b"Wrong Credentials".to_vec()), - (CODE_UNSUPPORTED_TRANS_PROTO, b"Unsupported Transport Protocol".to_vec()), - (CODE_ALLOC_QUOTA_REACHED, b"Allocation Quota Reached".to_vec()), - (CODE_INSUFFICIENT_CAPACITY, b"Insufficient Capacity".to_vec()), - - // RFC 6062. - (CODE_CONN_ALREADY_EXISTS, b"Connection Already Exists".to_vec()), - (CODE_CONN_TIMEOUT_OR_FAILURE, b"Connection Timeout or Failure".to_vec()), - - // RFC 6156. - (CODE_ADDR_FAMILY_NOT_SUPPORTED, b"Address Family not Supported".to_vec()), - (CODE_PEER_ADDR_FAMILY_MISMATCH, b"Peer Address Family Mismatch".to_vec()), - ].iter().cloned().collect(); - -} diff --git a/stun/src/fingerprint.rs b/stun/src/fingerprint.rs deleted file mode 100644 index 648c288ec..000000000 --- a/stun/src/fingerprint.rs +++ /dev/null @@ -1,64 +0,0 @@ -#[cfg(test)] -mod fingerprint_test; - -use crc::{Crc, CRC_32_ISO_HDLC}; - -use crate::attributes::ATTR_FINGERPRINT; -use crate::checks::*; -use crate::error::*; -use crate::message::*; - -// FingerprintAttr represents FINGERPRINT attribute. -// -// RFC 5389 Section 15.5 -pub struct FingerprintAttr; - -// FINGERPRINT is shorthand for FingerprintAttr. -// -// Example: -// -// m := New() -// FINGERPRINT.add_to(m) -pub const FINGERPRINT: FingerprintAttr = FingerprintAttr {}; - -pub const FINGERPRINT_XOR_VALUE: u32 = 0x5354554e; -pub const FINGERPRINT_SIZE: usize = 4; // 32 bit - -// FingerprintValue returns CRC-32 of b XOR-ed by 0x5354554e. -// -// The value of the attribute is computed as the CRC-32 of the STUN message -// up to (but excluding) the FINGERPRINT attribute itself, XOR'ed with -// the 32-bit value 0x5354554e (the XOR helps in cases where an -// application packet is also using CRC-32 in it). -pub fn fingerprint_value(b: &[u8]) -> u32 { - let checksum = Crc::::new(&CRC_32_ISO_HDLC).checksum(b); - checksum ^ FINGERPRINT_XOR_VALUE // XOR -} - -impl Setter for FingerprintAttr { - // add_to adds fingerprint to message. - fn add_to(&self, m: &mut Message) -> Result<()> { - let l = m.length; - // length in header should include size of fingerprint attribute - m.length += (FINGERPRINT_SIZE + ATTRIBUTE_HEADER_SIZE) as u32; // increasing length - m.write_length(); // writing Length to Raw - let val = fingerprint_value(&m.raw); - let b = val.to_be_bytes(); - m.length = l; - m.add(ATTR_FINGERPRINT, &b); - Ok(()) - } -} - -impl FingerprintAttr { - // Check reads fingerprint value from m and checks it, returning error if any. - // Can return *AttrLengthErr, ErrAttributeNotFound, and *CRCMismatch. - pub fn check(&self, m: &Message) -> Result<()> { - let b = m.get(ATTR_FINGERPRINT)?; - check_size(ATTR_FINGERPRINT, b.len(), FINGERPRINT_SIZE)?; - let val = u32::from_be_bytes([b[0], b[1], b[2], b[3]]); - let attr_start = m.raw.len() - (FINGERPRINT_SIZE + ATTRIBUTE_HEADER_SIZE); - let expected = fingerprint_value(&m.raw[..attr_start]); - check_fingerprint(val, expected) - } -} diff --git a/stun/src/fingerprint/fingerprint_test.rs b/stun/src/fingerprint/fingerprint_test.rs deleted file mode 100644 index 1ac589d54..000000000 --- a/stun/src/fingerprint/fingerprint_test.rs +++ /dev/null @@ -1,73 +0,0 @@ -use super::*; -use crate::attributes::ATTR_SOFTWARE; -use crate::textattrs::TextAttribute; - -#[test] -fn fingerprint_uses_crc_32_iso_hdlc() -> Result<()> { - let mut m = Message::new(); - - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "software".to_owned(), - }; - a.add_to(&mut m)?; - m.write_header(); - - FINGERPRINT.add_to(&mut m)?; - m.write_header(); - - assert_eq!(&m.raw[0..m.raw.len()-8], b"\x00\x00\x00\x14\x21\x12\xA4\x42\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x22\x00\x08\x73\x6F\x66\x74\x77\x61\x72\x65"); - - assert_eq!(m.raw[m.raw.len() - 4..], [0xe4, 0x4c, 0x33, 0xd9]); - - Ok(()) -} - -#[test] -fn test_fingerprint_check() -> Result<()> { - let mut m = Message::new(); - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "software".to_owned(), - }; - a.add_to(&mut m)?; - m.write_header(); - - FINGERPRINT.add_to(&mut m)?; - m.write_header(); - FINGERPRINT.check(&m)?; - m.raw[3] += 1; - - let result = FINGERPRINT.check(&m); - assert!(result.is_err(), "should error"); - - Ok(()) -} - -#[test] -fn test_fingerprint_check_bad() -> Result<()> { - let mut m = Message::new(); - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "software".to_owned(), - }; - a.add_to(&mut m)?; - m.write_header(); - - let result = FINGERPRINT.check(&m); - assert!(result.is_err(), "should error"); - - m.add(ATTR_FINGERPRINT, &[1, 2, 3]); - - let result = FINGERPRINT.check(&m); - if let Err(err) = result { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("Expected error, but got ok"); - } - - Ok(()) -} diff --git a/stun/src/integrity.rs b/stun/src/integrity.rs deleted file mode 100644 index cd692da18..000000000 --- a/stun/src/integrity.rs +++ /dev/null @@ -1,118 +0,0 @@ -#[cfg(test)] -mod integrity_test; - -use std::fmt; - -use md5::{Digest, Md5}; -use ring::hmac; - -use crate::attributes::*; -use crate::checks::*; -use crate::error::*; -use crate::message::*; - -// separator for credentials. -pub(crate) const CREDENTIALS_SEP: &str = ":"; - -// MessageIntegrity represents MESSAGE-INTEGRITY attribute. -// -// add_to and Check methods are using zero-allocation version of hmac, see -// newHMAC function and internal/hmac/pool.go. -// -// RFC 5389 Section 15.4 -#[derive(Default, Clone)] -pub struct MessageIntegrity(pub Vec); - -fn new_hmac(key: &[u8], message: &[u8]) -> Vec { - let mac = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key); - hmac::sign(&mac, message).as_ref().to_vec() -} - -impl fmt::Display for MessageIntegrity { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "KEY: 0x{:x?}", self.0) - } -} - -impl Setter for MessageIntegrity { - // add_to adds MESSAGE-INTEGRITY attribute to message. - // - // CPU costly, see BenchmarkMessageIntegrity_AddTo. - fn add_to(&self, m: &mut Message) -> Result<()> { - for a in &m.attributes.0 { - // Message should not contain FINGERPRINT attribute - // before MESSAGE-INTEGRITY. - if a.typ == ATTR_FINGERPRINT { - return Err(Error::ErrFingerprintBeforeIntegrity); - } - } - // The text used as input to HMAC is the STUN message, - // including the header, up to and including the attribute preceding the - // MESSAGE-INTEGRITY attribute. - let length = m.length; - // Adjusting m.Length to contain MESSAGE-INTEGRITY TLV. - m.length += (MESSAGE_INTEGRITY_SIZE + ATTRIBUTE_HEADER_SIZE) as u32; - m.write_length(); // writing length to m.Raw - let v = new_hmac(&self.0, &m.raw); // calculating HMAC for adjusted m.Raw - m.length = length; // changing m.Length back - - m.add(ATTR_MESSAGE_INTEGRITY, &v); - - Ok(()) - } -} - -pub(crate) const MESSAGE_INTEGRITY_SIZE: usize = 20; - -impl MessageIntegrity { - // new_long_term_integrity returns new MessageIntegrity with key for long-term - // credentials. Password, username, and realm must be SASL-prepared. - pub fn new_long_term_integrity(username: String, realm: String, password: String) -> Self { - let s = [username, realm, password].join(CREDENTIALS_SEP); - - let mut h = Md5::new(); - h.update(s.as_bytes()); - - MessageIntegrity(h.finalize().as_slice().to_vec()) - } - - // new_short_term_integrity returns new MessageIntegrity with key for short-term - // credentials. Password must be SASL-prepared. - pub fn new_short_term_integrity(password: String) -> Self { - MessageIntegrity(password.as_bytes().to_vec()) - } - - // Check checks MESSAGE-INTEGRITY attribute. - // - // CPU costly, see BenchmarkMessageIntegrity_Check. - pub fn check(&self, m: &mut Message) -> Result<()> { - let v = m.get(ATTR_MESSAGE_INTEGRITY)?; - - // Adjusting length in header to match m.Raw that was - // used when computing HMAC. - - let length = m.length as usize; - let mut after_integrity = false; - let mut size_reduced = 0; - - for a in &m.attributes.0 { - if after_integrity { - size_reduced += nearest_padded_value_length(a.length as usize); - size_reduced += ATTRIBUTE_HEADER_SIZE; - } - if a.typ == ATTR_MESSAGE_INTEGRITY { - after_integrity = true; - } - } - m.length -= size_reduced as u32; - m.write_length(); - // start_of_hmac should be first byte of integrity attribute. - let start_of_hmac = MESSAGE_HEADER_SIZE + m.length as usize - - (ATTRIBUTE_HEADER_SIZE + MESSAGE_INTEGRITY_SIZE); - let b = &m.raw[..start_of_hmac]; // data before integrity attribute - let expected = new_hmac(&self.0, b); - m.length = length as u32; - m.write_length(); // writing length back - check_hmac(&v, &expected) - } -} diff --git a/stun/src/integrity/integrity_test.rs b/stun/src/integrity/integrity_test.rs deleted file mode 100644 index e085cfa26..000000000 --- a/stun/src/integrity/integrity_test.rs +++ /dev/null @@ -1,93 +0,0 @@ -use super::*; -use crate::agent::TransactionId; -use crate::attributes::ATTR_SOFTWARE; -use crate::fingerprint::FINGERPRINT; -use crate::textattrs::TextAttribute; - -#[test] -fn test_message_integrity_add_to_simple() -> Result<()> { - let i = MessageIntegrity::new_long_term_integrity( - "user".to_owned(), - "realm".to_owned(), - "pass".to_owned(), - ); - let expected = vec![ - 0x84, 0x93, 0xfb, 0xc5, 0x3b, 0xa5, 0x82, 0xfb, 0x4c, 0x04, 0x4c, 0x45, 0x6b, 0xdc, 0x40, - 0xeb, - ]; - assert_eq!(i.0, expected, "{}", Error::ErrIntegrityMismatch); - - //"Check" - { - let mut m = Message::new(); - m.write_header(); - i.add_to(&mut m)?; - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "software".to_owned(), - }; - a.add_to(&mut m)?; - m.write_header(); - - let mut d_m = Message::new(); - d_m.raw.clone_from(&m.raw); - d_m.decode()?; - i.check(&mut d_m)?; - - d_m.raw[24] += 12; // HMAC now invalid - d_m.decode()?; - let result = i.check(&mut d_m); - assert!(result.is_err(), "should be invalid"); - } - - Ok(()) -} - -#[test] -fn test_message_integrity_with_fingerprint() -> Result<()> { - let mut m = Message::new(); - m.transaction_id = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0]); - m.write_header(); - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "software".to_owned(), - }; - a.add_to(&mut m)?; - - let i = MessageIntegrity::new_short_term_integrity("pwd".to_owned()); - assert_eq!(i.to_string(), "KEY: 0x[70, 77, 64]", "bad string {i}"); - let result = i.check(&mut m); - assert!(result.is_err(), "should error"); - - i.add_to(&mut m)?; - FINGERPRINT.add_to(&mut m)?; - i.check(&mut m)?; - m.raw[24] = 33; - m.decode()?; - let result = i.check(&mut m); - assert!(result.is_err(), "mismatch expected"); - - Ok(()) -} - -#[test] -fn test_message_integrity() -> Result<()> { - let mut m = Message::new(); - let i = MessageIntegrity::new_short_term_integrity("password".to_owned()); - m.write_header(); - i.add_to(&mut m)?; - m.get(ATTR_MESSAGE_INTEGRITY)?; - Ok(()) -} - -#[test] -fn test_message_integrity_before_fingerprint() -> Result<()> { - let mut m = Message::new(); - m.write_header(); - FINGERPRINT.add_to(&mut m)?; - let i = MessageIntegrity::new_short_term_integrity("password".to_owned()); - let result = i.add_to(&mut m); - assert!(result.is_err(), "should error"); - - Ok(()) -} diff --git a/stun/src/lib.rs b/stun/src/lib.rs deleted file mode 100644 index f13ff34d4..000000000 --- a/stun/src/lib.rs +++ /dev/null @@ -1,26 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -#[macro_use] -extern crate lazy_static; - -pub mod addr; -pub mod agent; -pub mod attributes; -pub mod checks; -pub mod client; -mod error; -pub mod error_code; -pub mod fingerprint; -pub mod integrity; -pub mod message; -pub mod textattrs; -pub mod uattrs; -pub mod uri; -pub mod xoraddr; - -// IANA assigned ports for "stun" protocol. -pub const DEFAULT_PORT: u16 = 3478; -pub const DEFAULT_TLS_PORT: u16 = 5349; - -pub use error::Error; diff --git a/stun/src/message.rs b/stun/src/message.rs deleted file mode 100644 index 6ac245e08..000000000 --- a/stun/src/message.rs +++ /dev/null @@ -1,626 +0,0 @@ -#[cfg(test)] -mod message_test; - -use std::fmt; -use std::io::{Read, Write}; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use rand::Rng; - -use crate::agent::*; -use crate::attributes::*; -use crate::error::*; - -// MAGIC_COOKIE is fixed value that aids in distinguishing STUN packets -// from packets of other protocols when STUN is multiplexed with those -// other protocols on the same Port. -// -// The magic cookie field MUST contain the fixed value 0x2112A442 in -// network byte order. -// -// Defined in "STUN Message Structure", section 6. -pub const MAGIC_COOKIE: u32 = 0x2112A442; -pub const ATTRIBUTE_HEADER_SIZE: usize = 4; -pub const MESSAGE_HEADER_SIZE: usize = 20; - -// TRANSACTION_ID_SIZE is length of transaction id array (in bytes). -pub const TRANSACTION_ID_SIZE: usize = 12; // 96 bit - -// Interfaces that are implemented by message attributes, shorthands for them, -// or helpers for message fields as type or transaction id. -pub trait Setter { - // Setter sets *Message attribute. - fn add_to(&self, m: &mut Message) -> Result<()>; -} - -// Getter parses attribute from *Message. -pub trait Getter { - fn get_from(&mut self, m: &Message) -> Result<()>; -} - -// Checker checks *Message attribute. -pub trait Checker { - fn check(&self, m: &Message) -> Result<()>; -} - -// is_message returns true if b looks like STUN message. -// Useful for multiplexing. is_message does not guarantee -// that decoding will be successful. -pub fn is_message(b: &[u8]) -> bool { - b.len() >= MESSAGE_HEADER_SIZE && u32::from_be_bytes([b[4], b[5], b[6], b[7]]) == MAGIC_COOKIE -} -// Message represents a single STUN packet. It uses aggressive internal -// buffering to enable zero-allocation encoding and decoding, -// so there are some usage constraints: -// -// Message, its fields, results of m.Get or any attribute a.GetFrom -// are valid only until Message.Raw is not modified. -#[derive(Default, Debug, Clone)] -pub struct Message { - pub typ: MessageType, - pub length: u32, // len(Raw) not including header - pub transaction_id: TransactionId, - pub attributes: Attributes, - pub raw: Vec, -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let t_id = BASE64_STANDARD.encode(self.transaction_id.0); - write!( - f, - "{} l={} attrs={} id={}", - self.typ, - self.length, - self.attributes.0.len(), - t_id - ) - } -} - -// Equal returns true if Message b equals to m. -// Ignores m.Raw. -impl PartialEq for Message { - fn eq(&self, other: &Self) -> bool { - if self.typ != other.typ { - return false; - } - if self.transaction_id != other.transaction_id { - return false; - } - if self.length != other.length { - return false; - } - if self.attributes != other.attributes { - return false; - } - true - } -} - -const DEFAULT_RAW_CAPACITY: usize = 120; - -impl Setter for Message { - // add_to sets b.TransactionID to m.TransactionID. - // - // Implements Setter to aid in crafting responses. - fn add_to(&self, b: &mut Message) -> Result<()> { - b.transaction_id = self.transaction_id; - b.write_transaction_id(); - Ok(()) - } -} - -impl Message { - // New returns *Message with pre-allocated Raw. - pub fn new() -> Self { - Message { - raw: { - let mut raw = Vec::with_capacity(DEFAULT_RAW_CAPACITY); - raw.extend_from_slice(&[0; MESSAGE_HEADER_SIZE]); - raw - }, - ..Default::default() - } - } - - // marshal_binary implements the encoding.BinaryMarshaler interface. - pub fn marshal_binary(&self) -> Result> { - // We can't return m.Raw, allocation is expected by implicit interface - // contract induced by other implementations. - Ok(self.raw.clone()) - } - - // unmarshal_binary implements the encoding.BinaryUnmarshaler interface. - pub fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> { - // We can't retain data, copy is expected by interface contract. - self.raw.clear(); - self.raw.extend_from_slice(data); - self.decode() - } - - // NewTransactionID sets m.TransactionID to random value from crypto/rand - // and returns error if any. - pub fn new_transaction_id(&mut self) -> Result<()> { - rand::thread_rng().fill(&mut self.transaction_id.0); - self.write_transaction_id(); - Ok(()) - } - - // Reset resets Message, attributes and underlying buffer length. - pub fn reset(&mut self) { - self.raw.clear(); - self.length = 0; - self.attributes.0.clear(); - } - - // grow ensures that internal buffer has n length. - fn grow(&mut self, n: usize, resize: bool) { - if self.raw.len() >= n { - if resize { - self.raw.resize(n, 0); - } - return; - } - self.raw.extend_from_slice(&vec![0; n - self.raw.len()]); - } - - // Add appends new attribute to message. Not goroutine-safe. - // - // Value of attribute is copied to internal buffer so - // it is safe to reuse v. - pub fn add(&mut self, t: AttrType, v: &[u8]) { - // Allocating buffer for TLV (type-length-value). - // T = t, L = len(v), V = v. - // m.Raw will look like: - // [0:20] <- message header - // [20:20+m.Length] <- existing message attributes - // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV - // [first:last] <- same as previous - // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer - // T L V - let alloc_size = ATTRIBUTE_HEADER_SIZE + v.len(); // ~ len(TLV) = len(TL) + len(V) - let first = MESSAGE_HEADER_SIZE + self.length as usize; // first byte number - let mut last = first + alloc_size; // last byte number - self.grow(last, true); // growing cap(Raw) to fit TLV - self.length += alloc_size as u32; // rendering length change - - // Encoding attribute TLV to allocated buffer. - let buf = &mut self.raw[first..last]; - buf[0..2].copy_from_slice(&t.value().to_be_bytes()); // T - buf[2..4].copy_from_slice(&(v.len() as u16).to_be_bytes()); // L - - let value = &mut buf[ATTRIBUTE_HEADER_SIZE..]; - value.copy_from_slice(v); // V - - let attr = RawAttribute { - typ: t, // T - length: v.len() as u16, // L - value: value.to_vec(), // V - }; - - // Checking that attribute value needs padding. - if attr.length as usize % PADDING != 0 { - // Performing padding. - let bytes_to_add = nearest_padded_value_length(v.len()) - v.len(); - last += bytes_to_add; - self.grow(last, true); - // setting all padding bytes to zero - // to prevent data leak from previous - // data in next bytes_to_add bytes - let buf = &mut self.raw[last - bytes_to_add..last]; - for b in buf { - *b = 0; - } - self.length += bytes_to_add as u32; // rendering length change - } - self.attributes.0.push(attr); - self.write_length(); - } - - // WriteLength writes m.Length to m.Raw. - pub fn write_length(&mut self) { - self.grow(4, false); - self.raw[2..4].copy_from_slice(&(self.length as u16).to_be_bytes()); - } - - // WriteHeader writes header to underlying buffer. Not goroutine-safe. - pub fn write_header(&mut self) { - self.grow(MESSAGE_HEADER_SIZE, false); - - self.write_type(); - self.write_length(); - self.raw[4..8].copy_from_slice(&MAGIC_COOKIE.to_be_bytes()); // magic cookie - self.raw[8..MESSAGE_HEADER_SIZE].copy_from_slice(&self.transaction_id.0); - // transaction ID - } - - // WriteTransactionID writes m.TransactionID to m.Raw. - pub fn write_transaction_id(&mut self) { - self.raw[8..MESSAGE_HEADER_SIZE].copy_from_slice(&self.transaction_id.0); - // transaction ID - } - - // WriteAttributes encodes all m.Attributes to m. - pub fn write_attributes(&mut self) { - let attributes: Vec = self.attributes.0.drain(..).collect(); - for a in &attributes { - self.add(a.typ, &a.value); - } - self.attributes = Attributes(attributes); - } - - // WriteType writes m.Type to m.Raw. - pub fn write_type(&mut self) { - self.grow(2, false); - self.raw[..2].copy_from_slice(&self.typ.value().to_be_bytes()); // message type - } - - // SetType sets m.Type and writes it to m.Raw. - pub fn set_type(&mut self, t: MessageType) { - self.typ = t; - self.write_type(); - } - - // Encode re-encodes message into m.Raw. - pub fn encode(&mut self) { - self.raw.clear(); - self.write_header(); - self.length = 0; - self.write_attributes(); - } - - // Decode decodes m.Raw into m. - pub fn decode(&mut self) -> Result<()> { - // decoding message header - let buf = &self.raw; - if buf.len() < MESSAGE_HEADER_SIZE { - return Err(Error::ErrUnexpectedHeaderEof); - } - - let t = u16::from_be_bytes([buf[0], buf[1]]); // first 2 bytes - let size = u16::from_be_bytes([buf[2], buf[3]]) as usize; // second 2 bytes - let cookie = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); // last 4 bytes - let full_size = MESSAGE_HEADER_SIZE + size; // len(m.Raw) - - if cookie != MAGIC_COOKIE { - return Err(Error::Other(format!( - "{cookie:x} is invalid magic cookie (should be {MAGIC_COOKIE:x})" - ))); - } - if buf.len() < full_size { - return Err(Error::Other(format!( - "buffer length {} is less than {} (expected message size)", - buf.len(), - full_size - ))); - } - - // saving header data - self.typ.read_value(t); - self.length = size as u32; - self.transaction_id - .0 - .copy_from_slice(&buf[8..MESSAGE_HEADER_SIZE]); - - self.attributes.0.clear(); - let mut offset = 0; - let mut b = &buf[MESSAGE_HEADER_SIZE..full_size]; - - while offset < size { - // checking that we have enough bytes to read header - if b.len() < ATTRIBUTE_HEADER_SIZE { - return Err(Error::Other(format!( - "buffer length {} is less than {} (expected header size)", - b.len(), - ATTRIBUTE_HEADER_SIZE - ))); - } - - let mut a = RawAttribute { - typ: compat_attr_type(u16::from_be_bytes([b[0], b[1]])), // first 2 bytes - length: u16::from_be_bytes([b[2], b[3]]), // second 2 bytes - ..Default::default() - }; - let a_l = a.length as usize; // attribute length - let a_buff_l = nearest_padded_value_length(a_l); // expected buffer length (with padding) - - b = &b[ATTRIBUTE_HEADER_SIZE..]; // slicing again to simplify value read - offset += ATTRIBUTE_HEADER_SIZE; - if b.len() < a_buff_l { - // checking size - return Err(Error::Other(format!( - "buffer length {} is less than {} (expected value size for {})", - b.len(), - a_buff_l, - a.typ - ))); - } - a.value = b[..a_l].to_vec(); - offset += a_buff_l; - b = &b[a_buff_l..]; - - self.attributes.0.push(a); - } - - Ok(()) - } - - // WriteTo implements WriterTo via calling Write(m.Raw) on w and returning - // call result. - pub fn write_to(&self, writer: &mut W) -> Result { - let n = writer.write(&self.raw)?; - Ok(n) - } - - // ReadFrom implements ReaderFrom. Reads message from r into m.Raw, - // Decodes it and return error if any. If m.Raw is too small, will return - // ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr. - // - // Can return *DecodeErr while decoding too. - pub fn read_from(&mut self, reader: &mut R) -> Result { - let mut t_buf = vec![0; DEFAULT_RAW_CAPACITY]; - let n = reader.read(&mut t_buf)?; - self.raw = t_buf[..n].to_vec(); - self.decode()?; - Ok(n) - } - - // Write decodes message and return error if any. - // - // Any error is unrecoverable, but message could be partially decoded. - pub fn write(&mut self, t_buf: &[u8]) -> Result { - self.raw.clear(); - self.raw.extend_from_slice(t_buf); - self.decode()?; - Ok(t_buf.len()) - } - - // CloneTo clones m to b securing any further m mutations. - pub fn clone_to(&self, b: &mut Message) -> Result<()> { - b.raw.clear(); - b.raw.extend_from_slice(&self.raw); - b.decode() - } - - // Contains return true if message contain t attribute. - pub fn contains(&self, t: AttrType) -> bool { - for a in &self.attributes.0 { - if a.typ == t { - return true; - } - } - false - } - - // get returns byte slice that represents attribute value, - // if there is no attribute with such type, - // ErrAttributeNotFound is returned. - pub fn get(&self, t: AttrType) -> Result> { - let (v, ok) = self.attributes.get(t); - if ok { - Ok(v.value) - } else { - Err(Error::ErrAttributeNotFound) - } - } - - // Build resets message and applies setters to it in batch, returning on - // first error. To prevent allocations, pass pointers to values. - // - // Example: - // var ( - // t = BindingRequest - // username = NewUsername("username") - // nonce = NewNonce("nonce") - // realm = NewRealm("example.org") - // ) - // m := new(Message) - // m.Build(t, username, nonce, realm) // 4 allocations - // m.Build(&t, &username, &nonce, &realm) // 0 allocations - // - // See BenchmarkBuildOverhead. - pub fn build(&mut self, setters: &[Box]) -> Result<()> { - self.reset(); - self.write_header(); - for s in setters { - s.add_to(self)?; - } - Ok(()) - } - - // Check applies checkers to message in batch, returning on first error. - pub fn check(&self, checkers: &[C]) -> Result<()> { - for c in checkers { - c.check(self)?; - } - Ok(()) - } - - // Parse applies getters to message in batch, returning on first error. - pub fn parse(&self, getters: &mut [G]) -> Result<()> { - for c in getters { - c.get_from(self)?; - } - Ok(()) - } -} - -// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct MessageClass(u8); - -// Possible values for message class in STUN Message Type. -pub const CLASS_REQUEST: MessageClass = MessageClass(0x00); // 0b00 -pub const CLASS_INDICATION: MessageClass = MessageClass(0x01); // 0b01 -pub const CLASS_SUCCESS_RESPONSE: MessageClass = MessageClass(0x02); // 0b10 -pub const CLASS_ERROR_RESPONSE: MessageClass = MessageClass(0x03); // 0b11 - -impl fmt::Display for MessageClass { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - CLASS_REQUEST => "request", - CLASS_INDICATION => "indication", - CLASS_SUCCESS_RESPONSE => "success response", - CLASS_ERROR_RESPONSE => "error response", - _ => "unknown message class", - }; - - write!(f, "{s}") - } -} - -// Method is uint16 representation of 12-bit STUN method. -#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)] -pub struct Method(u16); - -// Possible methods for STUN Message. -pub const METHOD_BINDING: Method = Method(0x001); -pub const METHOD_ALLOCATE: Method = Method(0x003); -pub const METHOD_REFRESH: Method = Method(0x004); -pub const METHOD_SEND: Method = Method(0x006); -pub const METHOD_DATA: Method = Method(0x007); -pub const METHOD_CREATE_PERMISSION: Method = Method(0x008); -pub const METHOD_CHANNEL_BIND: Method = Method(0x009); - -// Methods from RFC 6062. -pub const METHOD_CONNECT: Method = Method(0x000a); -pub const METHOD_CONNECTION_BIND: Method = Method(0x000b); -pub const METHOD_CONNECTION_ATTEMPT: Method = Method(0x000c); - -impl fmt::Display for Method { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let unknown = format!("0x{:x}", self.0); - - let s = match *self { - METHOD_BINDING => "Binding", - METHOD_ALLOCATE => "Allocate", - METHOD_REFRESH => "Refresh", - METHOD_SEND => "Send", - METHOD_DATA => "Data", - METHOD_CREATE_PERMISSION => "CreatePermission", - METHOD_CHANNEL_BIND => "ChannelBind", - - // RFC 6062. - METHOD_CONNECT => "Connect", - METHOD_CONNECTION_BIND => "ConnectionBind", - METHOD_CONNECTION_ATTEMPT => "ConnectionAttempt", - _ => unknown.as_str(), - }; - - write!(f, "{s}") - } -} - -// MessageType is STUN Message Type Field. -#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)] -pub struct MessageType { - pub method: Method, // e.g. binding - pub class: MessageClass, // e.g. request -} - -// Common STUN message types. -// Binding request message type. -pub const BINDING_REQUEST: MessageType = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, -}; -// Binding success response message type -pub const BINDING_SUCCESS: MessageType = MessageType { - method: METHOD_BINDING, - class: CLASS_SUCCESS_RESPONSE, -}; -// Binding error response message type. -pub const BINDING_ERROR: MessageType = MessageType { - method: METHOD_BINDING, - class: CLASS_ERROR_RESPONSE, -}; - -impl fmt::Display for MessageType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {}", self.method, self.class) - } -} - -const METHOD_ABITS: u16 = 0xf; // 0b0000000000001111 -const METHOD_BBITS: u16 = 0x70; // 0b0000000001110000 -const METHOD_DBITS: u16 = 0xf80; // 0b0000111110000000 - -const METHOD_BSHIFT: u16 = 1; -const METHOD_DSHIFT: u16 = 2; - -const FIRST_BIT: u16 = 0x1; -const SECOND_BIT: u16 = 0x2; - -const C0BIT: u16 = FIRST_BIT; -const C1BIT: u16 = SECOND_BIT; - -const CLASS_C0SHIFT: u16 = 4; -const CLASS_C1SHIFT: u16 = 7; - -impl Setter for MessageType { - // add_to sets m type to t. - fn add_to(&self, m: &mut Message) -> Result<()> { - m.set_type(*self); - Ok(()) - } -} - -impl MessageType { - // NewType returns new message type with provided method and class. - pub fn new(method: Method, class: MessageClass) -> Self { - MessageType { method, class } - } - - // Value returns bit representation of messageType. - pub fn value(&self) -> u16 { - // 0 1 - // 2 3 4 5 6 7 8 9 0 1 2 3 4 5 - // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - // |M |M |M|M|M|C|M|M|M|C|M|M|M|M| - // |11|10|9|8|7|1|6|5|4|0|3|2|1|0| - // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - // Figure 3: Format of STUN Message Type Field - - // Warning: Abandon all hope ye who enter here. - // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). - let method = self.method.0; - let a = method & METHOD_ABITS; // A = M * 0b0000000000001111 (right 4 bits) - let b = method & METHOD_BBITS; // B = M * 0b0000000001110000 (3 bits after A) - let d = method & METHOD_DBITS; // D = M * 0b0000111110000000 (5 bits after B) - - // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). - let method = a + (b << METHOD_BSHIFT) + (d << METHOD_DSHIFT); - - // C0 is zero bit of C, C1 is first bit. - // C0 = C * 0b01, C1 = (C * 0b10) >> 1 - // Ct = C0 << 4 + C1 << 8. - // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" - // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions - // (see figure 3). - let c = self.class.0 as u16; - let c0 = (c & C0BIT) << CLASS_C0SHIFT; - let c1 = (c & C1BIT) << CLASS_C1SHIFT; - let class = c0 + c1; - - method + class - } - - // ReadValue decodes uint16 into MessageType. - pub fn read_value(&mut self, value: u16) { - // Decoding class. - // We are taking first bit from v >> 4 and second from v >> 7. - let c0 = (value >> CLASS_C0SHIFT) & C0BIT; - let c1 = (value >> CLASS_C1SHIFT) & C1BIT; - let class = c0 + c1; - self.class = MessageClass(class as u8); - - // Decoding method. - let a = value & METHOD_ABITS; // A(M0-M3) - let b = (value >> METHOD_BSHIFT) & METHOD_BBITS; // B(M4-M6) - let d = (value >> METHOD_DSHIFT) & METHOD_DBITS; // D(M7-M11) - let m = a + b + d; - self.method = Method(m); - } -} diff --git a/stun/src/message/message_test.rs b/stun/src/message/message_test.rs deleted file mode 100644 index afd388ef5..000000000 --- a/stun/src/message/message_test.rs +++ /dev/null @@ -1,744 +0,0 @@ -use std::io::{BufReader, BufWriter}; - -use super::*; -use crate::fingerprint::FINGERPRINT; -use crate::integrity::MessageIntegrity; -use crate::textattrs::TextAttribute; -use crate::xoraddr::*; - -#[test] -fn test_message_buffer() -> Result<()> { - let mut m = Message::new(); - m.typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - m.transaction_id = TransactionId::new(); - m.add(ATTR_ERROR_CODE, &[0xff, 0xfe, 0xfa]); - m.write_header(); - - let mut m_decoded = Message::new(); - let mut reader = BufReader::new(m.raw.as_slice()); - m_decoded.read_from(&mut reader)?; - - assert_eq!(m_decoded, m, "{m_decoded} != {m}"); - - Ok(()) -} - -#[test] -fn test_message_type_value() -> Result<()> { - let tests = vec![ - ( - MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }, - 0x0001, - ), - ( - MessageType { - method: METHOD_BINDING, - class: CLASS_SUCCESS_RESPONSE, - }, - 0x0101, - ), - ( - MessageType { - method: METHOD_BINDING, - class: CLASS_ERROR_RESPONSE, - }, - 0x0111, - ), - ( - MessageType { - method: Method(0xb6d), - class: MessageClass(0x3), - }, - 0x2ddd, - ), - ]; - - for (input, output) in tests { - let b = input.value(); - assert_eq!(b, output, "Value({input}) -> {b}, want {output}"); - } - - Ok(()) -} - -#[test] -fn test_message_type_read_value() -> Result<()> { - let tests = vec![ - ( - 0x0001, - MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }, - ), - ( - 0x0101, - MessageType { - method: METHOD_BINDING, - class: CLASS_SUCCESS_RESPONSE, - }, - ), - ( - 0x0111, - MessageType { - method: METHOD_BINDING, - class: CLASS_ERROR_RESPONSE, - }, - ), - ]; - - for (input, output) in tests { - let mut m = MessageType::default(); - m.read_value(input); - assert_eq!(m, output, "ReadValue({input}) -> {m}, want {output}"); - } - - Ok(()) -} - -#[test] -fn test_message_type_read_write_value() -> Result<()> { - let tests = vec![ - MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }, - MessageType { - method: METHOD_BINDING, - class: CLASS_SUCCESS_RESPONSE, - }, - MessageType { - method: METHOD_BINDING, - class: CLASS_ERROR_RESPONSE, - }, - MessageType { - method: Method(0x12), - class: CLASS_ERROR_RESPONSE, - }, - ]; - - for test in tests { - let mut m = MessageType::default(); - let v = test.value(); - m.read_value(v); - assert_eq!(m, test, "ReadValue({test} -> {v}) = {m}, should be {test}"); - } - - Ok(()) -} - -#[test] -fn test_message_write_to() -> Result<()> { - let mut m = Message::new(); - m.typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - m.transaction_id = TransactionId::new(); - m.add(ATTR_ERROR_CODE, &[0xff, 0xfe, 0xfa]); - m.write_header(); - let mut buf = vec![]; - { - let mut writer = BufWriter::<&mut Vec>::new(buf.as_mut()); - m.write_to(&mut writer)?; - } - - let mut m_decoded = Message::new(); - let mut reader = BufReader::new(buf.as_slice()); - m_decoded.read_from(&mut reader)?; - assert_eq!(m_decoded, m, "{m_decoded} != {m}"); - - Ok(()) -} - -#[test] -fn test_message_cookie() -> Result<()> { - let buf = vec![0; 20]; - let mut m_decoded = Message::new(); - let mut reader = BufReader::new(buf.as_slice()); - let result = m_decoded.read_from(&mut reader); - assert!(result.is_err(), "should error"); - - Ok(()) -} - -#[test] -fn test_message_length_less_header_size() -> Result<()> { - let buf = vec![0; 8]; - let mut m_decoded = Message::new(); - let mut reader = BufReader::new(buf.as_slice()); - let result = m_decoded.read_from(&mut reader); - assert!(result.is_err(), "should error"); - - Ok(()) -} - -#[test] -fn test_message_bad_length() -> Result<()> { - let m_type = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let mut m = Message { - typ: m_type, - length: 4, - transaction_id: TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), - ..Default::default() - }; - m.add(AttrType(0x1), &[1, 2]); - m.write_header(); - m.raw[20 + 3] = 10; // set attr length = 10 - - let mut m_decoded = Message::new(); - let result = m_decoded.write(&m.raw); - assert!(result.is_err(), "should error"); - - Ok(()) -} - -#[test] -fn test_message_attr_length_less_than_header() -> Result<()> { - let m_type = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let message_attribute = RawAttribute { - length: 2, - value: vec![1, 2], - typ: AttrType(0x1), - }; - let message_attributes = Attributes(vec![message_attribute]); - let mut m = Message { - typ: m_type, - transaction_id: TransactionId::new(), - attributes: message_attributes, - ..Default::default() - }; - m.encode(); - - let mut m_decoded = Message::new(); - m.raw[3] = 2; // rewrite to bad length - - let mut reader = BufReader::new(&m.raw[..20 + 2]); - let result = m_decoded.read_from(&mut reader); - assert!(result.is_err(), "should be error"); - - Ok(()) -} - -#[test] -fn test_message_attr_size_less_than_length() -> Result<()> { - let m_type = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - let message_attribute = RawAttribute { - length: 4, - value: vec![1, 2, 3, 4], - typ: AttrType(0x1), - }; - let message_attributes = Attributes(vec![message_attribute]); - let mut m = Message { - typ: m_type, - transaction_id: TransactionId::new(), - attributes: message_attributes, - ..Default::default() - }; - m.write_attributes(); - m.write_header(); - m.raw[3] = 5; // rewrite to bad length - - let mut m_decoded = Message::new(); - let mut reader = BufReader::new(&m.raw[..20 + 5]); - let result = m_decoded.read_from(&mut reader); - assert!(result.is_err(), "should be error"); - - Ok(()) -} - -#[test] -fn test_message_read_from_error() -> Result<()> { - let mut m_decoded = Message::new(); - let buf = vec![]; - let mut reader = BufReader::new(buf.as_slice()); - let result = m_decoded.read_from(&mut reader); - assert!(result.is_err(), "should be error"); - - Ok(()) -} - -#[test] -fn test_message_class_string() -> Result<()> { - let v = vec![ - CLASS_REQUEST, - CLASS_ERROR_RESPONSE, - CLASS_SUCCESS_RESPONSE, - CLASS_INDICATION, - ]; - - for k in v { - if k.to_string() == *"unknown message class" { - panic!("bad stringer {k}"); - } - } - - // should panic - let p = MessageClass(0x05).to_string(); - assert_eq!(p, "unknown message class", "should be error {p}"); - - Ok(()) -} - -#[test] -fn test_attr_type_string() -> Result<()> { - let v = vec![ - ATTR_MAPPED_ADDRESS, - ATTR_USERNAME, - ATTR_ERROR_CODE, - ATTR_MESSAGE_INTEGRITY, - ATTR_UNKNOWN_ATTRIBUTES, - ATTR_REALM, - ATTR_NONCE, - ATTR_XORMAPPED_ADDRESS, - ATTR_SOFTWARE, - ATTR_ALTERNATE_SERVER, - ATTR_FINGERPRINT, - ]; - for k in v { - assert!(!k.to_string().starts_with("0x"), "bad stringer"); - } - - let v_non_standard = AttrType(0x512); - assert!( - v_non_standard.to_string().starts_with("0x512"), - "bad prefix" - ); - - Ok(()) -} - -#[test] -fn test_method_string() -> Result<()> { - assert_eq!( - METHOD_BINDING.to_string(), - "Binding".to_owned(), - "binding is not binding!" - ); - assert_eq!( - Method(0x616).to_string(), - "0x616".to_owned(), - "Bad stringer {}", - Method(0x616) - ); - - Ok(()) -} - -#[test] -fn test_attribute_equal() -> Result<()> { - let a = RawAttribute { - length: 2, - value: vec![0x1, 0x2], - ..Default::default() - }; - let b = RawAttribute { - length: 2, - value: vec![0x1, 0x2], - ..Default::default() - }; - assert_eq!(a, b, "should equal"); - - assert_ne!( - a, - RawAttribute { - typ: AttrType(0x2), - ..Default::default() - }, - "should not equal" - ); - assert_ne!( - a, - RawAttribute { - length: 0x2, - ..Default::default() - }, - "should not equal" - ); - assert_ne!( - a, - RawAttribute { - length: 0x3, - ..Default::default() - }, - "should not equal" - ); - assert_ne!( - a, - RawAttribute { - length: 0x2, - value: vec![0x1, 0x3], - ..Default::default() - }, - "should not equal" - ); - - Ok(()) -} - -#[test] -fn test_message_equal() -> Result<()> { - let attr = RawAttribute { - length: 2, - value: vec![0x1, 0x2], - typ: AttrType(0x1), - }; - let attrs = Attributes(vec![attr]); - let a = Message { - attributes: attrs.clone(), - length: 4 + 2, - ..Default::default() - }; - let b = Message { - attributes: attrs.clone(), - length: 4 + 2, - ..Default::default() - }; - assert_eq!(a, b, "should equal"); - assert_ne!( - a, - Message { - typ: MessageType { - class: MessageClass(128), - ..Default::default() - }, - ..Default::default() - }, - "should not equal" - ); - - let t_id = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - - assert_ne!( - a, - Message { - transaction_id: t_id, - ..Default::default() - }, - "should not equal" - ); - assert_ne!( - a, - Message { - length: 3, - ..Default::default() - }, - "should not equal" - ); - - let t_attrs = Attributes(vec![RawAttribute { - length: 1, - value: vec![0x1], - typ: AttrType(0x1), - }]); - assert_ne!( - a, - Message { - attributes: t_attrs, - length: 4 + 2, - ..Default::default() - }, - "should not equal" - ); - - let t_attrs = Attributes(vec![RawAttribute { - length: 2, - value: vec![0x1, 0x1], - typ: AttrType(0x2), - }]); - assert_ne!( - a, - Message { - attributes: t_attrs, - length: 4 + 2, - ..Default::default() - }, - "should not equal" - ); - - //"Nil attributes" - { - let a = Message { - length: 4 + 2, - ..Default::default() - }; - let mut b = Message { - attributes: attrs, - length: 4 + 2, - ..Default::default() - }; - - assert_ne!(a, b, "should not equal"); - assert_ne!(b, a, "should not equal"); - b.attributes = Attributes::default(); - assert_eq!(a, b, "should equal"); - } - - //"Attributes length" - { - let attr = RawAttribute { - length: 2, - value: vec![0x1, 0x2], - typ: AttrType(0x1), - }; - let attr1 = RawAttribute { - length: 2, - value: vec![0x1, 0x2], - typ: AttrType(0x1), - }; - let a = Message { - attributes: Attributes(vec![attr.clone()]), - length: 4 + 2, - ..Default::default() - }; - let b = Message { - attributes: Attributes(vec![attr, attr1]), - length: 4 + 2, - ..Default::default() - }; - assert_ne!(a, b, "should not equal"); - } - - //"Attributes values" - { - let attr = RawAttribute { - length: 2, - value: vec![0x1, 0x2], - typ: AttrType(0x1), - }; - let attr1 = RawAttribute { - length: 2, - value: vec![0x1, 0x1], - typ: AttrType(0x1), - }; - let a = Message { - attributes: Attributes(vec![attr.clone(), attr.clone()]), - length: 4 + 2, - ..Default::default() - }; - let b = Message { - attributes: Attributes(vec![attr, attr1]), - length: 4 + 2, - ..Default::default() - }; - assert_ne!(a, b, "should not equal"); - } - - Ok(()) -} - -#[test] -fn test_message_grow() -> Result<()> { - let mut m = Message::new(); - m.grow(512, false); - assert_eq!(m.raw.len(), 512, "Bad length {}", m.raw.len()); - - Ok(()) -} - -#[test] -fn test_message_grow_smaller() -> Result<()> { - let mut m = Message::new(); - m.grow(2, false); - assert!(m.raw.capacity() >= 20, "Bad capacity {}", m.raw.capacity()); - - assert!(m.raw.len() >= 20, "Bad length {}", m.raw.len()); - - Ok(()) -} - -#[test] -fn test_message_string() -> Result<()> { - let m = Message::new(); - assert_ne!(m.to_string(), "", "bad string"); - - Ok(()) -} - -#[test] -fn test_is_message() -> Result<()> { - let mut m = Message::new(); - let a = TextAttribute { - attr: ATTR_SOFTWARE, - text: "software".to_owned(), - }; - a.add_to(&mut m)?; - m.write_header(); - - let tests = vec![ - (vec![], false), // 0 - (vec![1, 2, 3], false), // 1 - (vec![1, 2, 4], false), // 2 - (vec![1, 2, 4, 5, 6, 7, 8, 9, 20], false), // 3 - (m.raw.to_vec(), true), // 5 - ( - vec![ - 0, 0, 0, 0, 33, 18, 164, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ], - true, - ), // 6 - ]; - - for (input, output) in tests { - let got = is_message(&input); - assert_eq!(got, output, "IsMessage({input:?}) {got} != {output}"); - } - - Ok(()) -} - -#[test] -fn test_message_contains() -> Result<()> { - let mut m = Message::new(); - m.add(ATTR_SOFTWARE, "value".as_bytes()); - - assert!(m.contains(ATTR_SOFTWARE), "message should contain software"); - assert!(!m.contains(ATTR_NONCE), "message should not contain nonce"); - - Ok(()) -} - -#[test] -fn test_message_full_size() -> Result<()> { - let mut m = Message::new(); - m.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0])), - Box::new(TextAttribute::new(ATTR_SOFTWARE, "pion/stun".to_owned())), - Box::new(MessageIntegrity::new_long_term_integrity( - "username".to_owned(), - "realm".to_owned(), - "password".to_owned(), - )), - Box::new(FINGERPRINT), - ])?; - let l = m.raw.len(); - m.raw = m.raw[..l - 10].to_vec(); - - let mut decoder = Message::new(); - let l = m.raw.len(); - decoder.raw = m.raw[..l - 10].to_vec(); - let result = decoder.decode(); - assert!(result.is_err(), "decode on truncated buffer should error"); - - Ok(()) -} - -#[test] -fn test_message_clone_to() -> Result<()> { - let mut m = Message::new(); - m.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0])), - Box::new(TextAttribute::new(ATTR_SOFTWARE, "pion/stun".to_owned())), - Box::new(MessageIntegrity::new_long_term_integrity( - "username".to_owned(), - "realm".to_owned(), - "password".to_owned(), - )), - Box::new(FINGERPRINT), - ])?; - m.encode(); - - let mut b = Message::new(); - m.clone_to(&mut b)?; - assert_eq!(b, m, "not equal"); - - //TODO: Corrupting m and checking that b is not corrupted. - /*let (mut s, ok) = b.attributes.get(ATTR_SOFTWARE); - assert!(ok, "no software attribute"); - s.value[0] = b'k'; - s.add_to(&mut b)?; - assert_ne!(b, m, "should not be equal");*/ - - Ok(()) -} - -#[test] -fn test_message_add_to() -> Result<()> { - let mut m = Message::new(); - m.build(&[ - Box::new(BINDING_REQUEST), - Box::new(TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0])), - Box::new(FINGERPRINT), - ])?; - m.encode(); - - let mut b = Message::new(); - m.clone_to(&mut b)?; - - m.transaction_id = TransactionId([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 0]); - assert_ne!(b, m, "should not be equal"); - - m.add_to(&mut b)?; - assert_eq!(b, m, "should be equal"); - - Ok(()) -} - -#[test] -fn test_decode() -> Result<()> { - let mut m = Message::new(); - m.typ = MessageType { - method: METHOD_BINDING, - class: CLASS_REQUEST, - }; - m.transaction_id = TransactionId::new(); - m.add(ATTR_ERROR_CODE, &[0xff, 0xfe, 0xfa]); - m.write_header(); - - let mut m_decoded = Message::new(); - m_decoded.raw.clear(); - m_decoded.raw.extend_from_slice(&m.raw); - m_decoded.decode()?; - assert_eq!( - m_decoded, m, - "decoded result is not equal to encoded message" - ); - - Ok(()) -} - -#[test] -fn test_message_marshal_binary() -> Result<()> { - let mut m = Message::new(); - m.build(&[ - Box::new(TextAttribute::new(ATTR_SOFTWARE, "software".to_owned())), - Box::new(XorMappedAddress { - ip: "213.1.223.5".parse().unwrap(), - port: 0, - }), - ])?; - - let mut data = m.marshal_binary()?; - // Reset m.Raw to check retention. - for i in 0..m.raw.len() { - m.raw[i] = 0; - } - m.unmarshal_binary(&data)?; - - // Reset data to check retention. - #[allow(clippy::needless_range_loop)] - for i in 0..data.len() { - data[i] = 0; - } - - m.decode()?; - - Ok(()) -} diff --git a/stun/src/textattrs.rs b/stun/src/textattrs.rs deleted file mode 100644 index 54c5e77c7..000000000 --- a/stun/src/textattrs.rs +++ /dev/null @@ -1,95 +0,0 @@ -#[cfg(test)] -mod textattrs_test; - -use std::fmt; - -use crate::attributes::*; -use crate::checks::*; -use crate::error::*; -use crate::message::*; - -const MAX_USERNAME_B: usize = 513; -const MAX_REALM_B: usize = 763; -const MAX_SOFTWARE_B: usize = 763; -const MAX_NONCE_B: usize = 763; - -// Username represents USERNAME attribute. -// -// RFC 5389 Section 15.3 -pub type Username = TextAttribute; - -// Realm represents REALM attribute. -// -// RFC 5389 Section 15.7 -pub type Realm = TextAttribute; - -// Nonce represents NONCE attribute. -// -// RFC 5389 Section 15.8 -pub type Nonce = TextAttribute; - -// Software is SOFTWARE attribute. -// -// RFC 5389 Section 15.10 -pub type Software = TextAttribute; - -// TextAttribute is helper for adding and getting text attributes. -#[derive(Clone, Default)] -pub struct TextAttribute { - pub attr: AttrType, - pub text: String, -} - -impl fmt::Display for TextAttribute { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.text) - } -} - -impl Setter for TextAttribute { - // add_to_as adds attribute with type t to m, checking maximum length. If max_len - // is less than 0, no check is performed. - fn add_to(&self, m: &mut Message) -> Result<()> { - let text = self.text.as_bytes(); - let max_len = match self.attr { - ATTR_USERNAME => MAX_USERNAME_B, - ATTR_REALM => MAX_REALM_B, - ATTR_SOFTWARE => MAX_SOFTWARE_B, - ATTR_NONCE => MAX_NONCE_B, - _ => return Err(Error::Other(format!("Unsupported AttrType {}", self.attr))), - }; - - check_overflow(self.attr, text.len(), max_len)?; - m.add(self.attr, text); - Ok(()) - } -} - -impl Getter for TextAttribute { - fn get_from(&mut self, m: &Message) -> Result<()> { - let attr = self.attr; - *self = TextAttribute::get_from_as(m, attr)?; - Ok(()) - } -} - -impl TextAttribute { - pub fn new(attr: AttrType, text: String) -> Self { - TextAttribute { attr, text } - } - - // get_from_as gets t attribute from m and appends its value to reset v. - pub fn get_from_as(m: &Message, attr: AttrType) -> Result { - match attr { - ATTR_USERNAME => {} - ATTR_REALM => {} - ATTR_SOFTWARE => {} - ATTR_NONCE => {} - _ => return Err(Error::Other(format!("Unsupported AttrType {attr}"))), - }; - - let a = m.get(attr)?; - let text = String::from_utf8(a)?; - Ok(TextAttribute { attr, text }) - } -} diff --git a/stun/src/textattrs/textattrs_test.rs b/stun/src/textattrs/textattrs_test.rs deleted file mode 100644 index e0a01aba1..000000000 --- a/stun/src/textattrs/textattrs_test.rs +++ /dev/null @@ -1,307 +0,0 @@ -use std::io::BufReader; - -use super::*; -use crate::checks::*; -use crate::error::*; - -#[test] -fn test_software_get_from() -> Result<()> { - let mut m = Message::new(); - let v = "Client v0.0.1".to_owned(); - m.add(ATTR_SOFTWARE, v.as_bytes()); - m.write_header(); - - let mut m2 = Message { - raw: Vec::with_capacity(256), - ..Default::default() - }; - - let mut reader = BufReader::new(m.raw.as_slice()); - m2.read_from(&mut reader)?; - let software = TextAttribute::get_from_as(&m, ATTR_SOFTWARE)?; - assert_eq!(software.to_string(), v, "Expected {v}, got {software}."); - - let (s_attr, ok) = m.attributes.get(ATTR_SOFTWARE); - assert!(ok, "sowfware attribute should be found"); - - let s = s_attr.to_string(); - assert!(s.starts_with("SOFTWARE:"), "bad string representation {s}"); - - Ok(()) -} - -#[test] -fn test_software_add_to_invalid() -> Result<()> { - let mut m = Message::new(); - let s = TextAttribute { - attr: ATTR_SOFTWARE, - text: String::from_utf8(vec![0; 1024]).unwrap(), - }; - let result = s.add_to(&mut m); - if let Err(err) = result { - assert!( - is_attr_size_overflow(&err), - "add_to should return AttrOverflowErr, got: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - let result = TextAttribute::get_from_as(&m, ATTR_SOFTWARE); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "GetFrom should return {}, got: {}", - Error::ErrAttributeNotFound, - err - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_software_add_to_regression() -> Result<()> { - // s.add_to checked len(m.Raw) instead of len(s.Raw). - let mut m = Message { - raw: vec![0u8; 2048], - ..Default::default() - }; - let s = TextAttribute { - attr: ATTR_SOFTWARE, - text: String::from_utf8(vec![0; 100]).unwrap(), - }; - s.add_to(&mut m)?; - - Ok(()) -} - -#[test] -fn test_username() -> Result<()> { - let username = "username".to_owned(); - let u = TextAttribute { - attr: ATTR_USERNAME, - text: username.clone(), - }; - let mut m = Message::new(); - m.write_header(); - //"Bad length" - { - let bad_u = TextAttribute { - attr: ATTR_USERNAME, - text: String::from_utf8(vec![0; 600]).unwrap(), - }; - let result = bad_u.add_to(&mut m); - if let Err(err) = result { - assert!( - is_attr_size_overflow(&err), - "add_to should return *AttrOverflowErr, got: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - } - //"add_to" - { - u.add_to(&mut m)?; - - //"GetFrom" - { - let got = TextAttribute::get_from_as(&m, ATTR_USERNAME)?; - assert_eq!( - got.to_string(), - username, - "expedted: {username}, got: {got}" - ); - //"Not found" - { - let m = Message::new(); - let result = TextAttribute::get_from_as(&m, ATTR_USERNAME); - if let Err(err) = result { - assert_eq!(Error::ErrAttributeNotFound, err, "Should error"); - } else { - panic!("expected error, but got ok"); - } - } - } - } - - //"No allocations" - { - let mut m = Message::new(); - m.write_header(); - let u = TextAttribute { - attr: ATTR_USERNAME, - text: "username".to_owned(), - }; - - u.add_to(&mut m)?; - m.reset(); - } - - Ok(()) -} - -#[test] -fn test_realm_get_from() -> Result<()> { - let mut m = Message::new(); - let v = "realm".to_owned(); - m.add(ATTR_REALM, v.as_bytes()); - m.write_header(); - - let mut m2 = Message { - raw: Vec::with_capacity(256), - ..Default::default() - }; - - let result = TextAttribute::get_from_as(&m2, ATTR_REALM); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "GetFrom should return {}, got: {}", - Error::ErrAttributeNotFound, - err - ); - } else { - panic!("Expected error, but got ok"); - } - - let mut reader = BufReader::new(m.raw.as_slice()); - m2.read_from(&mut reader)?; - - let r = TextAttribute::get_from_as(&m, ATTR_REALM)?; - assert_eq!(r.to_string(), v, "Expected {v}, got {r}."); - - let (r_attr, ok) = m.attributes.get(ATTR_REALM); - assert!(ok, "realm attribute should be found"); - - let s = r_attr.to_string(); - assert!(s.starts_with("REALM:"), "bad string representation {s}"); - - Ok(()) -} - -#[test] -fn test_realm_add_to_invalid() -> Result<()> { - let mut m = Message::new(); - let s = TextAttribute { - attr: ATTR_REALM, - text: String::from_utf8(vec![0; 1024]).unwrap(), - }; - let result = s.add_to(&mut m); - if let Err(err) = result { - assert!( - is_attr_size_overflow(&err), - "add_to should return AttrOverflowErr, got: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - let result = TextAttribute::get_from_as(&m, ATTR_REALM); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "GetFrom should return {}, got: {}", - Error::ErrAttributeNotFound, - err - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_nonce_get_from() -> Result<()> { - let mut m = Message::new(); - let v = "example.org".to_owned(); - m.add(ATTR_NONCE, v.as_bytes()); - m.write_header(); - - let mut m2 = Message { - raw: Vec::with_capacity(256), - ..Default::default() - }; - - let result = TextAttribute::get_from_as(&m2, ATTR_NONCE); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "GetFrom should return {}, got: {}", - Error::ErrAttributeNotFound, - err - ); - } else { - panic!("Expected error, but got ok"); - } - - let mut reader = BufReader::new(m.raw.as_slice()); - m2.read_from(&mut reader)?; - - let r = TextAttribute::get_from_as(&m, ATTR_NONCE)?; - assert_eq!(r.to_string(), v, "Expected {v}, got {r}."); - - let (r_attr, ok) = m.attributes.get(ATTR_NONCE); - assert!(ok, "realm attribute should be found"); - - let s = r_attr.to_string(); - assert!(s.starts_with("NONCE:"), "bad string representation {s}"); - - Ok(()) -} - -#[test] -fn test_nonce_add_to_invalid() -> Result<()> { - let mut m = Message::new(); - let s = TextAttribute { - attr: ATTR_NONCE, - text: String::from_utf8(vec![0; 1024]).unwrap(), - }; - let result = s.add_to(&mut m); - if let Err(err) = result { - assert!( - is_attr_size_overflow(&err), - "add_to should return AttrOverflowErr, got: {err}" - ); - } else { - panic!("expected error, but got ok"); - } - - let result = TextAttribute::get_from_as(&m, ATTR_NONCE); - if let Err(err) = result { - assert_eq!( - Error::ErrAttributeNotFound, - err, - "GetFrom should return {}, got: {}", - Error::ErrAttributeNotFound, - err - ); - } else { - panic!("expected error, but got ok"); - } - - Ok(()) -} - -#[test] -fn test_nonce_add_to() -> Result<()> { - let mut m = Message::new(); - let n = TextAttribute { - attr: ATTR_NONCE, - text: "example.org".to_owned(), - }; - n.add_to(&mut m)?; - - let v = m.get(ATTR_NONCE)?; - assert_eq!(v.as_slice(), b"example.org", "bad nonce {v:?}"); - - Ok(()) -} diff --git a/stun/src/uattrs.rs b/stun/src/uattrs.rs deleted file mode 100644 index 087d8099f..000000000 --- a/stun/src/uattrs.rs +++ /dev/null @@ -1,62 +0,0 @@ -#[cfg(test)] -mod uattrs_test; - -use std::fmt; - -use crate::attributes::*; -use crate::error::*; -use crate::message::*; - -// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute. -// -// RFC 5389 Section 15.9 -pub struct UnknownAttributes(pub Vec); - -impl fmt::Display for UnknownAttributes { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.0.is_empty() { - write!(f, "") - } else { - let mut s = vec![]; - for t in &self.0 { - s.push(t.to_string()); - } - write!(f, "{}", s.join(", ")) - } - } -} - -// type size is 16 bit. -const ATTR_TYPE_SIZE: usize = 2; - -impl Setter for UnknownAttributes { - // add_to adds UNKNOWN-ATTRIBUTES attribute to message. - fn add_to(&self, m: &mut Message) -> Result<()> { - let mut v = Vec::with_capacity(ATTR_TYPE_SIZE * 20); // 20 should be enough - // If len(a.Types) > 20, there will be allocations. - for t in &self.0 { - v.extend_from_slice(&t.value().to_be_bytes()); - } - m.add(ATTR_UNKNOWN_ATTRIBUTES, &v); - Ok(()) - } -} - -impl Getter for UnknownAttributes { - // GetFrom parses UNKNOWN-ATTRIBUTES from message. - fn get_from(&mut self, m: &Message) -> Result<()> { - let v = m.get(ATTR_UNKNOWN_ATTRIBUTES)?; - if v.len() % ATTR_TYPE_SIZE != 0 { - return Err(Error::ErrBadUnknownAttrsSize); - } - self.0.clear(); - let mut first = 0usize; - while first < v.len() { - let last = first + ATTR_TYPE_SIZE; - self.0 - .push(AttrType(u16::from_be_bytes([v[first], v[first + 1]]))); - first = last; - } - Ok(()) - } -} diff --git a/stun/src/uattrs/uattrs_test.rs b/stun/src/uattrs/uattrs_test.rs deleted file mode 100644 index 2351d555d..000000000 --- a/stun/src/uattrs/uattrs_test.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::*; - -#[test] -fn test_unknown_attributes() -> Result<()> { - let mut m = Message::new(); - let a = UnknownAttributes(vec![ATTR_DONT_FRAGMENT, ATTR_CHANNEL_NUMBER]); - assert_eq!( - a.to_string(), - "DONT-FRAGMENT, CHANNEL-NUMBER", - "bad String:{a}" - ); - assert_eq!( - UnknownAttributes(vec![]).to_string(), - "", - "bad blank string" - ); - - a.add_to(&mut m)?; - - //"GetFrom" - { - let mut attrs = UnknownAttributes(Vec::with_capacity(10)); - attrs.get_from(&m)?; - for i in 0..a.0.len() { - assert_eq!(a.0[i], attrs.0[i], "expected {} != {}", a.0[i], attrs.0[i]); - } - let mut m_blank = Message::new(); - let result = attrs.get_from(&m_blank); - assert!(result.is_err(), "should error"); - - m_blank.add(ATTR_UNKNOWN_ATTRIBUTES, &[1, 2, 3]); - let result = attrs.get_from(&m_blank); - assert!(result.is_err(), "should error"); - } - - Ok(()) -} diff --git a/stun/src/uri.rs b/stun/src/uri.rs deleted file mode 100644 index 5dce476d9..000000000 --- a/stun/src/uri.rs +++ /dev/null @@ -1,73 +0,0 @@ -#[cfg(test)] -mod uri_test; - -use std::fmt; - -use crate::error::*; - -// SCHEME definitions from RFC 7064 Section 3.2. - -pub const SCHEME: &str = "stun"; -pub const SCHEME_SECURE: &str = "stuns"; - -// URI as defined in RFC 7064. -#[derive(PartialEq, Eq, Debug)] -pub struct Uri { - pub scheme: String, - pub host: String, - pub port: Option, -} - -impl fmt::Display for Uri { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let host = if self.host.contains("::") { - "[".to_owned() + self.host.as_str() + "]" - } else { - self.host.clone() - }; - - if let Some(port) = self.port { - write!(f, "{}:{}:{}", self.scheme, host, port) - } else { - write!(f, "{}:{}", self.scheme, host) - } - } -} - -impl Uri { - // parse_uri parses URI from string. - pub fn parse_uri(raw: &str) -> Result { - // work around for url crate - if raw.contains("//") { - return Err(Error::ErrInvalidUrl); - } - - let mut s = raw.to_string(); - let pos = raw.find(':'); - if let Some(p) = pos { - s.replace_range(p..p + 1, "://"); - } else { - return Err(Error::ErrSchemeType); - } - - let raw_parts = url::Url::parse(&s)?; - - let scheme = raw_parts.scheme().into(); - if scheme != SCHEME && scheme != SCHEME_SECURE { - return Err(Error::ErrSchemeType); - } - - let host = if let Some(host) = raw_parts.host_str() { - host.trim() - .trim_start_matches('[') - .trim_end_matches(']') - .to_owned() - } else { - return Err(Error::ErrHost); - }; - - let port = raw_parts.port(); - - Ok(Uri { scheme, host, port }) - } -} diff --git a/stun/src/uri/uri_test.rs b/stun/src/uri/uri_test.rs deleted file mode 100644 index 20f13d17e..000000000 --- a/stun/src/uri/uri_test.rs +++ /dev/null @@ -1,68 +0,0 @@ -use super::*; - -#[test] -fn test_parse_uri() -> Result<()> { - let tests = vec![ - ( - "default", - "stun:example.org", - Uri { - host: "example.org".to_owned(), - scheme: SCHEME.to_owned(), - port: None, - }, - "stun:example.org", - ), - ( - "secure", - "stuns:example.org", - Uri { - host: "example.org".to_owned(), - scheme: SCHEME_SECURE.to_owned(), - port: None, - }, - "stuns:example.org", - ), - ( - "with port", - "stun:example.org:8000", - Uri { - host: "example.org".to_owned(), - scheme: SCHEME.to_owned(), - port: Some(8000), - }, - "stun:example.org:8000", - ), - ( - "ipv6 address", - "stun:[::1]:123", - Uri { - host: "::1".to_owned(), - scheme: SCHEME.to_owned(), - port: Some(123), - }, - "stun:[::1]:123", - ), - ]; - - for (name, input, output, expected_str) in tests { - let out = Uri::parse_uri(input)?; - assert_eq!(out, output, "{name}: {out} != {output}"); - assert_eq!(out.to_string(), expected_str, "{name}"); - } - - //"MustFail" - { - let tests = vec![ - ("hierarchical", "stun://example.org"), - ("bad scheme", "tcp:example.org"), - ("invalid uri scheme", "stun_s:test"), - ]; - for (name, input) in tests { - let result = Uri::parse_uri(input); - assert!(result.is_err(), "{name} should fail, but did not"); - } - } - - Ok(()) -} diff --git a/stun/src/xoraddr.rs b/stun/src/xoraddr.rs deleted file mode 100644 index 0a86bb35c..000000000 --- a/stun/src/xoraddr.rs +++ /dev/null @@ -1,173 +0,0 @@ -#[cfg(test)] -mod xoraddr_test; - -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::{fmt, mem}; - -use crate::addr::*; -use crate::attributes::*; -use crate::checks::*; -use crate::error::*; -use crate::message::*; - -const WORD_SIZE: usize = mem::size_of::(); - -//var supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" // nolint:gochecknoglobals - -// fast_xor_bytes xors in bulk. It only works on architectures that -// support unaligned read/writes. -/*TODO: fn fast_xor_bytes(dst:&[u8], a:&[u8], b:&[u8]) ->usize { - let mut n = a.len(); - if b.len() < n { - n = b.len(); - } - - let w = n / WORD_SIZE; - if w > 0 { - let dw = *(*[]uintptr)(unsafe.Pointer(&dst)) - let aw = *(*[]uintptr)(unsafe.Pointer(&a)) - let bw = *(*[]uintptr)(unsafe.Pointer(&b)) - for i := 0; i < w; i++ { - dw[i] = aw[i] ^ bw[i] - } - } - - for i := n - n%WORD_SIZE; i < n; i++ { - dst[i] = a[i] ^ b[i] - } - - return n -}*/ - -fn safe_xor_bytes(dst: &mut [u8], a: &[u8], b: &[u8]) -> usize { - let mut n = a.len(); - if b.len() < n { - n = b.len(); - } - if dst.len() < n { - n = dst.len(); - } - for i in 0..n { - dst[i] = a[i] ^ b[i]; - } - n -} - -/// xor_bytes xors the bytes in a and b. The destination is assumed to have enough -/// space. Returns the number of bytes xor'd. -pub fn xor_bytes(dst: &mut [u8], a: &[u8], b: &[u8]) -> usize { - //TODO: if supportsUnaligned { - // return fastXORBytes(dst, a, b) - //} - safe_xor_bytes(dst, a, b) -} - -/// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. -/// -/// RFC 5389 Section 15.2 -pub struct XorMappedAddress { - pub ip: IpAddr, - pub port: u16, -} - -impl Default for XorMappedAddress { - fn default() -> Self { - XorMappedAddress { - ip: IpAddr::V4(Ipv4Addr::from(0)), - port: 0, - } - } -} - -impl fmt::Display for XorMappedAddress { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let family = match self.ip { - IpAddr::V4(_) => FAMILY_IPV4, - IpAddr::V6(_) => FAMILY_IPV6, - }; - if family == FAMILY_IPV4 { - write!(f, "{}:{}", self.ip, self.port) - } else { - write!(f, "[{}]:{}", self.ip, self.port) - } - } -} - -impl Setter for XorMappedAddress { - /// add_to adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength - /// if len(a.IP) is invalid. - fn add_to(&self, m: &mut Message) -> Result<()> { - self.add_to_as(m, ATTR_XORMAPPED_ADDRESS) - } -} - -impl Getter for XorMappedAddress { - /// get_from decodes XOR-MAPPED-ADDRESS attribute in message and returns - /// error if any. While decoding, a.IP is reused if possible and can be - /// rendered to invalid state (e.g. if a.IP was set to IPv6 and then - /// IPv4 value were decoded into it), be careful. - fn get_from(&mut self, m: &Message) -> Result<()> { - self.get_from_as(m, ATTR_XORMAPPED_ADDRESS) - } -} - -impl XorMappedAddress { - /// add_to_as adds XOR-MAPPED-ADDRESS value to m as t attribute. - pub fn add_to_as(&self, m: &mut Message, t: AttrType) -> Result<()> { - let (family, ip_len, ip) = match self.ip { - IpAddr::V4(ipv4) => (FAMILY_IPV4, IPV4LEN, ipv4.octets().to_vec()), - IpAddr::V6(ipv6) => (FAMILY_IPV6, IPV6LEN, ipv6.octets().to_vec()), - }; - - let mut value = [0; 32 + 128]; - //value[0] = 0 // first 8 bits are zeroes - let mut xor_value = vec![0; IPV6LEN]; - xor_value[4..].copy_from_slice(&m.transaction_id.0); - xor_value[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes()); - value[0..2].copy_from_slice(&family.to_be_bytes()); - value[2..4].copy_from_slice(&(self.port ^ (MAGIC_COOKIE >> 16) as u16).to_be_bytes()); - xor_bytes(&mut value[4..4 + ip_len], &ip, &xor_value); - m.add(t, &value[..4 + ip_len]); - Ok(()) - } - - /// get_from_as decodes XOR-MAPPED-ADDRESS attribute value in message - /// getting it as for t type. - pub fn get_from_as(&mut self, m: &Message, t: AttrType) -> Result<()> { - let v = m.get(t)?; - if v.len() <= 4 { - return Err(Error::ErrUnexpectedEof); - } - - let family = u16::from_be_bytes([v[0], v[1]]); - if family != FAMILY_IPV6 && family != FAMILY_IPV4 { - return Err(Error::Other(format!("bad value {family}"))); - } - - check_overflow( - t, - v[4..].len(), - if family == FAMILY_IPV4 { - IPV4LEN - } else { - IPV6LEN - }, - )?; - self.port = u16::from_be_bytes([v[2], v[3]]) ^ (MAGIC_COOKIE >> 16) as u16; - let mut xor_value = vec![0; 4 + TRANSACTION_ID_SIZE]; - xor_value[0..4].copy_from_slice(&MAGIC_COOKIE.to_be_bytes()); - xor_value[4..].copy_from_slice(&m.transaction_id.0); - - if family == FAMILY_IPV6 { - let mut ip = [0; IPV6LEN]; - xor_bytes(&mut ip, &v[4..], &xor_value); - self.ip = IpAddr::V6(Ipv6Addr::from(ip)); - } else { - let mut ip = [0; IPV4LEN]; - xor_bytes(&mut ip, &v[4..], &xor_value); - self.ip = IpAddr::V4(Ipv4Addr::from(ip)); - }; - - Ok(()) - } -} diff --git a/stun/src/xoraddr/xoraddr_test.rs b/stun/src/xoraddr/xoraddr_test.rs deleted file mode 100644 index 2d5544a33..000000000 --- a/stun/src/xoraddr/xoraddr_test.rs +++ /dev/null @@ -1,250 +0,0 @@ -use std::io::BufReader; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; - -use super::*; -use crate::checks::*; - -#[test] -fn test_xor_safe() -> Result<()> { - let mut dst = vec![0; 8]; - let a = vec![1, 2, 3, 4, 5, 6, 7, 8]; - let b = vec![8, 7, 7, 6, 6, 3, 4, 1]; - safe_xor_bytes(&mut dst, &a, &b); - let c = dst.clone(); - safe_xor_bytes(&mut dst, &c, &a); - for i in 0..dst.len() { - assert_eq!(b[i], dst[i], "{} != {}", b[i], dst[i]); - } - - Ok(()) -} - -#[test] -fn test_xor_safe_bsmaller() -> Result<()> { - let mut dst = vec![0; 5]; - let a = vec![1, 2, 3, 4, 5, 6, 7, 8]; - let b = vec![8, 7, 7, 6, 6]; - safe_xor_bytes(&mut dst, &a, &b); - let c = dst.clone(); - safe_xor_bytes(&mut dst, &c, &a); - for i in 0..dst.len() { - assert_eq!(b[i], dst[i], "{} != {}", b[i], dst[i]); - } - - Ok(()) -} - -#[test] -fn test_xormapped_address_get_from() -> Result<()> { - let mut m = Message::new(); - let transaction_id = BASE64_STANDARD.decode("jxhBARZwX+rsC6er").unwrap(); - m.transaction_id.0.copy_from_slice(&transaction_id); - let addr_value = vec![0x00, 0x01, 0x9c, 0xd5, 0xf4, 0x9f, 0x38, 0xae]; - m.add(ATTR_XORMAPPED_ADDRESS, &addr_value); - let mut addr = XorMappedAddress { - ip: "0.0.0.0".parse().unwrap(), - port: 0, - }; - addr.get_from(&m)?; - assert_eq!( - addr.ip.to_string(), - "213.141.156.236", - "bad IP {} != 213.141.156.236", - addr.ip - ); - assert_eq!(addr.port, 48583, "bad Port {} != 48583", addr.port); - - //"UnexpectedEOF" - { - let mut m = Message::new(); - // {0, 1} is correct addr family. - m.add(ATTR_XORMAPPED_ADDRESS, &[0, 1, 3, 4]); - let mut addr = XorMappedAddress { - ip: "0.0.0.0".parse().unwrap(), - port: 0, - }; - let result = addr.get_from(&m); - if let Err(err) = result { - assert_eq!( - Error::ErrUnexpectedEof, - err, - "len(v) = 4 should render <{}> error, got <{}>", - Error::ErrUnexpectedEof, - err - ); - } else { - panic!("expected error, got ok"); - } - } - //"AttrOverflowErr" - { - let mut m = Message::new(); - // {0, 1} is correct addr family. - m.add( - ATTR_XORMAPPED_ADDRESS, - &[0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4], - ); - let mut addr = XorMappedAddress { - ip: "0.0.0.0".parse().unwrap(), - port: 0, - }; - let result = addr.get_from(&m); - if let Err(err) = result { - assert!( - is_attr_size_overflow(&err), - "AddTo should return AttrOverflowErr, got: {err}" - ); - } else { - panic!("expected error, got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_xormapped_address_get_from_invalid() -> Result<()> { - let mut m = Message::new(); - let transaction_id = BASE64_STANDARD.decode("jxhBARZwX+rsC6er").unwrap(); - m.transaction_id.0.copy_from_slice(&transaction_id); - let expected_ip: IpAddr = "213.141.156.236".parse().unwrap(); - let expected_port = 21254u16; - let mut addr = XorMappedAddress { - ip: "0.0.0.0".parse().unwrap(), - port: 0, - }; - let result = addr.get_from(&m); - assert!(result.is_err(), "should be error"); - - addr.ip = expected_ip; - addr.port = expected_port; - addr.add_to(&mut m)?; - m.write_header(); - - let mut m_res = Message::new(); - m.raw[20 + 4 + 1] = 0x21; - m.decode()?; - let mut reader = BufReader::new(m.raw.as_slice()); - m_res.read_from(&mut reader)?; - let result = addr.get_from(&m); - assert!(result.is_err(), "should be error"); - - Ok(()) -} - -#[test] -fn test_xormapped_address_add_to() -> Result<()> { - let mut m = Message::new(); - let transaction_id = BASE64_STANDARD.decode("jxhBARZwX+rsC6er").unwrap(); - m.transaction_id.0.copy_from_slice(&transaction_id); - let expected_ip: IpAddr = "213.141.156.236".parse().unwrap(); - let expected_port = 21254u16; - let mut addr = XorMappedAddress { - ip: "213.141.156.236".parse().unwrap(), - port: expected_port, - }; - addr.add_to(&mut m)?; - m.write_header(); - - let mut m_res = Message::new(); - m_res.write(&m.raw)?; - addr.get_from(&m_res)?; - assert_eq!( - addr.ip, expected_ip, - "{} (got) != {} (expected)", - addr.ip, expected_ip - ); - - assert_eq!( - addr.port, expected_port, - "bad Port {} != {}", - addr.port, expected_port - ); - - Ok(()) -} - -#[test] -fn test_xormapped_address_add_to_ipv6() -> Result<()> { - let mut m = Message::new(); - let transaction_id = BASE64_STANDARD.decode("jxhBARZwX+rsC6er").unwrap(); - m.transaction_id.0.copy_from_slice(&transaction_id); - let expected_ip: IpAddr = "fe80::dc2b:44ff:fe20:6009".parse().unwrap(); - let expected_port = 21254u16; - let addr = XorMappedAddress { - ip: "fe80::dc2b:44ff:fe20:6009".parse().unwrap(), - port: 21254, - }; - addr.add_to(&mut m)?; - m.write_header(); - - let mut m_res = Message::new(); - let mut reader = BufReader::new(m.raw.as_slice()); - m_res.read_from(&mut reader)?; - - let mut got_addr = XorMappedAddress { - ip: "0.0.0.0".parse().unwrap(), - port: 0, - }; - got_addr.get_from(&m)?; - - assert_eq!( - got_addr.ip, expected_ip, - "bad IP {} != {}", - got_addr.ip, expected_ip - ); - assert_eq!( - got_addr.port, expected_port, - "bad Port {} != {}", - got_addr.port, expected_port - ); - - Ok(()) -} - -/* -#[test] -fn TestXORMappedAddress_AddTo_Invalid() -> Result<()> { - let mut m = Message::new(); - let mut addr = XORMappedAddress{ - ip: 1, 2, 3, 4, 5, 6, 7, 8}, - port: 21254, - } - if err := addr.AddTo(m); !errors.Is(err, ErrBadIPLength) { - t.Errorf("AddTo should return %q, got: %v", ErrBadIPLength, err) - } -}*/ - -#[test] -fn test_xormapped_address_string() -> Result<()> { - let tests = vec![ - ( - // 0 - XorMappedAddress { - ip: "fe80::dc2b:44ff:fe20:6009".parse().unwrap(), - port: 124, - }, - "[fe80::dc2b:44ff:fe20:6009]:124", - ), - ( - // 1 - XorMappedAddress { - ip: "213.141.156.236".parse().unwrap(), - port: 8147, - }, - "213.141.156.236:8147", - ), - ]; - - for (addr, ip) in tests { - assert_eq!( - addr.to_string(), - ip, - " XORMappesAddress.String() {addr} (got) != {ip} (expected)", - ); - } - - Ok(()) -} diff --git a/turn/.gitignore b/turn/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/turn/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/turn/Cargo.toml b/turn/Cargo.toml deleted file mode 100644 index 7f54faacb..000000000 --- a/turn/Cargo.toml +++ /dev/null @@ -1,62 +0,0 @@ -[package] -name = "turn" -version = "0.8.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of TURN" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/turn" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/turn" - -[dependencies] -util = { version = "0.9.0", path = "../util", package = "webrtc-util", default-features = false, features = ["conn", "vnet"] } -stun = { version = "0.6.0", path = "../stun" } - -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -tokio-util = "0.7" -futures = "0.3" -async-trait = "0.1" -log = "0.4" -base64 = "0.21" -rand = "0.8" -ring = "0.17" -md-5 = "0.10" -thiserror = "1" -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" -env_logger = "0.10" -chrono = "0.4.28" -hex = "0.4" -clap = "3" -criterion = "0.5" - -[features] -metrics = [] - -[[bench]] -name = "bench" -harness = false - -[[example]] -name = "turn_client_udp" -path = "examples/turn_client_udp.rs" -bench = false - -[[example]] -name = "turn_server_udp" -path = "examples/turn_server_udp.rs" -bench = false diff --git a/turn/LICENSE-APACHE b/turn/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/turn/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/turn/LICENSE-MIT b/turn/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/turn/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/turn/README.md b/turn/README.md deleted file mode 100644 index 13218f8c6..000000000 --- a/turn/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- A pure Rust implementation of TURN. Rewrite Pion TURN in Rust -

diff --git a/turn/benches/bench.rs b/turn/benches/bench.rs deleted file mode 100644 index 48485ba58..000000000 --- a/turn/benches/bench.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::time::Duration; - -use criterion::{criterion_group, criterion_main, Criterion}; -use stun::attributes::ATTR_DATA; -use stun::message::{Getter, Message, Setter}; -use turn::proto::chandata::ChannelData; -use turn::proto::channum::{ChannelNumber, MIN_CHANNEL_NUMBER}; -use turn::proto::data::Data; -use turn::proto::lifetime::Lifetime; - -fn benchmark_chan_data(c: &mut Criterion) { - { - let buf = [64, 0, 0, 0, 0, 4, 0, 0, 1, 2, 3]; - c.bench_function("BenchmarkIsChannelData", |b| { - b.iter(|| { - assert!(ChannelData::is_channel_data(&buf)); - }) - }); - } - - { - let mut d = ChannelData { - data: vec![1, 2, 3, 4], - number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), - raw: vec![], - }; - c.bench_function("BenchmarkChannelData_Encode", |b| { - b.iter(|| { - d.encode(); - d.reset(); - }) - }); - } - - { - let mut d = ChannelData { - data: vec![1, 2, 3, 4], - number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), - raw: vec![], - }; - d.encode(); - let mut buf = vec![0u8; d.raw.len()]; - buf.copy_from_slice(&d.raw); - c.bench_function("BenchmarkChannelData_Decode", |b| { - b.iter(|| { - d.reset(); - d.raw.clone_from(&buf); - d.decode().unwrap(); - }) - }); - } -} - -fn benchmark_chan(c: &mut Criterion) { - { - let mut m = Message::new(); - c.bench_function("BenchmarkChannelNumber/AddTo", |b| { - b.iter(|| { - let n = ChannelNumber(12); - n.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let expected = ChannelNumber(12); - expected.add_to(&mut m).unwrap(); - let mut n = ChannelNumber::default(); - c.bench_function("BenchmarkChannelNumber/GetFrom", |b| { - b.iter(|| { - n.get_from(&m).unwrap(); - assert_eq!(n, expected); - }) - }); - } -} - -fn benchmark_data(c: &mut Criterion) { - { - let mut m = Message::new(); - let d = Data(vec![0u8; 10]); - c.bench_function("BenchmarkData/AddTo", |b| { - b.iter(|| { - d.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let d = Data(vec![0u8; 10]); - c.bench_function("BenchmarkData/AddToRaw", |b| { - b.iter(|| { - m.add(ATTR_DATA, &d.0); - m.reset(); - }) - }); - } -} - -fn benchmark_lifetime(c: &mut Criterion) { - { - let mut m = Message::new(); - let l = Lifetime(Duration::from_secs(1)); - c.bench_function("BenchmarkLifetime/AddTo", |b| { - b.iter(|| { - l.add_to(&mut m).unwrap(); - m.reset(); - }) - }); - } - - { - let mut m = Message::new(); - let expected = Lifetime(Duration::from_secs(60)); - expected.add_to(&mut m).unwrap(); - let mut l = Lifetime::default(); - c.bench_function("BenchmarkLifetime/GetFrom", |b| { - b.iter(|| { - l.get_from(&m).unwrap(); - assert_eq!(l, expected); - }) - }); - } -} - -criterion_group!( - benches, - benchmark_chan_data, - benchmark_chan, - benchmark_data, - benchmark_lifetime -); -criterion_main!(benches); diff --git a/turn/codecov.yml b/turn/codecov.yml deleted file mode 100644 index bf7afa148..000000000 --- a/turn/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 640e45ed-ce83-43e1-9eee-473aa65dc136 - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/turn/doc/webrtc.rs.png b/turn/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/turn/doc/webrtc.rs.png and /dev/null differ diff --git a/turn/examples/turn_client_udp.rs b/turn/examples/turn_client_udp.rs deleted file mode 100644 index ed71976fd..000000000 --- a/turn/examples/turn_client_udp.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use tokio::time::Duration; -use turn::client::*; -use turn::Error; -use util::Conn; - -// RUST_LOG=trace cargo run --color=always --package turn --example turn_client_udp -- --host 0.0.0.0 --user user=pass --ping - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::init(); - - let mut app = App::new("TURN Client UDP") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of TURN Client UDP") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("host") - .required_unless("FULLHELP") - .takes_value(true) - .long("host") - .help("TURN Server name."), - ) - .arg( - Arg::with_name("user") - .required_unless("FULLHELP") - .takes_value(true) - .long("user") - .help("A pair of username and password (e.g. \"user=pass\")"), - ) - .arg( - Arg::with_name("realm") - .default_value("webrtc.rs") - .takes_value(true) - .long("realm") - .help("Realm (defaults to \"webrtc.rs\")"), - ) - .arg( - Arg::with_name("port") - .takes_value(true) - .default_value("3478") - .long("port") - .help("Listening port."), - ) - .arg( - Arg::with_name("ping") - .long("ping") - .takes_value(false) - .help("Run ping test"), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let host = matches.value_of("host").unwrap(); - let port = matches.value_of("port").unwrap(); - let user = matches.value_of("user").unwrap(); - let cred: Vec<&str> = user.splitn(2, '=').collect(); - let ping = matches.is_present("ping"); - let realm = matches.value_of("realm").unwrap(); - - // TURN client won't create a local listening socket by itself. - let conn = UdpSocket::bind("0.0.0.0:0").await?; - - let turn_server_addr = format!("{host}:{port}"); - - let cfg = ClientConfig { - stun_serv_addr: turn_server_addr.clone(), - turn_serv_addr: turn_server_addr, - username: cred[0].to_string(), - password: cred[1].to_string(), - realm: realm.to_string(), - software: String::new(), - rto_in_ms: 0, - conn: Arc::new(conn), - vnet: None, - }; - - let client = Client::new(cfg).await?; - - // Start listening on the conn provided. - client.listen().await?; - - // Allocate a relay socket on the TURN server. On success, it - // will return a net.PacketConn which represents the remote - // socket. - let relay_conn = client.allocate().await?; - - // The relayConn's local address is actually the transport - // address assigned on the TURN server. - println!("relayed-address={}", relay_conn.local_addr()?); - - // If you provided `-ping`, perform a ping test against the - // relayConn we have just allocated. - if ping { - do_ping_test(&client, relay_conn).await?; - } - - client.close().await?; - - Ok(()) -} - -async fn do_ping_test( - client: &Client, - relay_conn: impl Conn + std::marker::Send + std::marker::Sync + 'static, -) -> Result<(), Error> { - // Send BindingRequest to learn our external IP - let mapped_addr = client.send_binding_request().await?; - - // Set up pinger socket (pingerConn) - //println!("bind..."); - let pinger_conn_tx = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - // Punch a UDP hole for the relay_conn by sending a data to the mapped_addr. - // This will trigger a TURN client to generate a permission request to the - // TURN server. After this, packets from the IP address will be accepted by - // the TURN server. - //println!("relay_conn send hello to mapped_addr {}", mapped_addr); - relay_conn.send_to("Hello".as_bytes(), mapped_addr).await?; - let relay_addr = relay_conn.local_addr()?; - - let pinger_conn_rx = Arc::clone(&pinger_conn_tx); - - // Start read-loop on pingerConn - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - let (n, from) = match pinger_conn_rx.recv_from(&mut buf).await { - Ok((n, from)) => (n, from), - Err(_) => break, - }; - - let msg = match String::from_utf8(buf[..n].to_vec()) { - Ok(msg) => msg, - Err(_) => break, - }; - - println!("pingerConn read-loop: {msg} from {from}"); - /*if sentAt, pingerErr := time.Parse(time.RFC3339Nano, msg); pingerErr == nil { - rtt := time.Since(sentAt) - log.Printf("%d bytes from from %s time=%d ms\n", n, from.String(), int(rtt.Seconds()*1000)) - }*/ - } - }); - - // Start read-loop on relay_conn - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - let (n, from) = match relay_conn.recv_from(&mut buf).await { - Err(_) => break, - Ok((n, from)) => (n, from), - }; - - println!("relay_conn read-loop: {:?} from {}", &buf[..n], from); - - // Echo back - if relay_conn.send_to(&buf[..n], from).await.is_err() { - break; - } - } - }); - - tokio::time::sleep(Duration::from_millis(500)).await; - - /*println!( - "pinger_conn_tx send 10 packets to relay addr {}...", - relay_addr - );*/ - // Send 10 packets from relay_conn to the echo server - for _ in 0..2 { - let msg = "12345678910".to_owned(); //format!("{:?}", tokio::time::Instant::now()); - println!("sending msg={} with size={}", msg, msg.as_bytes().len()); - pinger_conn_tx.send_to(msg.as_bytes(), relay_addr).await?; - - // For simplicity, this example does not wait for the pong (reply). - // Instead, sleep 1 second. - tokio::time::sleep(Duration::from_secs(1)).await; - } - - Ok(()) -} diff --git a/turn/examples/turn_server_udp.rs b/turn/examples/turn_server_udp.rs deleted file mode 100644 index ae8f88c00..000000000 --- a/turn/examples/turn_server_udp.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::collections::HashMap; -use std::net::{IpAddr, SocketAddr}; -use std::str::FromStr; -use std::sync::Arc; - -use clap::{App, AppSettings, Arg}; -use tokio::net::UdpSocket; -use tokio::signal; -use tokio::time::Duration; -use turn::auth::*; -use turn::relay::relay_static::*; -use turn::server::config::*; -use turn::server::*; -use turn::Error; -use util::vnet::net::*; - -struct MyAuthHandler { - cred_map: HashMap>, -} - -impl MyAuthHandler { - fn new(cred_map: HashMap>) -> Self { - MyAuthHandler { cred_map } - } -} - -impl AuthHandler for MyAuthHandler { - fn auth_handle( - &self, - username: &str, - _realm: &str, - _src_addr: SocketAddr, - ) -> Result, Error> { - if let Some(pw) = self.cred_map.get(username) { - //log::debug!("username={}, password={:?}", username, pw); - Ok(pw.to_vec()) - } else { - Err(Error::ErrFakeErr) - } - } -} - -// RUST_LOG=trace cargo run --color=always --package turn --example turn_server_udp -- --public-ip 0.0.0.0 --users user=pass - -#[tokio::main] -async fn main() -> Result<(), Error> { - env_logger::init(); - - let mut app = App::new("TURN Server UDP") - .version("0.1.0") - .author("Rain Liu ") - .about("An example of TURN Server UDP") - .setting(AppSettings::DeriveDisplayOrder) - .setting(AppSettings::SubcommandsNegateReqs) - .arg( - Arg::with_name("FULLHELP") - .help("Prints more detailed help information") - .long("fullhelp"), - ) - .arg( - Arg::with_name("public-ip") - .required_unless("FULLHELP") - .takes_value(true) - .long("public-ip") - .help("IP Address that TURN can be contacted by."), - ) - .arg( - Arg::with_name("users") - .required_unless("FULLHELP") - .takes_value(true) - .long("users") - .help("List of username and password (e.g. \"user=pass,user=pass\")"), - ) - .arg( - Arg::with_name("realm") - .default_value("webrtc.rs") - .takes_value(true) - .long("realm") - .help("Realm (defaults to \"webrtc.rs\")"), - ) - .arg( - Arg::with_name("port") - .takes_value(true) - .default_value("3478") - .long("port") - .help("Listening port."), - ); - - let matches = app.clone().get_matches(); - - if matches.is_present("FULLHELP") { - app.print_long_help().unwrap(); - std::process::exit(0); - } - - let public_ip = matches.value_of("public-ip").unwrap(); - let port = matches.value_of("port").unwrap(); - let users = matches.value_of("users").unwrap(); - let realm = matches.value_of("realm").unwrap(); - - // Cache -users flag for easy lookup later - // If passwords are stored they should be saved to your DB hashed using turn.GenerateAuthKey - let creds: Vec<&str> = users.split(',').collect(); - let mut cred_map = HashMap::new(); - for user in creds { - let cred: Vec<&str> = user.splitn(2, '=').collect(); - let key = generate_auth_key(cred[0], realm, cred[1]); - cred_map.insert(cred[0].to_owned(), key); - } - - // Create a UDP listener to pass into pion/turn - // turn itself doesn't allocate any UDP sockets, but lets the user pass them in - // this allows us to add logging, storage or modify inbound/outbound traffic - let conn = Arc::new(UdpSocket::bind(format!("0.0.0.0:{port}")).await?); - println!("listening {}...", conn.local_addr()?); - - let server = Server::new(ServerConfig { - conn_configs: vec![ConnConfig { - conn, - relay_addr_generator: Box::new(RelayAddressGeneratorStatic { - relay_address: IpAddr::from_str(public_ip)?, - address: "0.0.0.0".to_owned(), - net: Arc::new(Net::new(None)), - }), - }], - realm: realm.to_owned(), - auth_handler: Arc::new(MyAuthHandler::new(cred_map)), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - println!("Waiting for Ctrl-C..."); - signal::ctrl_c().await.expect("failed to listen for event"); - println!("\nClosing connection now..."); - server.close().await?; - - Ok(()) -} diff --git a/turn/src/allocation/allocation_manager.rs b/turn/src/allocation/allocation_manager.rs deleted file mode 100644 index f3c443689..000000000 --- a/turn/src/allocation/allocation_manager.rs +++ /dev/null @@ -1,198 +0,0 @@ -#[cfg(test)] -mod allocation_manager_test; - -use std::collections::HashMap; - -use futures::future; -use stun::textattrs::Username; -use tokio::sync::mpsc; -use util::Conn; - -use super::*; -use crate::error::*; -use crate::relay::*; - -/// `ManagerConfig` a bag of config params for `Manager`. -pub struct ManagerConfig { - pub relay_addr_generator: Box, - pub alloc_close_notify: Option>, -} - -/// `Manager` is used to hold active allocations. -pub struct Manager { - allocations: AllocationMap, - reservations: Arc>>, - relay_addr_generator: Box, - alloc_close_notify: Option>, -} - -impl Manager { - /// Creates a new [`Manager`]. - pub fn new(config: ManagerConfig) -> Self { - Manager { - allocations: Arc::new(Mutex::new(HashMap::new())), - reservations: Arc::new(Mutex::new(HashMap::new())), - relay_addr_generator: config.relay_addr_generator, - alloc_close_notify: config.alloc_close_notify, - } - } - - /// Closes this [`manager`] and closes all [`Allocation`]s it manages. - pub async fn close(&self) -> Result<()> { - let allocations = self.allocations.lock().await; - for a in allocations.values() { - a.close().await?; - } - Ok(()) - } - - /// Returns the information about the all [`Allocation`]s associated with - /// the specified [`FiveTuple`]s. - pub async fn get_allocations_info( - &self, - five_tuples: Option>, - ) -> HashMap { - let mut infos = HashMap::new(); - - let guarded = self.allocations.lock().await; - - guarded.iter().for_each(|(five_tuple, alloc)| { - if five_tuples.is_none() || five_tuples.as_ref().unwrap().contains(five_tuple) { - infos.insert( - *five_tuple, - AllocationInfo::new( - *five_tuple, - alloc.username.text.clone(), - #[cfg(feature = "metrics")] - alloc.relayed_bytes.load(Ordering::Acquire), - ), - ); - } - }); - - infos - } - - /// Fetches the [`Allocation`] matching the passed [`FiveTuple`]. - pub async fn get_allocation(&self, five_tuple: &FiveTuple) -> Option> { - let allocations = self.allocations.lock().await; - allocations.get(five_tuple).cloned() - } - - /// Creates a new [`Allocation`] and starts relaying. - pub async fn create_allocation( - &self, - five_tuple: FiveTuple, - turn_socket: Arc, - requested_port: u16, - lifetime: Duration, - username: Username, - use_ipv4: bool, - ) -> Result> { - if lifetime == Duration::from_secs(0) { - return Err(Error::ErrLifetimeZero); - } - - if self.get_allocation(&five_tuple).await.is_some() { - return Err(Error::ErrDupeFiveTuple); - } - - let (relay_socket, relay_addr) = self - .relay_addr_generator - .allocate_conn(use_ipv4, requested_port) - .await?; - let mut a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - five_tuple, - username, - self.alloc_close_notify.clone(), - ); - a.allocations = Some(Arc::clone(&self.allocations)); - - log::debug!("listening on relay addr: {:?}", a.relay_addr); - a.start(lifetime).await; - a.packet_handler().await; - - let a = Arc::new(a); - { - let mut allocations = self.allocations.lock().await; - allocations.insert(five_tuple, Arc::clone(&a)); - } - - Ok(a) - } - - /// Removes an [`Allocation`]. - pub async fn delete_allocation(&self, five_tuple: &FiveTuple) { - let allocation = self.allocations.lock().await.remove(five_tuple); - - if let Some(a) = allocation { - if let Err(err) = a.close().await { - log::error!("Failed to close allocation: {}", err); - } - } - } - - /// Deletes the [`Allocation`]s according to the specified username `name`. - pub async fn delete_allocations_by_username(&self, name: &str) { - let to_delete = { - let mut allocations = self.allocations.lock().await; - - let mut to_delete = Vec::new(); - - // TODO(logist322): Use `.drain_filter()` once stabilized. - allocations.retain(|_, allocation| { - let match_name = allocation.username.text == name; - - if match_name { - to_delete.push(Arc::clone(allocation)); - } - - !match_name - }); - - to_delete - }; - - future::join_all(to_delete.iter().map(|a| async move { - if let Err(err) = a.close().await { - log::error!("Failed to close allocation: {}", err); - } - })) - .await; - } - - /// Stores the reservation for the token+port. - pub async fn create_reservation(&self, reservation_token: String, port: u16) { - let reservations = Arc::clone(&self.reservations); - let reservation_token2 = reservation_token.clone(); - - tokio::spawn(async move { - let sleep = tokio::time::sleep(Duration::from_secs(30)); - tokio::pin!(sleep); - tokio::select! { - _ = &mut sleep => { - let mut reservations = reservations.lock().await; - reservations.remove(&reservation_token2); - }, - } - }); - - let mut reservations = self.reservations.lock().await; - reservations.insert(reservation_token, port); - } - - /// Returns the port for a given reservation if it exists. - pub async fn get_reservation(&self, reservation_token: &str) -> Option { - let reservations = self.reservations.lock().await; - reservations.get(reservation_token).copied() - } - - /// Returns a random un-allocated udp4 port. - pub async fn get_random_even_port(&self) -> Result { - let (_, addr) = self.relay_addr_generator.allocate_conn(true, 0).await?; - Ok(addr.port()) - } -} diff --git a/turn/src/allocation/allocation_manager/allocation_manager_test.rs b/turn/src/allocation/allocation_manager/allocation_manager_test.rs deleted file mode 100644 index c994d6e55..000000000 --- a/turn/src/allocation/allocation_manager/allocation_manager_test.rs +++ /dev/null @@ -1,613 +0,0 @@ -use std::net::{IpAddr, Ipv4Addr}; -use std::str::FromStr; - -use stun::attributes::ATTR_USERNAME; -use stun::textattrs::TextAttribute; -use tokio::net::UdpSocket; -use tokio::sync::mpsc::Sender; -use util::vnet::net::*; - -use super::*; -use crate::auth::{generate_auth_key, AuthHandler}; -use crate::client::{Client, ClientConfig}; -use crate::error::Result; -use crate::proto::lifetime::DEFAULT_LIFETIME; -use crate::relay::relay_none::*; -use crate::relay::relay_static::RelayAddressGeneratorStatic; -use crate::server::config::{ConnConfig, ServerConfig}; -use crate::server::Server; - -fn new_test_manager() -> Manager { - let config = ManagerConfig { - relay_addr_generator: Box::new(RelayAddressGeneratorNone { - address: "0.0.0.0".to_owned(), - net: Arc::new(Net::new(None)), - }), - alloc_close_notify: None, - }; - Manager::new(config) -} - -fn random_five_tuple() -> FiveTuple { - /* #nosec */ - FiveTuple { - src_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), rand::random()), - dst_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), rand::random()), - ..Default::default() - } -} - -#[tokio::test] -async fn test_packet_handler() -> Result<()> { - //env_logger::init(); - - // turn server initialization - let turn_socket = UdpSocket::bind("127.0.0.1:0").await?; - - // client listener initialization - let client_listener = UdpSocket::bind("127.0.0.1:0").await?; - let src_addr = client_listener.local_addr()?; - let (data_ch_tx, mut data_ch_rx) = mpsc::channel(1); - // client listener read data - tokio::spawn(async move { - let mut buffer = vec![0u8; RTP_MTU]; - loop { - let n = match client_listener.recv_from(&mut buffer).await { - Ok((n, _)) => n, - Err(_) => break, - }; - - let _ = data_ch_tx.send(buffer[..n].to_vec()).await; - } - }); - - let m = new_test_manager(); - let a = m - .create_allocation( - FiveTuple { - src_addr, - dst_addr: turn_socket.local_addr()?, - ..Default::default() - }, - Arc::new(turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - - let peer_listener1 = UdpSocket::bind("127.0.0.1:0").await?; - let peer_listener2 = UdpSocket::bind("127.0.0.1:0").await?; - - let channel_bind = ChannelBind::new( - ChannelNumber(MIN_CHANNEL_NUMBER), - peer_listener2.local_addr()?, - ); - - let port = { - // add permission with peer1 address - a.add_permission(Permission::new(peer_listener1.local_addr()?)) - .await; - // add channel with min channel number and peer2 address - a.add_channel_bind(channel_bind.clone(), DEFAULT_LIFETIME) - .await?; - - a.relay_socket.local_addr()?.port() - }; - - let relay_addr_with_host_str = format!("127.0.0.1:{port}"); - let relay_addr_with_host = SocketAddr::from_str(&relay_addr_with_host_str)?; - - // test for permission and data message - let target_text = "permission"; - let _ = peer_listener1 - .send_to(target_text.as_bytes(), relay_addr_with_host) - .await?; - let data = data_ch_rx - .recv() - .await - .ok_or(Error::Other("data ch closed".to_owned()))?; - - // resolve stun data message - assert!(is_message(&data), "should be stun message"); - - let mut msg = Message::new(); - msg.raw = data; - msg.decode()?; - - let mut msg_data = Data::default(); - msg_data.get_from(&msg)?; - assert_eq!( - target_text.as_bytes(), - &msg_data.0, - "get message doesn't equal the target text" - ); - - // test for channel bind and channel data - let target_text2 = "channel bind"; - let _ = peer_listener2 - .send_to(target_text2.as_bytes(), relay_addr_with_host) - .await?; - let data = data_ch_rx - .recv() - .await - .ok_or(Error::Other("data ch closed".to_owned()))?; - - // resolve channel data - assert!( - ChannelData::is_channel_data(&data), - "should be channel data" - ); - - let mut channel_data = ChannelData { - raw: data, - ..Default::default() - }; - channel_data.decode()?; - assert_eq!( - channel_bind.number, channel_data.number, - "get channel data's number is invalid" - ); - assert_eq!( - target_text2.as_bytes(), - &channel_data.data, - "get data doesn't equal the target text." - ); - - // listeners close - m.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_create_allocation_duplicate_five_tuple() -> Result<()> { - //env_logger::init(); - - // turn server initialization - let turn_socket: Arc = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let m = new_test_manager(); - - let five_tuple = random_five_tuple(); - - let _ = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - - let result = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await; - assert!(result.is_err(), "expected error, but got ok"); - - Ok(()) -} - -#[tokio::test] -async fn test_delete_allocation() -> Result<()> { - //env_logger::init(); - - // turn server initialization - let turn_socket: Arc = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let m = new_test_manager(); - - let five_tuple = random_five_tuple(); - - let _ = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - - assert!( - m.get_allocation(&five_tuple).await.is_some(), - "Failed to get allocation right after creation" - ); - - m.delete_allocation(&five_tuple).await; - - assert!( - m.get_allocation(&five_tuple).await.is_none(), - "Get allocation with {five_tuple} should be nil after delete" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_allocation_timeout() -> Result<()> { - //env_logger::init(); - - // turn server initialization - let turn_socket: Arc = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let m = new_test_manager(); - - let mut allocations = vec![]; - let lifetime = Duration::from_millis(100); - - for _ in 0..5 { - let five_tuple = random_five_tuple(); - - let a = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - lifetime, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - - allocations.push(a); - } - - let mut count = 0; - - 'outer: loop { - count += 1; - - if count >= 10 { - panic!("Allocations didn't timeout"); - } - - tokio::time::sleep(lifetime + Duration::from_millis(100)).await; - - let any_outstanding = false; - - for a in &allocations { - if a.close().await.is_ok() { - continue 'outer; - } - } - - if !any_outstanding { - return Ok(()); - } - } -} - -#[tokio::test] -async fn test_manager_close() -> Result<()> { - // env_logger::init(); - - // turn server initialization - let turn_socket: Arc = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let m = new_test_manager(); - - let mut allocations = vec![]; - - let a1 = m - .create_allocation( - random_five_tuple(), - Arc::clone(&turn_socket), - 0, - Duration::from_millis(100), - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - allocations.push(a1); - - let a2 = m - .create_allocation( - random_five_tuple(), - Arc::clone(&turn_socket), - 0, - Duration::from_millis(200), - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - allocations.push(a2); - - tokio::time::sleep(Duration::from_millis(150)).await; - - log::trace!("Mgr is going to be closed..."); - - m.close().await?; - - for a in allocations { - assert!( - a.close().await.is_err(), - "Allocation should be closed if lifetime timeout" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn test_delete_allocation_by_username() -> Result<()> { - let turn_socket: Arc = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let m = new_test_manager(); - - let five_tuple1 = random_five_tuple(); - let five_tuple2 = random_five_tuple(); - let five_tuple3 = random_five_tuple(); - - let _ = m - .create_allocation( - five_tuple1, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - let _ = m - .create_allocation( - five_tuple2, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - let _ = m - .create_allocation( - five_tuple3, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - TextAttribute::new(ATTR_USERNAME, "user2".into()), - true, - ) - .await?; - - assert_eq!(m.allocations.lock().await.len(), 3); - - m.delete_allocations_by_username("user").await; - - assert_eq!(m.allocations.lock().await.len(), 1); - - assert!( - m.get_allocation(&five_tuple1).await.is_none() - && m.get_allocation(&five_tuple2).await.is_none() - && m.get_allocation(&five_tuple3).await.is_some() - ); - - Ok(()) -} - -struct TestAuthHandler; -impl AuthHandler for TestAuthHandler { - fn auth_handle(&self, username: &str, realm: &str, _src_addr: SocketAddr) -> Result> { - Ok(generate_auth_key(username, realm, "pass")) - } -} - -async fn create_server( - alloc_close_notify: Option>, -) -> Result<(Server, u16)> { - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let server_port = conn.local_addr()?.port(); - - let server = Server::new(ServerConfig { - conn_configs: vec![ConnConfig { - conn, - relay_addr_generator: Box::new(RelayAddressGeneratorStatic { - relay_address: IpAddr::from_str("127.0.0.1")?, - address: "0.0.0.0".to_owned(), - net: Arc::new(Net::new(None)), - }), - }], - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(TestAuthHandler {}), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify, - }) - .await?; - - Ok((server, server_port)) -} - -async fn create_client(username: String, server_port: u16) -> Result { - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - Client::new(ClientConfig { - stun_serv_addr: format!("127.0.0.1:{server_port}"), - turn_serv_addr: format!("127.0.0.1:{server_port}"), - username, - password: "pass".to_owned(), - realm: String::new(), - software: String::new(), - rto_in_ms: 0, - conn, - vnet: None, - }) - .await -} - -#[cfg(feature = "metrics")] -#[tokio::test] -async fn test_get_allocations_info() -> Result<()> { - let (server, server_port) = create_server(None).await?; - - let client1 = create_client("user1".to_owned(), server_port).await?; - client1.listen().await?; - - let client2 = create_client("user2".to_owned(), server_port).await?; - client2.listen().await?; - - let client3 = create_client("user3".to_owned(), server_port).await?; - client3.listen().await?; - - assert!(server.get_allocations_info(None).await?.is_empty()); - - let user1 = client1.allocate().await?; - let user2 = client2.allocate().await?; - let user3 = client3.allocate().await?; - - assert_eq!(server.get_allocations_info(None).await?.len(), 3); - - let addr1 = client1 - .send_binding_request_to(format!("127.0.0.1:{server_port}").as_str()) - .await?; - let addr2 = client2 - .send_binding_request_to(format!("127.0.0.1:{server_port}").as_str()) - .await?; - let addr3 = client3 - .send_binding_request_to(format!("127.0.0.1:{server_port}").as_str()) - .await?; - - user1.send_to(b"1", addr1).await?; - user2.send_to(b"12", addr2).await?; - user3.send_to(b"123", addr3).await?; - - tokio::time::sleep(Duration::from_millis(100)).await; - - server - .get_allocations_info(None) - .await? - .iter() - .for_each(|(_, ai)| match ai.username.as_str() { - "user1" => assert_eq!(ai.relayed_bytes, 1), - "user2" => assert_eq!(ai.relayed_bytes, 2), - "user3" => assert_eq!(ai.relayed_bytes, 3), - _ => unreachable!(), - }); - - Ok(()) -} - -#[cfg(feature = "metrics")] -#[tokio::test] -async fn test_get_allocations_info_bytes_count() -> Result<()> { - let (server, server_port) = create_server(None).await?; - - let client = create_client("foo".to_owned(), server_port).await?; - - client.listen().await?; - - assert!(server.get_allocations_info(None).await?.is_empty()); - - let conn = client.allocate().await?; - let addr = client - .send_binding_request_to(format!("127.0.0.1:{server_port}").as_str()) - .await?; - - assert!(!server.get_allocations_info(None).await?.is_empty()); - - assert_eq!( - server - .get_allocations_info(None) - .await? - .values() - .last() - .unwrap() - .relayed_bytes, - 0 - ); - - for _ in 0..10 { - conn.send_to(b"Hello", addr).await?; - - tokio::time::sleep(Duration::from_millis(100)).await; - } - - tokio::time::sleep(Duration::from_millis(1000)).await; - - assert_eq!( - server - .get_allocations_info(None) - .await? - .values() - .last() - .unwrap() - .relayed_bytes, - 50 - ); - - for _ in 0..10 { - conn.send_to(b"Hello", addr).await?; - - tokio::time::sleep(Duration::from_millis(100)).await; - } - - tokio::time::sleep(Duration::from_millis(1000)).await; - - assert_eq!( - server - .get_allocations_info(None) - .await? - .values() - .last() - .unwrap() - .relayed_bytes, - 100 - ); - - client.close().await?; - server.close().await?; - - Ok(()) -} - -#[cfg(feature = "metrics")] -#[tokio::test] -async fn test_alloc_close_notify() -> Result<()> { - let (tx, mut rx) = mpsc::channel::(1); - - tokio::spawn(async move { - if let Some(alloc) = rx.recv().await { - assert_eq!(alloc.relayed_bytes, 50); - } - }); - - let (server, server_port) = create_server(Some(tx)).await?; - - let client = create_client("foo".to_owned(), server_port).await?; - - client.listen().await?; - - assert!(server.get_allocations_info(None).await?.is_empty()); - - let conn = client.allocate().await?; - let addr = client - .send_binding_request_to(format!("127.0.0.1:{server_port}").as_str()) - .await?; - - assert!(!server.get_allocations_info(None).await?.is_empty()); - - for _ in 0..10 { - conn.send_to(b"Hello", addr).await?; - - tokio::time::sleep(Duration::from_millis(100)).await; - } - - tokio::time::sleep(Duration::from_millis(1000)).await; - - client.close().await?; - server.close().await?; - - tokio::time::sleep(Duration::from_millis(1000)).await; - - Ok(()) -} diff --git a/turn/src/allocation/allocation_test.rs b/turn/src/allocation/allocation_test.rs deleted file mode 100644 index fc7caf386..000000000 --- a/turn/src/allocation/allocation_test.rs +++ /dev/null @@ -1,296 +0,0 @@ -use std::str::FromStr; - -use stun::attributes::ATTR_USERNAME; -use stun::textattrs::TextAttribute; -use tokio::net::UdpSocket; - -use super::*; -use crate::proto::lifetime::DEFAULT_LIFETIME; - -#[tokio::test] -async fn test_has_permission() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr1 = SocketAddr::from_str("127.0.0.1:3478")?; - let addr2 = SocketAddr::from_str("127.0.0.1:3479")?; - let addr3 = SocketAddr::from_str("127.0.0.2:3478")?; - - let p1 = Permission::new(addr1); - let p2 = Permission::new(addr2); - let p3 = Permission::new(addr3); - - a.add_permission(p1).await; - a.add_permission(p2).await; - a.add_permission(p3).await; - - let found_p1 = a.has_permission(&addr1).await; - assert!(found_p1, "Should keep the first one."); - - let found_p2 = a.has_permission(&addr2).await; - assert!(found_p2, "Second one should be ignored."); - - let found_p3 = a.has_permission(&addr3).await; - assert!(found_p3, "Permission with another IP should be found"); - - Ok(()) -} - -#[tokio::test] -async fn test_add_permission() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - let p = Permission::new(addr); - a.add_permission(p).await; - - let found_p = a.has_permission(&addr).await; - assert!(found_p, "Should keep the first one."); - - Ok(()) -} - -#[tokio::test] -async fn test_remove_permission() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - - let p = Permission::new(addr); - a.add_permission(p).await; - - let found_p = a.has_permission(&addr).await; - assert!(found_p, "Should keep the first one."); - - a.remove_permission(&addr).await; - - let found_permission = a.has_permission(&addr).await; - assert!( - !found_permission, - "Got permission should be nil after removed." - ); - - Ok(()) -} - -#[tokio::test] -async fn test_add_channel_bind() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - let c = ChannelBind::new(ChannelNumber(MIN_CHANNEL_NUMBER), addr); - - a.add_channel_bind(c, DEFAULT_LIFETIME).await?; - - let c2 = ChannelBind::new(ChannelNumber(MIN_CHANNEL_NUMBER + 1), addr); - let result = a.add_channel_bind(c2, DEFAULT_LIFETIME).await; - assert!( - result.is_err(), - "should failed with conflicted peer address" - ); - - let addr2 = SocketAddr::from_str("127.0.0.1:3479")?; - let c3 = ChannelBind::new(ChannelNumber(MIN_CHANNEL_NUMBER), addr2); - let result = a.add_channel_bind(c3, DEFAULT_LIFETIME).await; - assert!(result.is_err(), "should fail with conflicted number."); - - Ok(()) -} - -#[tokio::test] -async fn test_get_channel_by_number() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - let c = ChannelBind::new(ChannelNumber(MIN_CHANNEL_NUMBER), addr); - - a.add_channel_bind(c, DEFAULT_LIFETIME).await?; - - let exist_channel_addr = a - .get_channel_addr(&ChannelNumber(MIN_CHANNEL_NUMBER)) - .await - .unwrap(); - assert_eq!(addr, exist_channel_addr); - - let not_exist_channel = a - .get_channel_addr(&ChannelNumber(MIN_CHANNEL_NUMBER + 1)) - .await; - assert!( - not_exist_channel.is_none(), - "should be nil for not existed channel." - ); - - Ok(()) -} - -#[tokio::test] -async fn test_get_channel_by_addr() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - let addr2 = SocketAddr::from_str("127.0.0.1:3479")?; - let c = ChannelBind::new(ChannelNumber(MIN_CHANNEL_NUMBER), addr); - - a.add_channel_bind(c, DEFAULT_LIFETIME).await?; - - let exist_channel_number = a.get_channel_number(&addr).await.unwrap(); - assert_eq!(ChannelNumber(MIN_CHANNEL_NUMBER), exist_channel_number); - - let not_exist_channel = a.get_channel_number(&addr2).await; - assert!( - not_exist_channel.is_none(), - "should be nil for not existed channel." - ); - - Ok(()) -} - -#[tokio::test] -async fn test_remove_channel_bind() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - let number = ChannelNumber(MIN_CHANNEL_NUMBER); - let c = ChannelBind::new(number, addr); - - a.add_channel_bind(c, DEFAULT_LIFETIME).await?; - - a.remove_channel_bind(number).await; - - let not_exist_channel = a.get_channel_addr(&number).await; - assert!( - not_exist_channel.is_none(), - "should be nil for not existed channel." - ); - - let not_exist_channel = a.get_channel_number(&addr).await; - assert!( - not_exist_channel.is_none(), - "should be nil for not existed channel." - ); - - Ok(()) -} - -#[tokio::test] -async fn test_allocation_refresh() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - a.start(DEFAULT_LIFETIME).await; - a.refresh(Duration::from_secs(0)).await; - - assert!(!a.stop(), "lifetimeTimer has expired"); - - Ok(()) -} - -#[tokio::test] -async fn test_allocation_close() -> Result<()> { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - // add mock lifetimeTimer - a.start(DEFAULT_LIFETIME).await; - - // add channel - let addr = SocketAddr::from_str("127.0.0.1:3478")?; - let number = ChannelNumber(MIN_CHANNEL_NUMBER); - let c = ChannelBind::new(number, addr); - - a.add_channel_bind(c, DEFAULT_LIFETIME).await?; - - // add permission - a.add_permission(Permission::new(addr)).await; - - a.close().await?; - - Ok(()) -} diff --git a/turn/src/allocation/channel_bind.rs b/turn/src/allocation/channel_bind.rs deleted file mode 100644 index e613d07d4..000000000 --- a/turn/src/allocation/channel_bind.rs +++ /dev/null @@ -1,87 +0,0 @@ -#[cfg(test)] -mod channel_bind_test; - -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicBool; -use tokio::sync::Mutex; -use tokio::time::{Duration, Instant}; - -use super::*; -use crate::proto::channum::*; - -/// `ChannelBind` represents a TURN Channel. -/// -/// https://tools.ietf.org/html/rfc5766#section-2.5. -#[derive(Clone)] -pub struct ChannelBind { - pub(crate) peer: SocketAddr, - pub(crate) number: ChannelNumber, - pub(crate) channel_bindings: Option>>>, - reset_tx: Option>, - timer_expired: Arc, -} - -impl ChannelBind { - /// Creates a new [`ChannelBind`] - pub fn new(number: ChannelNumber, peer: SocketAddr) -> Self { - ChannelBind { - number, - peer, - channel_bindings: None, - reset_tx: None, - timer_expired: Arc::new(AtomicBool::new(false)), - } - } - - pub(crate) async fn start(&mut self, lifetime: Duration) { - let (reset_tx, mut reset_rx) = mpsc::channel(1); - self.reset_tx = Some(reset_tx); - - let channel_bindings = self.channel_bindings.clone(); - let number = self.number; - let timer_expired = Arc::clone(&self.timer_expired); - - tokio::spawn(async move { - let timer = tokio::time::sleep(lifetime); - tokio::pin!(timer); - let mut done = false; - - while !done { - tokio::select! { - _ = &mut timer => { - if let Some(cbs) = &channel_bindings{ - let mut cb = cbs.lock().await; - if cb.remove(&number).is_none() { - log::error!("Failed to remove ChannelBind for {}", number); - } - } - done = true; - }, - result = reset_rx.recv() => { - if let Some(d) = result { - timer.as_mut().reset(Instant::now() + d); - } else { - done = true; - } - }, - } - } - - timer_expired.store(true, Ordering::SeqCst); - }); - } - - pub(crate) fn stop(&mut self) -> bool { - let expired = self.reset_tx.is_none() || self.timer_expired.load(Ordering::SeqCst); - self.reset_tx.take(); - expired - } - - pub(crate) async fn refresh(&self, lifetime: Duration) { - if let Some(tx) = &self.reset_tx { - let _ = tx.send(lifetime).await; - } - } -} diff --git a/turn/src/allocation/channel_bind/channel_bind_test.rs b/turn/src/allocation/channel_bind/channel_bind_test.rs deleted file mode 100644 index 365c3744d..000000000 --- a/turn/src/allocation/channel_bind/channel_bind_test.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::net::Ipv4Addr; - -use stun::attributes::ATTR_USERNAME; -use stun::textattrs::TextAttribute; -use tokio::net::UdpSocket; - -use super::*; -use crate::allocation::*; -use crate::error::Result; - -async fn create_channel_bind(lifetime: Duration) -> Result { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let relay_socket = Arc::clone(&turn_socket); - let relay_addr = relay_socket.local_addr()?; - let a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - FiveTuple::default(), - TextAttribute::new(ATTR_USERNAME, "user".into()), - None, - ); - - let addr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); - let c = ChannelBind::new(ChannelNumber(MIN_CHANNEL_NUMBER), addr); - - a.add_channel_bind(c, lifetime).await?; - - Ok(a) -} - -#[tokio::test] -async fn test_channel_bind() -> Result<()> { - let a = create_channel_bind(Duration::from_millis(20)).await?; - - let result = a.get_channel_addr(&ChannelNumber(MIN_CHANNEL_NUMBER)).await; - if let Some(addr) = result { - assert_eq!(addr.ip().to_string(), "0.0.0.0"); - } else { - panic!("expected some, but got none"); - } - - Ok(()) -} - -async fn test_channel_bind_start() -> Result<()> { - let a = create_channel_bind(Duration::from_millis(20)).await?; - tokio::time::sleep(Duration::from_millis(30)).await; - - assert!(a - .get_channel_addr(&ChannelNumber(MIN_CHANNEL_NUMBER)) - .await - .is_none()); - - Ok(()) -} - -async fn test_channel_bind_reset() -> Result<()> { - let a = create_channel_bind(Duration::from_millis(30)).await?; - - tokio::time::sleep(Duration::from_millis(20)).await; - { - let channel_bindings = a.channel_bindings.lock().await; - if let Some(c) = channel_bindings.get(&ChannelNumber(MIN_CHANNEL_NUMBER)) { - c.refresh(Duration::from_millis(30)).await; - } - } - tokio::time::sleep(Duration::from_millis(20)).await; - - assert!(a - .get_channel_addr(&ChannelNumber(MIN_CHANNEL_NUMBER)) - .await - .is_some()); - - Ok(()) -} diff --git a/turn/src/allocation/five_tuple.rs b/turn/src/allocation/five_tuple.rs deleted file mode 100644 index d28eb48ff..000000000 --- a/turn/src/allocation/five_tuple.rs +++ /dev/null @@ -1,46 +0,0 @@ -#[cfg(test)] -mod five_tuple_test; - -use std::fmt; -use std::net::{Ipv4Addr, SocketAddr}; - -use crate::proto::*; - -/// `FiveTuple` is the combination (client IP address and port, server IP -/// address and port, and transport protocol (currently one of UDP, -/// TCP, or TLS)) used to communicate between the client and the -/// server. The 5-tuple uniquely identifies this communication -/// stream. The 5-tuple also uniquely identifies the Allocation on -/// the server. -#[derive(PartialEq, Eq, Clone, Copy, Hash)] -pub struct FiveTuple { - pub protocol: Protocol, - pub src_addr: SocketAddr, - pub dst_addr: SocketAddr, -} - -impl Default for FiveTuple { - fn default() -> Self { - FiveTuple { - protocol: PROTO_UDP, - src_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), - dst_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), - } - } -} - -impl fmt::Display for FiveTuple { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}_{}_{}", self.protocol, self.src_addr, self.dst_addr) - } -} - -impl fmt::Debug for FiveTuple { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FiveTuple") - .field("protocol", &self.protocol) - .field("src_addr", &self.src_addr) - .field("dst_addr", &self.dst_addr) - .finish() - } -} diff --git a/turn/src/allocation/five_tuple/five_tuple_test.rs b/turn/src/allocation/five_tuple/five_tuple_test.rs deleted file mode 100644 index 3e33e5d52..000000000 --- a/turn/src/allocation/five_tuple/five_tuple_test.rs +++ /dev/null @@ -1,100 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_five_tuple_protocol() -> Result<()> { - let udp_expect = PROTO_UDP; - let tcp_expect = PROTO_TCP; - - assert_eq!( - udp_expect, PROTO_UDP, - "Invalid UDP Protocol value, expect {udp_expect} but {PROTO_UDP}" - ); - assert_eq!( - tcp_expect, PROTO_TCP, - "Invalid TCP Protocol value, expect {tcp_expect} but {PROTO_TCP}" - ); - - assert_eq!(udp_expect.to_string(), "UDP"); - assert_eq!(tcp_expect.to_string(), "TCP"); - - Ok(()) -} - -#[test] -fn test_five_tuple_equal() -> Result<()> { - let src_addr1: SocketAddr = "0.0.0.0:3478".parse::()?; - let src_addr2: SocketAddr = "0.0.0.0:3479".parse::()?; - - let dst_addr1: SocketAddr = "0.0.0.0:3480".parse::()?; - let dst_addr2: SocketAddr = "0.0.0.0:3481".parse::()?; - - let tests = vec![ - ( - "Equal", - true, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr1, - dst_addr: dst_addr1, - }, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr1, - dst_addr: dst_addr1, - }, - ), - ( - "DifferentProtocol", - false, - FiveTuple { - protocol: PROTO_TCP, - src_addr: src_addr1, - dst_addr: dst_addr1, - }, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr1, - dst_addr: dst_addr1, - }, - ), - ( - "DifferentSrcAddr", - false, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr1, - dst_addr: dst_addr1, - }, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr2, - dst_addr: dst_addr1, - }, - ), - ( - "DifferentDstAddr", - false, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr1, - dst_addr: dst_addr1, - }, - FiveTuple { - protocol: PROTO_UDP, - src_addr: src_addr1, - dst_addr: dst_addr2, - }, - ), - ]; - - for (name, expect, a, b) in tests { - let fact = a == b; - assert_eq!( - expect, fact, - "{name}: {a}, {b} equal check should be {expect}, but {fact}" - ); - } - - Ok(()) -} diff --git a/turn/src/allocation/mod.rs b/turn/src/allocation/mod.rs deleted file mode 100644 index b8b758bd6..000000000 --- a/turn/src/allocation/mod.rs +++ /dev/null @@ -1,468 +0,0 @@ -#[cfg(test)] -mod allocation_test; - -pub mod allocation_manager; -pub mod channel_bind; -pub mod five_tuple; -pub mod permission; - -use std::collections::HashMap; -use std::marker::{Send, Sync}; -use std::net::SocketAddr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use channel_bind::*; -use five_tuple::*; -use permission::*; -use portable_atomic::{AtomicBool, AtomicUsize}; -use stun::agent::*; -use stun::message::*; -use stun::textattrs::Username; -use tokio::sync::oneshot::{self, Sender}; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::{Duration, Instant}; -use util::sync::Mutex as SyncMutex; -use util::Conn; - -use crate::error::*; -use crate::proto::chandata::*; -use crate::proto::channum::*; -use crate::proto::data::*; -use crate::proto::peeraddr::*; -use crate::proto::*; - -const RTP_MTU: usize = 1500; - -pub type AllocationMap = Arc>>>; - -/// Information about an [`Allocation`]. -#[derive(Debug, Clone)] -pub struct AllocationInfo { - /// [`FiveTuple`] of this [`Allocation`]. - pub five_tuple: FiveTuple, - - /// Username of this [`Allocation`]. - pub username: String, - - /// Relayed bytes with this [`Allocation`]. - #[cfg(feature = "metrics")] - pub relayed_bytes: usize, -} - -impl AllocationInfo { - /// Creates a new [`AllocationInfo`]. - pub fn new( - five_tuple: FiveTuple, - username: String, - #[cfg(feature = "metrics")] relayed_bytes: usize, - ) -> Self { - Self { - five_tuple, - username, - #[cfg(feature = "metrics")] - relayed_bytes, - } - } -} - -/// `Allocation` is tied to a FiveTuple and relays traffic -/// use create_allocation and get_allocation to operate. -pub struct Allocation { - protocol: Protocol, - turn_socket: Arc, - pub(crate) relay_addr: SocketAddr, - pub(crate) relay_socket: Arc, - five_tuple: FiveTuple, - username: Username, - permissions: Arc>>, - channel_bindings: Arc>>, - pub(crate) allocations: Option, - reset_tx: SyncMutex>>, - timer_expired: Arc, - closed: AtomicBool, // Option>, - pub(crate) relayed_bytes: AtomicUsize, - drop_tx: Option>, - alloc_close_notify: Option>, -} - -fn addr2ipfingerprint(addr: &SocketAddr) -> String { - addr.ip().to_string() -} - -impl Allocation { - /// Creates a new [`Allocation`]. - pub fn new( - turn_socket: Arc, - relay_socket: Arc, - relay_addr: SocketAddr, - five_tuple: FiveTuple, - username: Username, - alloc_close_notify: Option>, - ) -> Self { - Allocation { - protocol: PROTO_UDP, - turn_socket, - relay_addr, - relay_socket, - five_tuple, - username, - permissions: Arc::new(Mutex::new(HashMap::new())), - channel_bindings: Arc::new(Mutex::new(HashMap::new())), - allocations: None, - reset_tx: SyncMutex::new(None), - timer_expired: Arc::new(AtomicBool::new(false)), - closed: AtomicBool::new(false), - relayed_bytes: Default::default(), - drop_tx: None, - alloc_close_notify, - } - } - - /// Checks the Permission for the `addr`. - pub async fn has_permission(&self, addr: &SocketAddr) -> bool { - let permissions = self.permissions.lock().await; - permissions.get(&addr2ipfingerprint(addr)).is_some() - } - - /// Adds a new [`Permission`] to this [`Allocation`]. - pub async fn add_permission(&self, mut p: Permission) { - let fingerprint = addr2ipfingerprint(&p.addr); - - { - let permissions = self.permissions.lock().await; - if let Some(existed_permission) = permissions.get(&fingerprint) { - existed_permission.refresh(PERMISSION_TIMEOUT).await; - return; - } - } - - p.permissions = Some(Arc::clone(&self.permissions)); - p.start(PERMISSION_TIMEOUT).await; - - { - let mut permissions = self.permissions.lock().await; - permissions.insert(fingerprint, p); - } - } - - /// Removes the `addr`'s fingerprint from this [`Allocation`]'s permissions. - pub async fn remove_permission(&self, addr: &SocketAddr) -> bool { - let mut permissions = self.permissions.lock().await; - permissions.remove(&addr2ipfingerprint(addr)).is_some() - } - - /// Adds a new [`ChannelBind`] to this [`Allocation`], it also updates the - /// permissions needed for this [`ChannelBind`]. - pub async fn add_channel_bind(&self, mut c: ChannelBind, lifetime: Duration) -> Result<()> { - { - if let Some(addr) = self.get_channel_addr(&c.number).await { - if addr != c.peer { - return Err(Error::ErrSameChannelDifferentPeer); - } - } - - if let Some(number) = self.get_channel_number(&c.peer).await { - if number != c.number { - return Err(Error::ErrSameChannelDifferentPeer); - } - } - } - - { - let channel_bindings = self.channel_bindings.lock().await; - if let Some(cb) = channel_bindings.get(&c.number) { - cb.refresh(lifetime).await; - - // Channel binds also refresh permissions. - self.add_permission(Permission::new(cb.peer)).await; - - return Ok(()); - } - } - - let peer = c.peer; - - // Add or refresh this channel. - c.channel_bindings = Some(Arc::clone(&self.channel_bindings)); - c.start(lifetime).await; - - { - let mut channel_bindings = self.channel_bindings.lock().await; - channel_bindings.insert(c.number, c); - } - - // Channel binds also refresh permissions. - self.add_permission(Permission::new(peer)).await; - - Ok(()) - } - - /// Removes the [`ChannelBind`] from this [`Allocation`] by `number`. - pub async fn remove_channel_bind(&self, number: ChannelNumber) -> bool { - let mut channel_bindings = self.channel_bindings.lock().await; - channel_bindings.remove(&number).is_some() - } - - /// Gets the [`ChannelBind`]'s address by `number`. - pub async fn get_channel_addr(&self, number: &ChannelNumber) -> Option { - let channel_bindings = self.channel_bindings.lock().await; - channel_bindings.get(number).map(|cb| cb.peer) - } - - /// Gets the [`ChannelBind`]'s number from this [`Allocation`] by `addr`. - pub async fn get_channel_number(&self, addr: &SocketAddr) -> Option { - let channel_bindings = self.channel_bindings.lock().await; - for cb in channel_bindings.values() { - if cb.peer == *addr { - return Some(cb.number); - } - } - None - } - - /// Closes the [`Allocation`]. - pub async fn close(&self) -> Result<()> { - if self.closed.load(Ordering::Acquire) { - return Err(Error::ErrClosed); - } - - self.closed.store(true, Ordering::Release); - self.stop(); - - { - let mut permissions = self.permissions.lock().await; - for p in permissions.values_mut() { - p.stop(); - } - } - - { - let mut channel_bindings = self.channel_bindings.lock().await; - for c in channel_bindings.values_mut() { - c.stop(); - } - } - - log::trace!("allocation with {} closed!", self.five_tuple); - - let _ = self.turn_socket.close().await; - let _ = self.relay_socket.close().await; - - if let Some(notify_tx) = &self.alloc_close_notify { - let _ = notify_tx - .send(AllocationInfo { - five_tuple: self.five_tuple, - username: self.username.text.clone(), - #[cfg(feature = "metrics")] - relayed_bytes: self.relayed_bytes.load(Ordering::Acquire), - }) - .await; - } - - Ok(()) - } - - pub async fn start(&self, lifetime: Duration) { - let (reset_tx, mut reset_rx) = mpsc::channel(1); - self.reset_tx.lock().replace(reset_tx); - - let allocations = self.allocations.clone(); - let five_tuple = self.five_tuple; - let timer_expired = Arc::clone(&self.timer_expired); - - tokio::spawn(async move { - let timer = tokio::time::sleep(lifetime); - tokio::pin!(timer); - let mut done = false; - - while !done { - tokio::select! { - _ = &mut timer => { - if let Some(allocs) = &allocations{ - let mut allocs = allocs.lock().await; - if let Some(a) = allocs.remove(&five_tuple) { - let _ = a.close().await; - } - } - done = true; - }, - result = reset_rx.recv() => { - if let Some(d) = result { - timer.as_mut().reset(Instant::now() + d); - } else { - done = true; - } - }, - } - } - - timer_expired.store(true, Ordering::SeqCst); - }); - } - - fn stop(&self) -> bool { - let reset_tx = self.reset_tx.lock().take(); - reset_tx.is_none() || self.timer_expired.load(Ordering::SeqCst) - } - - /// Updates the allocations lifetime. - pub async fn refresh(&self, lifetime: Duration) { - let reset_tx = self.reset_tx.lock().clone(); - if let Some(tx) = reset_tx { - let _ = tx.send(lifetime).await; - } - } - - // https://tools.ietf.org/html/rfc5766#section-10.3 - // When the server receives a UDP datagram at a currently allocated - // relayed transport address, the server looks up the allocation - // associated with the relayed transport address. The server then - // checks to see whether the set of permissions for the allocation allow - // the relaying of the UDP datagram as described in Section 8. - // - // If relaying is permitted, then the server checks if there is a - // channel bound to the peer that sent the UDP datagram (see - // Section 11). If a channel is bound, then processing proceeds as - // described in Section 11.7. - // - // If relaying is permitted but no channel is bound to the peer, then - // the server forms and sends a Data indication. The Data indication - // MUST contain both an XOR-PEER-ADDRESS and a DATA attribute. The DATA - // attribute is set to the value of the 'data octets' field from the - // datagram, and the XOR-PEER-ADDRESS attribute is set to the source - // transport address of the received UDP datagram. The Data indication - // is then sent on the 5-tuple associated with the allocation. - async fn packet_handler(&mut self) { - let five_tuple = self.five_tuple; - let relay_addr = self.relay_addr; - let relay_socket = Arc::clone(&self.relay_socket); - let turn_socket = Arc::clone(&self.turn_socket); - let allocations = self.allocations.clone(); - let channel_bindings = Arc::clone(&self.channel_bindings); - let permissions = Arc::clone(&self.permissions); - let (drop_tx, drop_rx) = oneshot::channel::(); - self.drop_tx = Some(drop_tx); - - tokio::spawn(async move { - let mut buffer = vec![0u8; RTP_MTU]; - - tokio::pin!(drop_rx); - - loop { - let (n, src_addr) = tokio::select! { - result = relay_socket.recv_from(&mut buffer) => { - match result { - Ok((n, src_addr)) => (n, src_addr), - Err(_) => { - if let Some(allocs) = &allocations { - let mut allocs = allocs.lock().await; - allocs.remove(&five_tuple); - } - break; - } - } - } - _ = drop_rx.as_mut() => { - log::trace!("allocation has stopped, stop packet_handler. five_tuple: {:?}", five_tuple); - break; - } - }; - - log::debug!( - "relay socket {:?} received {} bytes from {}", - relay_socket.local_addr(), - n, - src_addr - ); - - let cb_number = { - let mut cb_number = None; - let cbs = channel_bindings.lock().await; - for cb in cbs.values() { - if cb.peer == src_addr { - cb_number = Some(cb.number); - break; - } - } - cb_number - }; - - if let Some(number) = cb_number { - let mut channel_data = ChannelData { - data: buffer[..n].to_vec(), - number, - raw: vec![], - }; - channel_data.encode(); - - if let Err(err) = turn_socket - .send_to(&channel_data.raw, five_tuple.src_addr) - .await - { - log::error!( - "Failed to send ChannelData from allocation {} {}", - src_addr, - err - ); - } - } else { - let exist = { - let ps = permissions.lock().await; - ps.get(&addr2ipfingerprint(&src_addr)).is_some() - }; - - if exist { - let msg = { - let peer_address_attr = PeerAddress { - ip: src_addr.ip(), - port: src_addr.port(), - }; - let data_attr = Data(buffer[..n].to_vec()); - - let mut msg = Message::new(); - if let Err(err) = msg.build(&[ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_DATA, CLASS_INDICATION)), - Box::new(peer_address_attr), - Box::new(data_attr), - ]) { - log::error!( - "Failed to send DataIndication from allocation {} {}", - src_addr, - err - ); - None - } else { - Some(msg) - } - }; - - if let Some(msg) = msg { - log::debug!( - "relaying message from {} to client at {}", - src_addr, - five_tuple.src_addr - ); - if let Err(err) = - turn_socket.send_to(&msg.raw, five_tuple.src_addr).await - { - log::error!( - "Failed to send DataIndication from allocation {} {}", - src_addr, - err - ); - } - } - } else { - log::info!( - "No Permission or Channel exists for {} on allocation {}", - src_addr, - relay_addr - ); - } - } - } - }); - } -} diff --git a/turn/src/allocation/permission.rs b/turn/src/allocation/permission.rs deleted file mode 100644 index 08013523b..000000000 --- a/turn/src/allocation/permission.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicBool; -use tokio::sync::Mutex; -use tokio::time::{Duration, Instant}; - -use super::*; - -pub(crate) const PERMISSION_TIMEOUT: Duration = Duration::from_secs(5 * 60); - -/// `Permission` represents a TURN permission. TURN permissions mimic the address-restricted -/// filtering mechanism of NATs that comply with [RFC4787]. -/// -/// https://tools.ietf.org/html/rfc5766#section-2.3 -pub struct Permission { - pub(crate) addr: SocketAddr, - pub(crate) permissions: Option>>>, - reset_tx: Option>, - timer_expired: Arc, -} - -impl Permission { - /// Creates a new [`Permission`]. - pub fn new(addr: SocketAddr) -> Self { - Permission { - addr, - permissions: None, - reset_tx: None, - timer_expired: Arc::new(AtomicBool::new(false)), - } - } - - pub(crate) async fn start(&mut self, lifetime: Duration) { - let (reset_tx, mut reset_rx) = mpsc::channel(1); - self.reset_tx = Some(reset_tx); - - let permissions = self.permissions.clone(); - let addr = self.addr; - let timer_expired = Arc::clone(&self.timer_expired); - - tokio::spawn(async move { - let timer = tokio::time::sleep(lifetime); - tokio::pin!(timer); - let mut done = false; - - while !done { - tokio::select! { - _ = &mut timer => { - if let Some(perms) = &permissions{ - let mut p = perms.lock().await; - p.remove(&addr2ipfingerprint(&addr)); - } - done = true; - }, - result = reset_rx.recv() => { - if let Some(d) = result { - timer.as_mut().reset(Instant::now() + d); - } else { - done = true; - } - }, - } - } - - timer_expired.store(true, Ordering::SeqCst); - }); - } - - pub(crate) fn stop(&mut self) -> bool { - let expired = self.reset_tx.is_none() || self.timer_expired.load(Ordering::SeqCst); - self.reset_tx.take(); - expired - } - - pub(crate) async fn refresh(&self, lifetime: Duration) { - if let Some(tx) = &self.reset_tx { - let _ = tx.send(lifetime).await; - } - } -} diff --git a/turn/src/auth/auth_test.rs b/turn/src/auth/auth_test.rs deleted file mode 100644 index df1e51d8b..000000000 --- a/turn/src/auth/auth_test.rs +++ /dev/null @@ -1,103 +0,0 @@ -use super::*; - -#[test] -fn test_lt_cred() -> Result<()> { - let username = "1599491771"; - let shared_secret = "foobar"; - - let expected_password = "Tpz/nKkyvX/vMSLKvL4sbtBt8Vs="; - let actual_password = long_term_credentials(username, shared_secret); - assert_eq!( - expected_password, actual_password, - "Expected {expected_password}, got {actual_password}" - ); - - Ok(()) -} - -#[test] -fn test_generate_auth_key() -> Result<()> { - let username = "60"; - let password = "HWbnm25GwSj6jiHTEDMTO5D7aBw="; - let realm = "webrtc.rs"; - - let expected_key = vec![ - 56, 22, 47, 139, 198, 127, 13, 188, 171, 80, 23, 29, 195, 148, 216, 224, - ]; - let actual_key = generate_auth_key(username, realm, password); - assert_eq!( - expected_key, actual_key, - "Expected {expected_key:?}, got {actual_key:?}" - ); - - Ok(()) -} - -#[cfg(target_family = "unix")] -#[tokio::test] -async fn test_new_long_term_auth_handler() -> Result<()> { - use std::net::IpAddr; - use std::str::FromStr; - use std::sync::Arc; - - use tokio::net::UdpSocket; - use util::vnet::net::*; - - use crate::client::*; - use crate::relay::relay_static::*; - use crate::server::config::*; - use crate::server::*; - - //env_logger::init(); - - const SHARED_SECRET: &str = "HELLO_WORLD"; - - // here, it should use static port, like "0.0.0.0:3478", - // but, due to different test environment, let's fake it by using "0.0.0.0:0" - // to auto assign a "static" port - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let server_port = conn.local_addr()?.port(); - - let server = Server::new(ServerConfig { - conn_configs: vec![ConnConfig { - conn, - relay_addr_generator: Box::new(RelayAddressGeneratorStatic { - relay_address: IpAddr::from_str("127.0.0.1")?, - address: "0.0.0.0".to_owned(), - net: Arc::new(Net::new(None)), - }), - }], - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(LongTermAuthHandler::new(SHARED_SECRET.to_string())), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - let (username, password) = - generate_long_term_credentials(SHARED_SECRET, Duration::from_secs(60))?; - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let client = Client::new(ClientConfig { - stun_serv_addr: format!("0.0.0.0:{server_port}"), - turn_serv_addr: format!("0.0.0.0:{server_port}"), - username, - password, - realm: "webrtc.rs".to_owned(), - software: String::new(), - rto_in_ms: 0, - conn, - vnet: None, - }) - .await?; - - client.listen().await?; - - let _allocation = client.allocate().await?; - - client.close().await?; - server.close().await?; - - Ok(()) -} diff --git a/turn/src/auth/mod.rs b/turn/src/auth/mod.rs deleted file mode 100644 index 537983d7d..000000000 --- a/turn/src/auth/mod.rs +++ /dev/null @@ -1,77 +0,0 @@ -#[cfg(test)] -mod auth_test; - -use std::net::SocketAddr; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use md5::{Digest, Md5}; -use ring::hmac; - -use crate::error::*; - -pub trait AuthHandler { - fn auth_handle(&self, username: &str, realm: &str, src_addr: SocketAddr) -> Result>; -} - -/// `generate_long_term_credentials()` can be used to create credentials valid for `duration` time/ -pub fn generate_long_term_credentials( - shared_secret: &str, - duration: Duration, -) -> Result<(String, String)> { - let t = SystemTime::now().duration_since(UNIX_EPOCH)? + duration; - let username = format!("{}", t.as_secs()); - let password = long_term_credentials(&username, shared_secret); - Ok((username, password)) -} - -fn long_term_credentials(username: &str, shared_secret: &str) -> String { - let mac = hmac::Key::new( - hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, - shared_secret.as_bytes(), - ); - let password = hmac::sign(&mac, username.as_bytes()).as_ref().to_vec(); - BASE64_STANDARD.encode(password) -} - -/// A convenience function to easily generate keys in the format used by [`AuthHandler`]. -pub fn generate_auth_key(username: &str, realm: &str, password: &str) -> Vec { - let s = format!("{username}:{realm}:{password}"); - - let mut h = Md5::new(); - h.update(s.as_bytes()); - h.finalize().as_slice().to_vec() -} - -pub struct LongTermAuthHandler { - shared_secret: String, -} - -impl AuthHandler for LongTermAuthHandler { - fn auth_handle(&self, username: &str, realm: &str, src_addr: SocketAddr) -> Result> { - log::trace!( - "Authentication username={} realm={} src_addr={}", - username, - realm, - src_addr - ); - - let t = Duration::from_secs(username.parse::()?); - if t < SystemTime::now().duration_since(UNIX_EPOCH)? { - return Err(Error::Other(format!( - "Expired time-windowed username {username}" - ))); - } - - let password = long_term_credentials(username, &self.shared_secret); - Ok(generate_auth_key(username, realm, &password)) - } -} - -impl LongTermAuthHandler { - /// https://tools.ietf.org/search/rfc5389#section-10.2 - pub fn new(shared_secret: String) -> Self { - LongTermAuthHandler { shared_secret } - } -} diff --git a/turn/src/client/binding.rs b/turn/src/client/binding.rs deleted file mode 100644 index 2de2a03cb..000000000 --- a/turn/src/client/binding.rs +++ /dev/null @@ -1,136 +0,0 @@ -#[cfg(test)] -mod binding_test; - -use std::collections::HashMap; -use std::net::SocketAddr; - -use tokio::time::Instant; - -// Channel number: -// 0x4000 through 0x7FFF: These values are the allowed channel -// numbers (16,383 possible values). -const MIN_CHANNEL_NUMBER: u16 = 0x4000; -const MAX_CHANNEL_NUMBER: u16 = 0x7fff; - -#[derive(Copy, Clone, Debug, PartialEq)] -pub(crate) enum BindingState { - Idle, - Request, - Ready, - Refresh, - Failed, -} - -#[derive(Copy, Clone, Debug, PartialEq)] -pub(crate) struct Binding { - pub(crate) number: u16, - pub(crate) st: BindingState, - pub(crate) addr: SocketAddr, - pub(crate) refreshed_at: Instant, -} - -impl Binding { - pub(crate) fn set_state(&mut self, state: BindingState) { - //atomic.StoreInt32((*int32)(&b.st), int32(state)) - self.st = state; - } - - pub(crate) fn state(&self) -> BindingState { - //return BindingState(atomic.LoadInt32((*int32)(&b.st))) - self.st - } - - pub(crate) fn set_refreshed_at(&mut self, at: Instant) { - self.refreshed_at = at; - } - - pub(crate) fn refreshed_at(&self) -> Instant { - self.refreshed_at - } -} -/// Thread-safe Binding map. -#[derive(Default)] -pub(crate) struct BindingManager { - chan_map: HashMap, - addr_map: HashMap, - next: u16, -} - -impl BindingManager { - pub(crate) fn new() -> Self { - BindingManager { - chan_map: HashMap::new(), - addr_map: HashMap::new(), - next: MIN_CHANNEL_NUMBER, - } - } - - pub(crate) fn assign_channel_number(&mut self) -> u16 { - let n = self.next; - if self.next == MAX_CHANNEL_NUMBER { - self.next = MIN_CHANNEL_NUMBER; - } else { - self.next += 1; - } - n - } - - pub(crate) fn create(&mut self, addr: SocketAddr) -> Option<&Binding> { - let b = Binding { - number: self.assign_channel_number(), - st: BindingState::Idle, - addr, - refreshed_at: Instant::now(), - }; - - self.chan_map.insert(b.number, b.addr.to_string()); - self.addr_map.insert(b.addr.to_string(), b); - self.addr_map.get(&addr.to_string()) - } - - pub(crate) fn find_by_addr(&self, addr: &SocketAddr) -> Option<&Binding> { - self.addr_map.get(&addr.to_string()) - } - - pub(crate) fn get_by_addr(&mut self, addr: &SocketAddr) -> Option<&mut Binding> { - self.addr_map.get_mut(&addr.to_string()) - } - - pub(crate) fn find_by_number(&self, number: u16) -> Option<&Binding> { - if let Some(s) = self.chan_map.get(&number) { - self.addr_map.get(s) - } else { - None - } - } - - pub(crate) fn get_by_number(&mut self, number: u16) -> Option<&mut Binding> { - if let Some(s) = self.chan_map.get(&number) { - self.addr_map.get_mut(s) - } else { - None - } - } - - pub(crate) fn delete_by_addr(&mut self, addr: &SocketAddr) -> bool { - if let Some(b) = self.addr_map.remove(&addr.to_string()) { - self.chan_map.remove(&b.number); - true - } else { - false - } - } - - pub(crate) fn delete_by_number(&mut self, number: u16) -> bool { - if let Some(s) = self.chan_map.remove(&number) { - self.addr_map.remove(&s); - true - } else { - false - } - } - - pub(crate) fn size(&self) -> usize { - self.addr_map.len() - } -} diff --git a/turn/src/client/binding/binding_test.rs b/turn/src/client/binding/binding_test.rs deleted file mode 100644 index d8bae6863..000000000 --- a/turn/src/client/binding/binding_test.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::net::{Ipv4Addr, SocketAddrV4}; - -use super::*; -use crate::error::Result; - -#[test] -fn test_binding_manager_number_assignment() -> Result<()> { - let mut m = BindingManager::new(); - let mut n: u16; - for i in 0..10 { - n = m.assign_channel_number(); - assert_eq!(MIN_CHANNEL_NUMBER + i, n, "should match"); - } - - m.next = 0x7ff0; - for i in 0..16 { - n = m.assign_channel_number(); - assert_eq!(0x7ff0 + i, n, "should match"); - } - // back to min - n = m.assign_channel_number(); - assert_eq!(MIN_CHANNEL_NUMBER, n, "should match"); - - Ok(()) -} - -#[test] -fn test_binding_manager_method() -> Result<()> { - let lo = Ipv4Addr::new(127, 0, 0, 1); - let count = 100; - let mut m = BindingManager::new(); - for i in 0..count { - let addr = SocketAddr::V4(SocketAddrV4::new(lo, 10000 + i)); - let b0 = { - let b = m.create(addr); - *b.unwrap() - }; - let b1 = m.find_by_addr(&addr); - assert!(b1.is_some(), "should succeed"); - let b2 = m.find_by_number(b0.number); - assert!(b2.is_some(), "should succeed"); - - assert_eq!(b0, *b1.unwrap(), "should match"); - assert_eq!(b0, *b2.unwrap(), "should match"); - } - - assert_eq!(count, m.size() as u16, "should match"); - assert_eq!(count, m.addr_map.len() as u16, "should match"); - - for i in 0..count { - let addr = SocketAddr::V4(SocketAddrV4::new(lo, 10000 + i)); - if i % 2 == 0 { - assert!(m.delete_by_addr(&addr), "should return true"); - } else { - assert!( - m.delete_by_number(MIN_CHANNEL_NUMBER + i), - "should return true" - ); - } - } - - assert_eq!(0, m.size(), "should match"); - assert_eq!(0, m.addr_map.len(), "should match"); - - Ok(()) -} - -#[test] -fn test_binding_manager_failure() -> Result<()> { - let ipv4 = Ipv4Addr::new(127, 0, 0, 1); - let addr = SocketAddr::V4(SocketAddrV4::new(ipv4, 7777)); - let mut m = BindingManager::new(); - let b = m.find_by_addr(&addr); - assert!(b.is_none(), "should fail"); - let b = m.find_by_number(5555); - assert!(b.is_none(), "should fail"); - let ok = m.delete_by_addr(&addr); - assert!(!ok, "should fail"); - let ok = m.delete_by_number(5555); - assert!(!ok, "should fail"); - - Ok(()) -} diff --git a/turn/src/client/client_test.rs b/turn/src/client/client_test.rs deleted file mode 100644 index 20516bd25..000000000 --- a/turn/src/client/client_test.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::net::IpAddr; - -use tokio::net::UdpSocket; -use tokio::time::Duration; -use util::vnet::net::*; - -use super::*; -use crate::auth::*; -use crate::relay::relay_static::*; -use crate::server::config::*; -use crate::server::*; - -async fn create_listening_test_client(rto_in_ms: u16) -> Result { - let conn = UdpSocket::bind("0.0.0.0:0").await?; - - let c = Client::new(ClientConfig { - stun_serv_addr: String::new(), - turn_serv_addr: String::new(), - username: String::new(), - password: String::new(), - realm: String::new(), - software: "TEST SOFTWARE".to_owned(), - rto_in_ms, - conn: Arc::new(conn), - vnet: None, - }) - .await?; - - c.listen().await?; - - Ok(c) -} - -async fn create_listening_test_client_with_stun_serv() -> Result { - let conn = UdpSocket::bind("0.0.0.0:0").await?; - - let c = Client::new(ClientConfig { - stun_serv_addr: "stun1.l.google.com:19302".to_owned(), - turn_serv_addr: String::new(), - username: String::new(), - password: String::new(), - realm: String::new(), - software: "TEST SOFTWARE".to_owned(), - rto_in_ms: 0, - conn: Arc::new(conn), - vnet: None, - }) - .await?; - - c.listen().await?; - - Ok(c) -} - -#[tokio::test] -async fn test_client_with_stun_send_binding_request() -> Result<()> { - //env_logger::init(); - - let c = create_listening_test_client_with_stun_serv().await?; - - let resp = c.send_binding_request().await?; - log::debug!("mapped-addr: {}", resp); - { - let ci = c.client_internal.lock().await; - let tm = ci.tr_map.lock().await; - assert_eq!(0, tm.size(), "should be no transaction left"); - } - - c.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_client_with_stun_send_binding_request_to_parallel() -> Result<()> { - env_logger::init(); - - let c1 = create_listening_test_client(0).await?; - let c2 = c1.clone(); - - let (stared_tx, mut started_rx) = mpsc::channel::<()>(1); - let (finished_tx, mut finished_rx) = mpsc::channel::<()>(1); - - let to = lookup_host(true, "stun1.l.google.com:19302").await?; - - tokio::spawn(async move { - drop(stared_tx); - if let Ok(resp) = c2.send_binding_request_to(&to.to_string()).await { - log::debug!("mapped-addr: {}", resp); - } - drop(finished_tx); - }); - - let _ = started_rx.recv().await; - - let resp = c1.send_binding_request_to(&to.to_string()).await?; - log::debug!("mapped-addr: {}", resp); - - let _ = finished_rx.recv().await; - - c1.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_client_with_stun_send_binding_request_to_timeout() -> Result<()> { - //env_logger::init(); - - let c = create_listening_test_client(10).await?; - - let to = lookup_host(true, "127.0.0.1:9").await?; - - let result = c.send_binding_request_to(&to.to_string()).await; - assert!(result.is_err(), "expected error, but got ok"); - - c.close().await?; - - Ok(()) -} - -struct TestAuthHandler; -impl AuthHandler for TestAuthHandler { - fn auth_handle(&self, username: &str, realm: &str, _src_addr: SocketAddr) -> Result> { - Ok(generate_auth_key(username, realm, "pass")) - } -} - -// Create an allocation, and then delete all nonces -// The subsequent Write on the allocation will cause a CreatePermission -// which will be forced to handle a stale nonce response -#[tokio::test] -async fn test_client_nonce_expiration() -> Result<()> { - // env_logger::init(); - - // here, it should use static port, like "0.0.0.0:3478", - // but, due to different test environment, let's fake it by using "0.0.0.0:0" - // to auto assign a "static" port - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let server_port = conn.local_addr()?.port(); - - let server = Server::new(ServerConfig { - conn_configs: vec![ConnConfig { - conn, - relay_addr_generator: Box::new(RelayAddressGeneratorStatic { - relay_address: IpAddr::from_str("127.0.0.1")?, - address: "0.0.0.0".to_owned(), - net: Arc::new(Net::new(None)), - }), - }], - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(TestAuthHandler {}), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let client = Client::new(ClientConfig { - stun_serv_addr: format!("127.0.0.1:{server_port}"), - turn_serv_addr: format!("127.0.0.1:{server_port}"), - username: "foo".to_owned(), - password: "pass".to_owned(), - realm: String::new(), - software: String::new(), - rto_in_ms: 0, - conn, - vnet: None, - }) - .await?; - - client.listen().await?; - - let allocation = client.allocate().await?; - - { - let mut nonces = server.nonces.lock().await; - nonces.clear(); - } - - allocation - .send_to(&[0x00], SocketAddr::from_str("127.0.0.1:8080")?) - .await?; - - // Shutdown - client.close().await?; - server.close().await?; - - Ok(()) -} diff --git a/turn/src/client/mod.rs b/turn/src/client/mod.rs deleted file mode 100644 index 9ae590d8f..000000000 --- a/turn/src/client/mod.rs +++ /dev/null @@ -1,652 +0,0 @@ -#[cfg(test)] -mod client_test; - -pub mod binding; -pub mod periodic_timer; -pub mod permission; -pub mod relay_conn; -pub mod transaction; - -use std::net::SocketAddr; -use std::str::FromStr; -use std::sync::Arc; - -use async_trait::async_trait; -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use binding::*; -use relay_conn::*; -use stun::agent::*; -use stun::attributes::*; -use stun::error_code::*; -use stun::fingerprint::*; -use stun::integrity::*; -use stun::message::*; -use stun::textattrs::*; -use stun::xoraddr::*; -use tokio::pin; -use tokio::select; -use tokio::sync::{mpsc, Mutex}; -use tokio_util::sync::CancellationToken; -use transaction::*; -use util::conn::*; -use util::vnet::net::*; - -use crate::error::*; -use crate::proto::chandata::*; -use crate::proto::data::*; -use crate::proto::lifetime::*; -use crate::proto::peeraddr::*; -use crate::proto::relayaddr::*; -use crate::proto::reqtrans::*; -use crate::proto::PROTO_UDP; - -const DEFAULT_RTO_IN_MS: u16 = 200; -const MAX_DATA_BUFFER_SIZE: usize = u16::MAX as usize; // message size limit for Chromium -const MAX_READ_QUEUE_SIZE: usize = 1024; - -// interval [msec] -// 0: 0 ms +500 -// 1: 500 ms +1000 -// 2: 1500 ms +2000 -// 3: 3500 ms +4000 -// 4: 7500 ms +8000 -// 5: 15500 ms +16000 -// 6: 31500 ms +32000 -// -: 63500 ms failed - -/// ClientConfig is a bag of config parameters for Client. -pub struct ClientConfig { - pub stun_serv_addr: String, // STUN server address (e.g. "stun.abc.com:3478") - pub turn_serv_addr: String, // TURN server address (e.g. "turn.abc.com:3478") - pub username: String, - pub password: String, - pub realm: String, - pub software: String, - pub rto_in_ms: u16, - pub conn: Arc, - pub vnet: Option>, -} - -struct ClientInternal { - conn: Arc, - stun_serv_addr: String, - turn_serv_addr: String, - username: Username, - password: String, - realm: Realm, - integrity: MessageIntegrity, - software: Software, - tr_map: Arc>, - binding_mgr: Arc>, - rto_in_ms: u16, - read_ch_tx: Arc>>>, - close_notify: CancellationToken, -} - -#[async_trait] -impl RelayConnObserver for ClientInternal { - /// Returns the TURN server address. - fn turn_server_addr(&self) -> String { - self.turn_serv_addr.clone() - } - - /// Returns the `username`. - fn username(&self) -> Username { - self.username.clone() - } - - /// Return the `realm`. - fn realm(&self) -> Realm { - self.realm.clone() - } - - /// Sends data to the specified destination using the base socket. - async fn write_to(&self, data: &[u8], to: &str) -> std::result::Result { - let n = self.conn.send_to(data, SocketAddr::from_str(to)?).await?; - Ok(n) - } - - /// Performs STUN transaction. - async fn perform_transaction( - &mut self, - msg: &Message, - to: &str, - ignore_result: bool, - ) -> Result { - let tr_key = BASE64_STANDARD.encode(msg.transaction_id.0); - - let mut tr = Transaction::new(TransactionConfig { - key: tr_key.clone(), - raw: msg.raw.clone(), - to: to.to_string(), - interval: self.rto_in_ms, - ignore_result, - }); - let result_ch_rx = tr.get_result_channel(); - - log::trace!("start {} transaction {} to {}", msg.typ, tr_key, tr.to); - { - let mut tm = self.tr_map.lock().await; - tm.insert(tr_key.clone(), tr); - } - - self.conn - .send_to(&msg.raw, SocketAddr::from_str(to)?) - .await?; - - let conn2 = Arc::clone(&self.conn); - let tr_map2 = Arc::clone(&self.tr_map); - { - let mut tm = self.tr_map.lock().await; - if let Some(tr) = tm.get(&tr_key) { - tr.start_rtx_timer(conn2, tr_map2).await; - } - } - - // If dontWait is true, get the transaction going and return immediately - if ignore_result { - return Ok(TransactionResult::default()); - } - - // wait_for_result waits for the transaction result - if let Some(mut result_ch_rx) = result_ch_rx { - match result_ch_rx.recv().await { - Some(tr) => Ok(tr), - None => Err(Error::ErrTransactionClosed), - } - } else { - Err(Error::ErrWaitForResultOnNonResultTransaction) - } - } -} - -impl ClientInternal { - /// Creates a new [`ClientInternal`]. - async fn new(config: ClientConfig) -> Result { - let net = if let Some(vnet) = config.vnet { - if vnet.is_virtual() { - log::warn!("vnet is enabled"); - } - vnet - } else { - Arc::new(Net::new(None)) - }; - - let stun_serv_addr = if config.stun_serv_addr.is_empty() { - String::new() - } else { - log::debug!("resolving {}", config.stun_serv_addr); - let local_addr = config.conn.local_addr()?; - let stun_serv = net - .resolve_addr(local_addr.is_ipv4(), &config.stun_serv_addr) - .await?; - log::debug!("stunServ: {}", stun_serv); - stun_serv.to_string() - }; - - let turn_serv_addr = if config.turn_serv_addr.is_empty() { - String::new() - } else { - log::debug!("resolving {}", config.turn_serv_addr); - let local_addr = config.conn.local_addr()?; - let turn_serv = net - .resolve_addr(local_addr.is_ipv4(), &config.turn_serv_addr) - .await?; - log::debug!("turnServ: {}", turn_serv); - turn_serv.to_string() - }; - - Ok(ClientInternal { - conn: Arc::clone(&config.conn), - stun_serv_addr, - turn_serv_addr, - username: Username::new(ATTR_USERNAME, config.username), - password: config.password, - realm: Realm::new(ATTR_REALM, config.realm), - software: Software::new(ATTR_SOFTWARE, config.software), - tr_map: Arc::new(Mutex::new(TransactionMap::new())), - binding_mgr: Arc::new(Mutex::new(BindingManager::new())), - rto_in_ms: if config.rto_in_ms != 0 { - config.rto_in_ms - } else { - DEFAULT_RTO_IN_MS - }, - integrity: MessageIntegrity::new_short_term_integrity(String::new()), - read_ch_tx: Arc::new(Mutex::new(None)), - close_notify: CancellationToken::new(), - }) - } - - /// Returns the STUN server address. - fn stun_server_addr(&self) -> String { - self.stun_serv_addr.clone() - } - - /// `listen()` will have this client start listening on the `relay_conn` provided via the config. - /// This is optional. If not used, you will need to call `handle_inbound` method - /// to supply incoming data, instead. - async fn listen(&self) -> Result<()> { - let conn = Arc::clone(&self.conn); - let stun_serv_str = self.stun_serv_addr.clone(); - let tr_map = Arc::clone(&self.tr_map); - let read_ch_tx = Arc::clone(&self.read_ch_tx); - let binding_mgr = Arc::clone(&self.binding_mgr); - let close_notify = self.close_notify.clone(); - - tokio::spawn(async move { - let mut buf = vec![0u8; MAX_DATA_BUFFER_SIZE]; - let wait_cancel = close_notify.cancelled(); - pin!(wait_cancel); - - loop { - let (n, from) = select! { - biased; - - _ = &mut wait_cancel => { - log::debug!("exiting read loop"); - break; - }, - result = conn.recv_from(&mut buf) => match result { - Ok((n, from)) => (n, from), - Err(err) => { - log::debug!("exiting read loop: {}", err); - break; - } - } - }; - log::debug!("received {} bytes of udp from {}", n, from); - - select! { - biased; - - _ = &mut wait_cancel => { - log::debug!("exiting read loop"); - break; - }, - result = ClientInternal::handle_inbound( - &read_ch_tx, - &buf[..n], - from, - &stun_serv_str, - &tr_map, - &binding_mgr, - ) => { - if let Err(err) = result { - log::debug!("exiting read loop: {}", err); - break; - } - } - } - } - }); - - Ok(()) - } - - /// Handles data received. - /// - /// This method handles incoming packet demultiplex it by the source address - /// and the types of the message. - /// Caller should check if the packet was handled by this client or not. - /// If not handled, it is assumed that the packet is application data. - /// If an error is returned, the caller should discard the packet regardless. - async fn handle_inbound( - read_ch_tx: &Arc>>>, - data: &[u8], - from: SocketAddr, - stun_serv_str: &str, - tr_map: &Arc>, - binding_mgr: &Arc>, - ) -> Result<()> { - // +-------------------+-------------------------------+ - // | Return Values | | - // +-------------------+ Meaning / Action | - // | handled | error | | - // |=========+=========+===============================+ - // | false | nil | Handle the packet as app data | - // |---------+---------+-------------------------------+ - // | true | nil | Nothing to do | - // |---------+---------+-------------------------------+ - // | false | error | (shouldn't happen) | - // |---------+---------+-------------------------------+ - // | true | error | Error occurred while handling | - // +---------+---------+-------------------------------+ - // Possible causes of the error: - // - Malformed packet (parse error) - // - STUN message was a request - // - Non-STUN message from the STUN server - - if is_message(data) { - ClientInternal::handle_stun_message(tr_map, read_ch_tx, data, from).await - } else if ChannelData::is_channel_data(data) { - ClientInternal::handle_channel_data(binding_mgr, read_ch_tx, data).await - } else if !stun_serv_str.is_empty() && from.to_string() == *stun_serv_str { - // received from STUN server but it is not a STUN message - Err(Error::ErrNonStunmessage) - } else { - // assume, this is an application data - log::trace!("non-STUN/TURN packect, unhandled"); - Ok(()) - } - } - - async fn handle_stun_message( - tr_map: &Arc>, - read_ch_tx: &Arc>>>, - data: &[u8], - mut from: SocketAddr, - ) -> Result<()> { - let mut msg = Message::new(); - msg.raw = data.to_vec(); - msg.decode()?; - - if msg.typ.class == CLASS_REQUEST { - return Err(Error::Other(format!( - "{:?} : {}", - Error::ErrUnexpectedStunrequestMessage, - msg - ))); - } - - if msg.typ.class == CLASS_INDICATION { - if msg.typ.method == METHOD_DATA { - let mut peer_addr = PeerAddress::default(); - peer_addr.get_from(&msg)?; - from = SocketAddr::new(peer_addr.ip, peer_addr.port); - - let mut data = Data::default(); - data.get_from(&msg)?; - - log::debug!("data indication received from {}", from); - - let _ = ClientInternal::handle_inbound_relay_conn(read_ch_tx, &data.0, from).await; - } - - return Ok(()); - } - - // This is a STUN response message (transactional) - // The type is either: - // - stun.ClassSuccessResponse - // - stun.ClassErrorResponse - - let tr_key = BASE64_STANDARD.encode(msg.transaction_id.0); - - let mut tm = tr_map.lock().await; - if tm.find(&tr_key).is_none() { - // silently discard - log::debug!("no transaction for {}", msg); - return Ok(()); - } - - if let Some(mut tr) = tm.delete(&tr_key) { - // End the transaction - tr.stop_rtx_timer(); - - if !tr - .write_result(TransactionResult { - msg, - from, - retries: tr.retries(), - ..Default::default() - }) - .await - { - log::debug!("no listener for msg.raw {:?}", data); - } - } - - Ok(()) - } - - async fn handle_channel_data( - binding_mgr: &Arc>, - read_ch_tx: &Arc>>>, - data: &[u8], - ) -> Result<()> { - let mut ch_data = ChannelData { - raw: data.to_vec(), - ..Default::default() - }; - ch_data.decode()?; - - let addr = ClientInternal::find_addr_by_channel_number(binding_mgr, ch_data.number.0) - .await - .ok_or(Error::ErrChannelBindNotFound)?; - - log::trace!( - "channel data received from {} (ch={})", - addr, - ch_data.number.0 - ); - - let _ = ClientInternal::handle_inbound_relay_conn(read_ch_tx, &ch_data.data, addr).await; - - Ok(()) - } - - /// Passes inbound data in RelayConn. - async fn handle_inbound_relay_conn( - read_ch_tx: &Arc>>>, - data: &[u8], - from: SocketAddr, - ) -> Result<()> { - let read_ch_tx_opt = read_ch_tx.lock().await; - log::debug!("read_ch_tx_opt = {}", read_ch_tx_opt.is_some()); - if let Some(tx) = &*read_ch_tx_opt { - log::debug!("try_send data = {:?}, from = {}", data, from); - if tx - .try_send(InboundData { - data: data.to_vec(), - from, - }) - .is_err() - { - log::warn!("receive buffer full"); - } - Ok(()) - } else { - Err(Error::ErrAlreadyClosed) - } - } - - /// Closes this client. - async fn close(&mut self) { - self.close_notify.cancel(); - { - let mut read_ch_tx = self.read_ch_tx.lock().await; - read_ch_tx.take(); - } - { - let mut tm = self.tr_map.lock().await; - tm.close_and_delete_all(); - } - } - - /// Sends a new STUN request to the given transport address. - async fn send_binding_request_to(&mut self, to: &str) -> Result { - let msg = { - let attrs: Vec> = if !self.software.text.is_empty() { - vec![ - Box::new(TransactionId::new()), - Box::new(BINDING_REQUEST), - Box::new(self.software.clone()), - ] - } else { - vec![Box::new(TransactionId::new()), Box::new(BINDING_REQUEST)] - }; - - let mut msg = Message::new(); - msg.build(&attrs)?; - msg - }; - - log::debug!("client.SendBindingRequestTo call PerformTransaction 1"); - let tr_res = self.perform_transaction(&msg, to, false).await?; - - let mut refl_addr = XorMappedAddress::default(); - refl_addr.get_from(&tr_res.msg)?; - - Ok(SocketAddr::new(refl_addr.ip, refl_addr.port)) - } - - /// Sends a new STUN request to the STUN server. - async fn send_binding_request(&mut self) -> Result { - if self.stun_serv_addr.is_empty() { - Err(Error::ErrStunserverAddressNotSet) - } else { - self.send_binding_request_to(&self.stun_serv_addr.clone()) - .await - } - } - - /// Returns a peer address associated with the - // channel number on this UDPConn - async fn find_addr_by_channel_number( - binding_mgr: &Arc>, - ch_num: u16, - ) -> Option { - let bm = binding_mgr.lock().await; - bm.find_by_number(ch_num).map(|b| b.addr) - } - - /// Sends a TURN allocation request to the given transport address. - async fn allocate(&mut self) -> Result { - { - let read_ch_tx = self.read_ch_tx.lock().await; - log::debug!("allocate check: read_ch_tx_opt = {}", read_ch_tx.is_some()); - if read_ch_tx.is_some() { - return Err(Error::ErrOneAllocateOnly); - } - } - - let mut msg = Message::new(); - msg.build(&[ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_ALLOCATE, CLASS_REQUEST)), - Box::new(RequestedTransport { - protocol: PROTO_UDP, - }), - Box::new(FINGERPRINT), - ])?; - - log::debug!("client.Allocate call PerformTransaction 1"); - let tr_res = self - .perform_transaction(&msg, &self.turn_serv_addr.clone(), false) - .await?; - let res = tr_res.msg; - - // Anonymous allocate failed, trying to authenticate. - let nonce = Nonce::get_from_as(&res, ATTR_NONCE)?; - self.realm = Realm::get_from_as(&res, ATTR_REALM)?; - - self.integrity = MessageIntegrity::new_long_term_integrity( - self.username.text.clone(), - self.realm.text.clone(), - self.password.clone(), - ); - - // Trying to authorize. - msg.build(&[ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_ALLOCATE, CLASS_REQUEST)), - Box::new(RequestedTransport { - protocol: PROTO_UDP, - }), - Box::new(self.username.clone()), - Box::new(self.realm.clone()), - Box::new(nonce.clone()), - Box::new(self.integrity.clone()), - Box::new(FINGERPRINT), - ])?; - - log::debug!("client.Allocate call PerformTransaction 2"); - let tr_res = self - .perform_transaction(&msg, &self.turn_serv_addr.clone(), false) - .await?; - let res = tr_res.msg; - - if res.typ.class == CLASS_ERROR_RESPONSE { - let mut code = ErrorCodeAttribute::default(); - let result = code.get_from(&res); - if result.is_err() { - return Err(Error::Other(format!("{}", res.typ))); - } else { - return Err(Error::Other(format!("{} (error {})", res.typ, code))); - } - } - - // Getting relayed addresses from response. - let mut relayed = RelayedAddress::default(); - relayed.get_from(&res)?; - let relayed_addr = SocketAddr::new(relayed.ip, relayed.port); - - // Getting lifetime from response - let mut lifetime = Lifetime::default(); - lifetime.get_from(&res)?; - - let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE); - { - let mut read_ch_tx_opt = self.read_ch_tx.lock().await; - *read_ch_tx_opt = Some(read_ch_tx); - log::debug!("allocate: read_ch_tx_opt = {}", read_ch_tx_opt.is_some()); - } - - Ok(RelayConnConfig { - relayed_addr, - integrity: self.integrity.clone(), - nonce, - lifetime: lifetime.0, - binding_mgr: Arc::clone(&self.binding_mgr), - read_ch_rx: Arc::new(Mutex::new(read_ch_rx)), - }) - } -} - -/// Client is a STUN server client. -#[derive(Clone)] -pub struct Client { - client_internal: Arc>, -} - -impl Client { - pub async fn new(config: ClientConfig) -> Result { - let ci = ClientInternal::new(config).await?; - Ok(Client { - client_internal: Arc::new(Mutex::new(ci)), - }) - } - - pub async fn listen(&self) -> Result<()> { - let ci = self.client_internal.lock().await; - ci.listen().await - } - - pub async fn allocate(&self) -> Result { - let config = { - let mut ci = self.client_internal.lock().await; - ci.allocate().await? - }; - - Ok(RelayConn::new(Arc::clone(&self.client_internal), config).await) - } - - pub async fn close(&self) -> Result<()> { - let mut ci = self.client_internal.lock().await; - ci.close().await; - Ok(()) - } - - /// Sends a new STUN request to the given transport address. - pub async fn send_binding_request_to(&self, to: &str) -> Result { - let mut ci = self.client_internal.lock().await; - ci.send_binding_request_to(to).await - } - - /// Sends a new STUN request to the STUN server. - pub async fn send_binding_request(&self) -> Result { - let mut ci = self.client_internal.lock().await; - ci.send_binding_request().await - } -} diff --git a/turn/src/client/periodic_timer.rs b/turn/src/client/periodic_timer.rs deleted file mode 100644 index d0e5cdf44..000000000 --- a/turn/src/client/periodic_timer.rs +++ /dev/null @@ -1,93 +0,0 @@ -#[cfg(test)] -mod periodic_timer_test; - -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; - -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum TimerIdRefresh { - #[default] - Alloc, - Perms, -} - -/// `PeriodicTimerTimeoutHandler` is a handler called on timeout. -#[async_trait] -pub trait PeriodicTimerTimeoutHandler { - async fn on_timeout(&mut self, id: TimerIdRefresh); -} - -/// `PeriodicTimer` is a periodic timer. -#[derive(Default)] -pub struct PeriodicTimer { - id: TimerIdRefresh, - interval: Duration, - close_tx: Mutex>>, -} - -impl PeriodicTimer { - /// create a new [`PeriodicTimer`]. - pub fn new(id: TimerIdRefresh, interval: Duration) -> Self { - PeriodicTimer { - id, - interval, - close_tx: Mutex::new(None), - } - } - - /// Starts the timer. - pub async fn start( - &self, - timeout_handler: Arc>, - ) -> bool { - // this is a noop if the timer is always running - { - let close_tx = self.close_tx.lock().await; - if close_tx.is_some() { - return false; - } - } - - let (close_tx, mut close_rx) = mpsc::channel(1); - let interval = self.interval; - let id = self.id; - - tokio::spawn(async move { - loop { - let timer = tokio::time::sleep(interval); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() => { - let mut handler = timeout_handler.lock().await; - handler.on_timeout(id).await; - } - _ = close_rx.recv() => break, - } - } - }); - - { - let mut close = self.close_tx.lock().await; - *close = Some(close_tx); - } - - true - } - - /// Stops the timer. - pub async fn stop(&self) { - let mut close_tx = self.close_tx.lock().await; - close_tx.take(); - } - - /// Tests if the timer is running. - /// Debug purpose only. - pub async fn is_running(&self) -> bool { - let close_tx = self.close_tx.lock().await; - close_tx.is_some() - } -} diff --git a/turn/src/client/periodic_timer/periodic_timer_test.rs b/turn/src/client/periodic_timer/periodic_timer_test.rs deleted file mode 100644 index b67b31dd3..000000000 --- a/turn/src/client/periodic_timer/periodic_timer_test.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::*; -use crate::error::Result; - -struct DummyPeriodicTimerTimeoutHandler; - -#[async_trait] -impl PeriodicTimerTimeoutHandler for DummyPeriodicTimerTimeoutHandler { - async fn on_timeout(&mut self, id: TimerIdRefresh) { - assert_eq!(id, TimerIdRefresh::Perms); - } -} - -#[tokio::test] -async fn test_periodic_timer() -> Result<()> { - let timer_id = TimerIdRefresh::Perms; - let rt = PeriodicTimer::new(timer_id, Duration::from_millis(50)); - let dummy1 = Arc::new(Mutex::new(DummyPeriodicTimerTimeoutHandler {})); - let dummy2 = Arc::clone(&dummy1); - - assert!(!rt.is_running().await, "should not be running yet"); - - let ok = rt.start(dummy1).await; - assert!(ok, "should be true"); - assert!(rt.is_running().await, "should be running"); - - tokio::time::sleep(Duration::from_millis(100)).await; - - let ok = rt.start(dummy2).await; - assert!(!ok, "start again is noop"); - - tokio::time::sleep(Duration::from_millis(120)).await; - rt.stop().await; - - assert!(!rt.is_running().await, "should not be running"); - - Ok(()) -} diff --git a/turn/src/client/permission.rs b/turn/src/client/permission.rs deleted file mode 100644 index d34b1fdba..000000000 --- a/turn/src/client/permission.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicU8; - -#[derive(Default, Copy, Clone, PartialEq, Debug)] -pub(crate) enum PermState { - #[default] - Idle = 0, - Permitted = 1, -} - -impl From for PermState { - fn from(v: u8) -> Self { - match v { - 0 => PermState::Idle, - _ => PermState::Permitted, - } - } -} - -#[derive(Default)] -pub(crate) struct Permission { - st: AtomicU8, //PermState, -} - -impl Permission { - pub(crate) fn set_state(&self, state: PermState) { - self.st.store(state as u8, Ordering::SeqCst); - } - - pub(crate) fn state(&self) -> PermState { - self.st.load(Ordering::SeqCst).into() - } -} - -/// Thread-safe Permission map. -#[derive(Default)] -pub(crate) struct PermissionMap { - perm_map: HashMap>, -} - -impl PermissionMap { - pub(crate) fn new() -> PermissionMap { - PermissionMap { - perm_map: HashMap::new(), - } - } - - pub(crate) fn insert(&mut self, addr: &SocketAddr, p: Arc) { - self.perm_map.insert(addr.ip().to_string(), p); - } - - pub(crate) fn find(&self, addr: &SocketAddr) -> Option<&Arc> { - self.perm_map.get(&addr.ip().to_string()) - } - - pub(crate) fn delete(&mut self, addr: &SocketAddr) { - self.perm_map.remove(&addr.ip().to_string()); - } - - pub(crate) fn addrs(&self) -> Vec { - let mut a = vec![]; - for k in self.perm_map.keys() { - if let Ok(ip) = k.parse() { - a.push(SocketAddr::new(ip, 0)); - } - } - a - } -} diff --git a/turn/src/client/relay_conn.rs b/turn/src/client/relay_conn.rs deleted file mode 100644 index dd990a879..000000000 --- a/turn/src/client/relay_conn.rs +++ /dev/null @@ -1,630 +0,0 @@ -#[cfg(test)] -mod relay_conn_test; - -// client implements the API for a TURN client -use std::io; -use std::net::SocketAddr; -use std::sync::Arc; - -use async_trait::async_trait; -use stun::agent::*; -use stun::attributes::*; -use stun::error_code::*; -use stun::fingerprint::*; -use stun::integrity::*; -use stun::message::*; -use stun::textattrs::*; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::{Duration, Instant}; -use util::Conn; - -use super::binding::*; -use super::periodic_timer::*; -use super::permission::*; -use super::transaction::*; -use crate::{proto, Error}; - -const PERM_REFRESH_INTERVAL: Duration = Duration::from_secs(120); -const MAX_RETRY_ATTEMPTS: u16 = 3; - -pub(crate) struct InboundData { - pub(crate) data: Vec, - pub(crate) from: SocketAddr, -} - -/// `RelayConnObserver` is an interface to [`RelayConn`] observer. -#[async_trait] -pub trait RelayConnObserver { - fn turn_server_addr(&self) -> String; - fn username(&self) -> Username; - fn realm(&self) -> Realm; - async fn write_to(&self, data: &[u8], to: &str) -> Result; - async fn perform_transaction( - &mut self, - msg: &Message, - to: &str, - ignore_result: bool, - ) -> Result; -} - -/// `RelayConnConfig` is a set of configuration params used by [`RelayConn::new()`]. -pub(crate) struct RelayConnConfig { - pub(crate) relayed_addr: SocketAddr, - pub(crate) integrity: MessageIntegrity, - pub(crate) nonce: Nonce, - pub(crate) lifetime: Duration, - pub(crate) binding_mgr: Arc>, - pub(crate) read_ch_rx: Arc>>, -} - -pub struct RelayConnInternal { - obs: Arc>, - relayed_addr: SocketAddr, - perm_map: PermissionMap, - binding_mgr: Arc>, - integrity: MessageIntegrity, - nonce: Nonce, - lifetime: Duration, -} - -/// `RelayConn` is the implementation of the Conn interfaces for UDP Relayed network connections. -pub struct RelayConn { - relayed_addr: SocketAddr, - read_ch_rx: Arc>>, - relay_conn: Arc>>, - refresh_alloc_timer: PeriodicTimer, - refresh_perms_timer: PeriodicTimer, -} - -impl RelayConn { - /// Creates a new [`RelayConn`]. - pub(crate) async fn new(obs: Arc>, config: RelayConnConfig) -> Self { - log::debug!("initial lifetime: {} seconds", config.lifetime.as_secs()); - - let c = RelayConn { - refresh_alloc_timer: PeriodicTimer::new(TimerIdRefresh::Alloc, config.lifetime / 2), - refresh_perms_timer: PeriodicTimer::new(TimerIdRefresh::Perms, PERM_REFRESH_INTERVAL), - relayed_addr: config.relayed_addr, - read_ch_rx: Arc::clone(&config.read_ch_rx), - relay_conn: Arc::new(Mutex::new(RelayConnInternal::new(obs, config))), - }; - - let rci1 = Arc::clone(&c.relay_conn); - let rci2 = Arc::clone(&c.relay_conn); - - if c.refresh_alloc_timer.start(rci1).await { - log::debug!("refresh_alloc_timer started"); - } - if c.refresh_perms_timer.start(rci2).await { - log::debug!("refresh_perms_timer started"); - } - - c - } -} - -#[async_trait] -impl Conn for RelayConn { - async fn connect(&self, _addr: SocketAddr) -> Result<(), util::Error> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, _buf: &mut [u8]) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - /// Reads a packet from the connection, - /// copying the payload into `p`. It returns the number of - /// bytes copied into `p` and the return address that - /// was on the packet. - /// It returns the number of bytes read `(0 <= n <= len(p))` - /// and any error encountered. Callers should always process - /// the `n > 0` bytes returned before considering the error. - /// It can be made to time out and return - /// an Error with Timeout() == true after a fixed time limit; - /// see SetDeadline and SetReadDeadline. - async fn recv_from(&self, p: &mut [u8]) -> Result<(usize, SocketAddr), util::Error> { - let mut read_ch_rx = self.read_ch_rx.lock().await; - - if let Some(ib_data) = read_ch_rx.recv().await { - let n = ib_data.data.len(); - if p.len() < n { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - Error::ErrShortBuffer.to_string(), - ) - .into()); - } - p[..n].copy_from_slice(&ib_data.data); - Ok((n, ib_data.from)) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - Error::ErrAlreadyClosed.to_string(), - ) - .into()) - } - } - - async fn send(&self, _buf: &[u8]) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - /// Writes a packet with payload `p` to `addr`. - /// It can be made to time out and return - /// an Error with Timeout() == true after a fixed time limit; - /// see SetDeadline and SetWriteDeadline. - /// On packet-oriented connections, write timeouts are rare. - async fn send_to(&self, p: &[u8], addr: SocketAddr) -> Result { - let mut relay_conn = self.relay_conn.lock().await; - match relay_conn.send_to(p, addr).await { - Ok(n) => Ok(n), - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), - } - } - - /// Returns the local network address. - fn local_addr(&self) -> Result { - Ok(self.relayed_addr) - } - - fn remote_addr(&self) -> Option { - None - } - - /// Closes the connection. - /// Any blocked [`Self::recv_from()`] or [`Self::send_to()`] operations - /// will be unblocked and return errors. - async fn close(&self) -> Result<(), util::Error> { - self.refresh_alloc_timer.stop().await; - self.refresh_perms_timer.stop().await; - - let mut relay_conn = self.relay_conn.lock().await; - let _ = relay_conn - .close() - .await - .map_err(|err| util::Error::Other(format!("{err}"))); - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -impl RelayConnInternal { - /// Creates a new [`RelayConnInternal`]. - fn new(obs: Arc>, config: RelayConnConfig) -> Self { - RelayConnInternal { - obs, - relayed_addr: config.relayed_addr, - perm_map: PermissionMap::new(), - binding_mgr: config.binding_mgr, - integrity: config.integrity, - nonce: config.nonce, - lifetime: config.lifetime, - } - } - - /// Writes a packet with payload `p` to `addr`. - /// It can be made to time out and return - /// an Error with Timeout() == true after a fixed time limit; - /// see SetDeadline and SetWriteDeadline. - /// On packet-oriented connections, write timeouts are rare. - async fn send_to(&mut self, p: &[u8], addr: SocketAddr) -> Result { - // check if we have a permission for the destination IP addr - let perm = if let Some(perm) = self.perm_map.find(&addr) { - Arc::clone(perm) - } else { - let perm = Arc::new(Permission::default()); - self.perm_map.insert(&addr, Arc::clone(&perm)); - perm - }; - - let mut result = Ok(()); - for _ in 0..MAX_RETRY_ATTEMPTS { - result = self.create_perm(&perm, addr).await; - if let Err(err) = &result { - if Error::ErrTryAgain != *err { - break; - } - } - } - result?; - - let number = { - let (bind_st, bind_at, bind_number, bind_addr) = { - let mut binding_mgr = self.binding_mgr.lock().await; - let b = if let Some(b) = binding_mgr.find_by_addr(&addr) { - b - } else { - binding_mgr - .create(addr) - .ok_or_else(|| Error::Other("Addr not found".to_owned()))? - }; - (b.state(), b.refreshed_at(), b.number, b.addr) - }; - - if bind_st == BindingState::Idle - || bind_st == BindingState::Request - || bind_st == BindingState::Failed - { - // block only callers with the same binding until - // the binding transaction has been complete - // binding state may have been changed while waiting. check again. - if bind_st == BindingState::Idle { - let binding_mgr = Arc::clone(&self.binding_mgr); - let rc_obs = Arc::clone(&self.obs); - let nonce = self.nonce.clone(); - let integrity = self.integrity.clone(); - { - let mut bm = binding_mgr.lock().await; - if let Some(b) = bm.get_by_addr(&bind_addr) { - b.set_state(BindingState::Request); - } - } - tokio::spawn(async move { - let result = RelayConnInternal::bind( - rc_obs, - bind_addr, - bind_number, - nonce, - integrity, - ) - .await; - - { - let mut bm = binding_mgr.lock().await; - if let Err(err) = result { - if Error::ErrUnexpectedResponse != err { - bm.delete_by_addr(&bind_addr); - } else if let Some(b) = bm.get_by_addr(&bind_addr) { - b.set_state(BindingState::Failed); - } - - // keep going... - log::warn!("bind() failed: {}", err); - } else if let Some(b) = bm.get_by_addr(&bind_addr) { - b.set_state(BindingState::Ready); - } - } - }); - } - - // send data using SendIndication - let peer_addr = socket_addr2peer_address(&addr); - let mut msg = Message::new(); - msg.build(&[ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_SEND, CLASS_INDICATION)), - Box::new(proto::data::Data(p.to_vec())), - Box::new(peer_addr), - Box::new(FINGERPRINT), - ])?; - - // indication has no transaction (fire-and-forget) - let obs = self.obs.lock().await; - let turn_server_addr = obs.turn_server_addr(); - return Ok(obs.write_to(&msg.raw, &turn_server_addr).await?); - } - - // binding is either ready - - // check if the binding needs a refresh - if bind_st == BindingState::Ready - && Instant::now() - .checked_duration_since(bind_at) - .unwrap_or_else(|| Duration::from_secs(0)) - > Duration::from_secs(5 * 60) - { - let binding_mgr = Arc::clone(&self.binding_mgr); - let rc_obs = Arc::clone(&self.obs); - let nonce = self.nonce.clone(); - let integrity = self.integrity.clone(); - { - let mut bm = binding_mgr.lock().await; - if let Some(b) = bm.get_by_addr(&bind_addr) { - b.set_state(BindingState::Refresh); - } - } - tokio::spawn(async move { - let result = - RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity) - .await; - - { - let mut bm = binding_mgr.lock().await; - if let Err(err) = result { - if Error::ErrUnexpectedResponse != err { - bm.delete_by_addr(&bind_addr); - } else if let Some(b) = bm.get_by_addr(&bind_addr) { - b.set_state(BindingState::Failed); - } - - // keep going... - log::warn!("bind() for refresh failed: {}", err); - } else if let Some(b) = bm.get_by_addr(&bind_addr) { - b.set_refreshed_at(Instant::now()); - b.set_state(BindingState::Ready); - } - } - }); - } - - bind_number - }; - - // send via ChannelData - self.send_channel_data(p, number).await - } - - /// This func-block would block, per destination IP (, or perm), until - /// the perm state becomes "requested". Purpose of this is to guarantee - /// the order of packets (within the same perm). - /// Note that CreatePermission transaction may not be complete before - /// all the data transmission. This is done assuming that the request - /// will be mostly likely successful and we can tolerate some loss of - /// UDP packet (or reorder), inorder to minimize the latency in most cases. - async fn create_perm(&mut self, perm: &Arc, addr: SocketAddr) -> Result<(), Error> { - if perm.state() == PermState::Idle { - // punch a hole! (this would block a bit..) - if let Err(err) = self.create_permissions(&[addr]).await { - self.perm_map.delete(&addr); - return Err(err); - } - perm.set_state(PermState::Permitted); - } - Ok(()) - } - - async fn send_channel_data(&self, data: &[u8], ch_num: u16) -> Result { - let mut ch_data = proto::chandata::ChannelData { - data: data.to_vec(), - number: proto::channum::ChannelNumber(ch_num), - ..Default::default() - }; - ch_data.encode(); - - let obs = self.obs.lock().await; - Ok(obs.write_to(&ch_data.raw, &obs.turn_server_addr()).await?) - } - - async fn create_permissions(&mut self, addrs: &[SocketAddr]) -> Result<(), Error> { - let res = { - let msg = { - let obs = self.obs.lock().await; - let mut setters: Vec> = vec![ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST)), - ]; - - for addr in addrs { - setters.push(Box::new(socket_addr2peer_address(addr))); - } - - setters.push(Box::new(obs.username())); - setters.push(Box::new(obs.realm())); - setters.push(Box::new(self.nonce.clone())); - setters.push(Box::new(self.integrity.clone())); - setters.push(Box::new(FINGERPRINT)); - - let mut msg = Message::new(); - msg.build(&setters)?; - msg - }; - - let mut obs = self.obs.lock().await; - let turn_server_addr = obs.turn_server_addr(); - - log::debug!("UDPConn.createPermissions call PerformTransaction 1"); - let tr_res = obs - .perform_transaction(&msg, &turn_server_addr, false) - .await?; - - tr_res.msg - }; - - if res.typ.class == CLASS_ERROR_RESPONSE { - let mut code = ErrorCodeAttribute::default(); - let result = code.get_from(&res); - if result.is_err() { - return Err(Error::Other(format!("{}", res.typ))); - } else if code.code == CODE_STALE_NONCE { - self.set_nonce_from_msg(&res); - return Err(Error::ErrTryAgain); - } else { - return Err(Error::Other(format!("{} (error {})", res.typ, code))); - } - } - - Ok(()) - } - - pub fn set_nonce_from_msg(&mut self, msg: &Message) { - // Update nonce - match Nonce::get_from_as(msg, ATTR_NONCE) { - Ok(nonce) => { - self.nonce = nonce; - log::debug!("refresh allocation: 438, got new nonce."); - } - Err(_) => log::warn!("refresh allocation: 438 but no nonce."), - } - } - - /// Closes the connection. - /// Any blocked `recv_from` or `send_to` operations will be unblocked and return errors. - pub async fn close(&mut self) -> Result<(), Error> { - self.refresh_allocation(Duration::from_secs(0), true /* dontWait=true */) - .await - } - - async fn refresh_allocation( - &mut self, - lifetime: Duration, - dont_wait: bool, - ) -> Result<(), Error> { - let res = { - let mut obs = self.obs.lock().await; - - let mut msg = Message::new(); - msg.build(&[ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_REFRESH, CLASS_REQUEST)), - Box::new(proto::lifetime::Lifetime(lifetime)), - Box::new(obs.username()), - Box::new(obs.realm()), - Box::new(self.nonce.clone()), - Box::new(self.integrity.clone()), - Box::new(FINGERPRINT), - ])?; - - log::debug!("send refresh request (dont_wait={})", dont_wait); - let turn_server_addr = obs.turn_server_addr(); - let tr_res = obs - .perform_transaction(&msg, &turn_server_addr, dont_wait) - .await?; - - if dont_wait { - log::debug!("refresh request sent"); - return Ok(()); - } - - log::debug!("refresh request sent, and waiting response"); - - tr_res.msg - }; - - if res.typ.class == CLASS_ERROR_RESPONSE { - let mut code = ErrorCodeAttribute::default(); - let result = code.get_from(&res); - if result.is_err() { - return Err(Error::Other(format!("{}", res.typ))); - } else if code.code == CODE_STALE_NONCE { - self.set_nonce_from_msg(&res); - return Err(Error::ErrTryAgain); - } else { - return Ok(()); - } - } - - // Getting lifetime from response - let mut updated_lifetime = proto::lifetime::Lifetime::default(); - updated_lifetime.get_from(&res)?; - - self.lifetime = updated_lifetime.0; - log::debug!("updated lifetime: {} seconds", self.lifetime.as_secs()); - Ok(()) - } - - async fn refresh_permissions(&mut self) -> Result<(), Error> { - let addrs = self.perm_map.addrs(); - if addrs.is_empty() { - log::debug!("no permission to refresh"); - return Ok(()); - } - - if let Err(err) = self.create_permissions(&addrs).await { - if Error::ErrTryAgain != err { - log::error!("fail to refresh permissions: {}", err); - } - return Err(err); - } - - log::debug!("refresh permissions successful"); - Ok(()) - } - - async fn bind( - rc_obs: Arc>, - bind_addr: SocketAddr, - bind_number: u16, - nonce: Nonce, - integrity: MessageIntegrity, - ) -> Result<(), Error> { - let (msg, turn_server_addr) = { - let obs = rc_obs.lock().await; - - let setters: Vec> = vec![ - Box::new(TransactionId::new()), - Box::new(MessageType::new(METHOD_CHANNEL_BIND, CLASS_REQUEST)), - Box::new(socket_addr2peer_address(&bind_addr)), - Box::new(proto::channum::ChannelNumber(bind_number)), - Box::new(obs.username()), - Box::new(obs.realm()), - Box::new(nonce), - Box::new(integrity), - Box::new(FINGERPRINT), - ]; - - let mut msg = Message::new(); - msg.build(&setters)?; - - (msg, obs.turn_server_addr()) - }; - - log::debug!("UDPConn.bind call PerformTransaction 1"); - let tr_res = { - let mut obs = rc_obs.lock().await; - obs.perform_transaction(&msg, &turn_server_addr, false) - .await? - }; - - let res = tr_res.msg; - - if res.typ != MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE) { - return Err(Error::ErrUnexpectedResponse); - } - - log::debug!("channel binding successful: {} {}", bind_addr, bind_number); - - // Success. - Ok(()) - } -} - -#[async_trait] -impl PeriodicTimerTimeoutHandler for RelayConnInternal { - async fn on_timeout(&mut self, id: TimerIdRefresh) { - log::debug!("refresh timer {:?} expired", id); - match id { - TimerIdRefresh::Alloc => { - let lifetime = self.lifetime; - // limit the max retries on errTryAgain to 3 - // when stale nonce returns, second retry should succeed - let mut result = Ok(()); - for _ in 0..MAX_RETRY_ATTEMPTS { - result = self.refresh_allocation(lifetime, false).await; - if let Err(err) = &result { - if Error::ErrTryAgain != *err { - break; - } - } - } - if result.is_err() { - log::warn!("refresh allocation failed"); - } - } - TimerIdRefresh::Perms => { - let mut result = Ok(()); - for _ in 0..MAX_RETRY_ATTEMPTS { - result = self.refresh_permissions().await; - if let Err(err) = &result { - if Error::ErrTryAgain != *err { - break; - } - } - } - if result.is_err() { - log::warn!("refresh permissions failed"); - } - } - } - } -} - -fn socket_addr2peer_address(addr: &SocketAddr) -> proto::peeraddr::PeerAddress { - proto::peeraddr::PeerAddress { - ip: addr.ip(), - port: addr.port(), - } -} diff --git a/turn/src/client/relay_conn/relay_conn_test.rs b/turn/src/client/relay_conn/relay_conn_test.rs deleted file mode 100644 index cc0e6f59a..000000000 --- a/turn/src/client/relay_conn/relay_conn_test.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::net::Ipv4Addr; - -use super::*; -use crate::error::Result; - -struct DummyRelayConnObserver { - turn_server_addr: String, - username: Username, - realm: Realm, -} - -#[async_trait] -impl RelayConnObserver for DummyRelayConnObserver { - fn turn_server_addr(&self) -> String { - self.turn_server_addr.clone() - } - - fn username(&self) -> Username { - self.username.clone() - } - - fn realm(&self) -> Realm { - self.realm.clone() - } - - async fn write_to(&self, _data: &[u8], _to: &str) -> std::result::Result { - Ok(0) - } - - async fn perform_transaction( - &mut self, - _msg: &Message, - _to: &str, - _dont_wait: bool, - ) -> Result { - Err(Error::ErrFakeErr) - } -} - -#[tokio::test] -async fn test_relay_conn() -> Result<()> { - let obs = DummyRelayConnObserver { - turn_server_addr: String::new(), - username: Username::new(ATTR_USERNAME, "username".to_owned()), - realm: Realm::new(ATTR_REALM, "realm".to_owned()), - }; - - let (_read_ch_tx, read_ch_rx) = mpsc::channel(100); - - let config = RelayConnConfig { - relayed_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), - integrity: MessageIntegrity::default(), - nonce: Nonce::new(ATTR_NONCE, "nonce".to_owned()), - lifetime: Duration::from_secs(0), - binding_mgr: Arc::new(Mutex::new(BindingManager::new())), - read_ch_rx: Arc::new(Mutex::new(read_ch_rx)), - }; - - let rc = RelayConn::new(Arc::new(Mutex::new(obs)), config).await; - - let rci = rc.relay_conn.lock().await; - let (bind_addr, bind_number) = { - let mut bm = rci.binding_mgr.lock().await; - let b = bm - .create(SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 1234)) - .unwrap(); - (b.addr, b.number) - }; - - //let binding_mgr = Arc::clone(&rci.binding_mgr); - let rc_obs = Arc::clone(&rci.obs); - let nonce = rci.nonce.clone(); - let integrity = rci.integrity.clone(); - - if let Err(err) = - RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity).await - { - assert!(Error::ErrUnexpectedResponse != err); - } else { - panic!("should fail"); - } - - Ok(()) -} diff --git a/turn/src/client/transaction.rs b/turn/src/client/transaction.rs deleted file mode 100644 index 557269d7e..000000000 --- a/turn/src/client/transaction.rs +++ /dev/null @@ -1,282 +0,0 @@ -use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::str::FromStr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicU16; -use stun::message::*; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; -use util::Conn; - -use crate::error::*; - -const MAX_RTX_INTERVAL_IN_MS: u16 = 1600; -const MAX_RTX_COUNT: u16 = 7; // total 7 requests (Rc) - -async fn on_rtx_timeout( - conn: &Arc, - tr_map: &Arc>, - tr_key: &str, - n_rtx: u16, -) -> bool { - let mut tm = tr_map.lock().await; - let (tr_raw, tr_to) = match tm.find(tr_key) { - Some(tr) => (tr.raw.clone(), tr.to.clone()), - None => return true, // already gone - }; - - if n_rtx == MAX_RTX_COUNT { - // all retransmisstions failed - if let Some(tr) = tm.delete(tr_key) { - if !tr - .write_result(TransactionResult { - err: Some(Error::Other(format!( - "{:?} {}", - Error::ErrAllRetransmissionsFailed, - tr_key - ))), - ..Default::default() - }) - .await - { - log::debug!("no listener for transaction"); - } - } - return true; - } - - log::trace!( - "retransmitting transaction {} to {} (n_rtx={})", - tr_key, - tr_to, - n_rtx - ); - - let dst = match SocketAddr::from_str(&tr_to) { - Ok(dst) => dst, - Err(_) => return false, - }; - - if conn.send_to(&tr_raw, dst).await.is_err() { - if let Some(tr) = tm.delete(tr_key) { - if !tr - .write_result(TransactionResult { - err: Some(Error::Other(format!( - "{:?} {}", - Error::ErrAllRetransmissionsFailed, - tr_key - ))), - ..Default::default() - }) - .await - { - log::debug!("no listener for transaction"); - } - } - return true; - } - - false -} - -/// `TransactionResult` is a bag of result values of a transaction. -#[derive(Debug)] //Clone -pub struct TransactionResult { - pub msg: Message, - pub from: SocketAddr, - pub retries: u16, - pub err: Option, -} - -impl Default for TransactionResult { - fn default() -> Self { - TransactionResult { - msg: Message::default(), - from: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), - retries: 0, - err: None, - } - } -} - -/// `TransactionConfig` is a set of config params used by [`Transaction::new()`]. -#[derive(Default)] -pub struct TransactionConfig { - pub key: String, - pub raw: Vec, - pub to: String, - pub interval: u16, - pub ignore_result: bool, // true to throw away the result of this transaction (it will not be readable using wait_for_result) -} - -/// `Transaction` represents a transaction. -#[derive(Debug)] -pub struct Transaction { - pub key: String, - pub raw: Vec, - pub to: String, - pub n_rtx: Arc, - pub interval: Arc, - timer_ch_tx: Option>, - result_ch_tx: Option>, - result_ch_rx: Option>, -} - -impl Default for Transaction { - fn default() -> Self { - Transaction { - key: String::new(), - raw: vec![], - to: String::new(), - n_rtx: Arc::new(AtomicU16::new(0)), - interval: Arc::new(AtomicU16::new(0)), - //timer: None, - timer_ch_tx: None, - result_ch_tx: None, - result_ch_rx: None, - } - } -} - -impl Transaction { - /// Creates a new [`Transaction`] using the given `config`. - pub fn new(config: TransactionConfig) -> Self { - let (result_ch_tx, result_ch_rx) = if !config.ignore_result { - let (tx, rx) = mpsc::channel(1); - (Some(tx), Some(rx)) - } else { - (None, None) - }; - - Transaction { - key: config.key, - raw: config.raw, - to: config.to, - interval: Arc::new(AtomicU16::new(config.interval)), - result_ch_tx, - result_ch_rx, - ..Default::default() - } - } - - /// Starts the transaction timer. - pub async fn start_rtx_timer( - &mut self, - conn: Arc, - tr_map: Arc>, - ) { - let (timer_ch_tx, mut timer_ch_rx) = mpsc::channel(1); - self.timer_ch_tx = Some(timer_ch_tx); - let (n_rtx, interval, key) = (self.n_rtx.clone(), self.interval.clone(), self.key.clone()); - - tokio::spawn(async move { - let mut done = false; - while !done { - let timer = tokio::time::sleep(Duration::from_millis( - interval.load(Ordering::SeqCst) as u64, - )); - tokio::pin!(timer); - - tokio::select! { - _ = timer.as_mut() => { - let rtx = n_rtx.fetch_add(1, Ordering::SeqCst); - - let mut val = interval.load(Ordering::SeqCst); - val *= 2; - if val > MAX_RTX_INTERVAL_IN_MS { - val = MAX_RTX_INTERVAL_IN_MS; - } - interval.store(val, Ordering::SeqCst); - - done = on_rtx_timeout(&conn, &tr_map, &key, rtx + 1).await; - } - _ = timer_ch_rx.recv() => done = true, - } - } - }); - } - - /// Stops the transaction timer. - pub fn stop_rtx_timer(&mut self) { - if self.timer_ch_tx.is_some() { - self.timer_ch_tx.take(); - } - } - - /// Writes the result to the result channel. - pub async fn write_result(&self, res: TransactionResult) -> bool { - if let Some(result_ch) = &self.result_ch_tx { - result_ch.send(res).await.is_ok() - } else { - false - } - } - - /// Returns the result channel. - pub fn get_result_channel(&mut self) -> Option> { - self.result_ch_rx.take() - } - - /// Closes the transaction. - pub fn close(&mut self) { - if self.result_ch_tx.is_some() { - self.result_ch_tx.take(); - } - } - - /// Returns the number of retransmission it has made. - pub fn retries(&self) -> u16 { - self.n_rtx.load(Ordering::SeqCst) - } -} - -/// `TransactionMap` is a thread-safe transaction map. -#[derive(Default, Debug)] -pub struct TransactionMap { - tr_map: HashMap, -} - -impl TransactionMap { - /// Create a new [`TransactionMap`]. - pub fn new() -> TransactionMap { - TransactionMap { - tr_map: HashMap::new(), - } - } - - /// Inserts a [`Transaction`] to the map. - pub fn insert(&mut self, key: String, tr: Transaction) -> bool { - self.tr_map.insert(key, tr); - true - } - - /// Looks up a [`Transaction`] by its key. - pub fn find(&self, key: &str) -> Option<&Transaction> { - self.tr_map.get(key) - } - - /// Gets the [`Transaction`] associated with the given `key`. - pub fn get(&mut self, key: &str) -> Option<&mut Transaction> { - self.tr_map.get_mut(key) - } - - /// Deletes a [`Transaction`] by its key. - pub fn delete(&mut self, key: &str) -> Option { - self.tr_map.remove(key) - } - - /// Closes and deletes all [`Transaction`]s. - pub fn close_and_delete_all(&mut self) { - for tr in self.tr_map.values_mut() { - tr.close(); - } - self.tr_map.clear(); - } - - /// Returns its length. - pub fn size(&self) -> usize { - self.tr_map.len() - } -} diff --git a/turn/src/error.rs b/turn/src/error.rs deleted file mode 100644 index 415bd51ff..000000000 --- a/turn/src/error.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::num::ParseIntError; -use std::time::SystemTimeError; -use std::{io, net}; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("turn: RelayAddress must be valid IP to use RelayAddressGeneratorStatic")] - ErrRelayAddressInvalid, - #[error("turn: PacketConnConfigs and ConnConfigs are empty, unable to proceed")] - ErrNoAvailableConns, - #[error("turn: PacketConnConfig must have a non-nil Conn")] - ErrConnUnset, - #[error("turn: ListenerConfig must have a non-nil Listener")] - ErrListenerUnset, - #[error("turn: RelayAddressGenerator has invalid ListeningAddress")] - ErrListeningAddressInvalid, - #[error("turn: RelayAddressGenerator in RelayConfig is unset")] - ErrRelayAddressGeneratorUnset, - #[error("turn: max retries exceeded")] - ErrMaxRetriesExceeded, - #[error("turn: MaxPort must be not 0")] - ErrMaxPortNotZero, - #[error("turn: MaxPort must be not 0")] - ErrMinPortNotZero, - #[error("turn: MaxPort less than MinPort")] - ErrMaxPortLessThanMinPort, - #[error("turn: relay_conn cannot not be nil")] - ErrNilConn, - #[error("turn: TODO")] - ErrTodo, - #[error("turn: already listening")] - ErrAlreadyListening, - #[error("turn: Server failed to close")] - ErrFailedToClose, - #[error("turn: failed to retransmit transaction")] - ErrFailedToRetransmitTransaction, - #[error("all retransmissions failed")] - ErrAllRetransmissionsFailed, - #[error("no binding found for channel")] - ErrChannelBindNotFound, - #[error("STUN server address is not set for the client")] - ErrStunserverAddressNotSet, - #[error("only one Allocate() caller is allowed")] - ErrOneAllocateOnly, - #[error("already allocated")] - ErrAlreadyAllocated, - #[error("non-STUN message from STUN server")] - ErrNonStunmessage, - #[error("failed to decode STUN message")] - ErrFailedToDecodeStun, - #[error("unexpected STUN request message")] - ErrUnexpectedStunrequestMessage, - #[error("channel number not in [0x4000, 0x7FFF]")] - ErrInvalidChannelNumber, - #[error("channelData length != len(Data)")] - ErrBadChannelDataLength, - #[error("unexpected EOF")] - ErrUnexpectedEof, - #[error("invalid value for requested family attribute")] - ErrInvalidRequestedFamilyValue, - #[error("error code 443: peer address family mismatch")] - ErrPeerAddressFamilyMismatch, - #[error("fake error")] - ErrFakeErr, - #[error("try again")] - ErrTryAgain, - #[error("use of closed network connection")] - ErrClosed, - #[error("addr is not a net.UDPAddr")] - ErrUdpaddrCast, - #[error("already closed")] - ErrAlreadyClosed, - #[error("try-lock is already locked")] - ErrDoubleLock, - #[error("transaction closed")] - ErrTransactionClosed, - #[error("wait_for_result called on non-result transaction")] - ErrWaitForResultOnNonResultTransaction, - #[error("failed to build refresh request")] - ErrFailedToBuildRefreshRequest, - #[error("failed to refresh allocation")] - ErrFailedToRefreshAllocation, - #[error("failed to get lifetime from refresh response")] - ErrFailedToGetLifetime, - #[error("too short buffer")] - ErrShortBuffer, - #[error("unexpected response type")] - ErrUnexpectedResponse, - #[error("AllocatePacketConn must be set")] - ErrAllocatePacketConnMustBeSet, - #[error("AllocateConn must be set")] - ErrAllocateConnMustBeSet, - #[error("LeveledLogger must be set")] - ErrLeveledLoggerMustBeSet, - #[error("you cannot use the same channel number with different peer")] - ErrSameChannelDifferentPeer, - #[error("allocations must not be created with nil FivTuple")] - ErrNilFiveTuple, - #[error("allocations must not be created with nil FiveTuple.src_addr")] - ErrNilFiveTupleSrcAddr, - #[error("allocations must not be created with nil FiveTuple.dst_addr")] - ErrNilFiveTupleDstAddr, - #[error("allocations must not be created with nil turnSocket")] - ErrNilTurnSocket, - #[error("allocations must not be created with a lifetime of 0")] - ErrLifetimeZero, - #[error("allocation attempt created with duplicate FiveTuple")] - ErrDupeFiveTuple, - #[error("failed to cast net.Addr to *net.UDPAddr")] - ErrFailedToCastUdpaddr, - #[error("failed to generate nonce")] - ErrFailedToGenerateNonce, - #[error("failed to send error message")] - ErrFailedToSendError, - #[error("duplicated Nonce generated, discarding request")] - ErrDuplicatedNonce, - #[error("no such user exists")] - ErrNoSuchUser, - #[error("unexpected class")] - ErrUnexpectedClass, - #[error("unexpected method")] - ErrUnexpectedMethod, - #[error("failed to handle")] - ErrFailedToHandle, - #[error("unhandled STUN packet")] - ErrUnhandledStunpacket, - #[error("unable to handle ChannelData")] - ErrUnableToHandleChannelData, - #[error("failed to create stun message from packet")] - ErrFailedToCreateStunpacket, - #[error("failed to create channel data from packet")] - ErrFailedToCreateChannelData, - #[error("relay already allocated for 5-TUPLE")] - ErrRelayAlreadyAllocatedForFiveTuple, - #[error("RequestedTransport must be UDP")] - ErrRequestedTransportMustBeUdp, - #[error("no support for DONT-FRAGMENT")] - ErrNoDontFragmentSupport, - #[error("Request must not contain RESERVATION-TOKEN and EVEN-PORT")] - ErrRequestWithReservationTokenAndEvenPort, - #[error("Request must not contain RESERVATION-TOKEN and REQUESTED-ADDRESS-FAMILY")] - ErrRequestWithReservationTokenAndReqAddressFamily, - #[error("no allocation found")] - ErrNoAllocationFound, - #[error("unable to handle send-indication, no permission added")] - ErrNoPermission, - #[error("packet write smaller than packet")] - ErrShortWrite, - #[error("no such channel bind")] - ErrNoSuchChannelBind, - #[error("failed writing to socket")] - ErrFailedWriteSocket, - #[error("parse int: {0}")] - ParseInt(#[from] ParseIntError), - #[error("parse addr: {0}")] - ParseIp(#[from] net::AddrParseError), - #[error("{0}")] - Io(#[source] IoError), - #[error("{0}")] - Util(#[from] util::Error), - #[error("{0}")] - Stun(#[from] stun::Error), - #[error("{0}")] - Other(String), -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} - -impl From for Error { - fn from(e: SystemTimeError) -> Self { - Error::Other(e.to_string()) - } -} diff --git a/turn/src/lib.rs b/turn/src/lib.rs deleted file mode 100644 index 2436eeee0..000000000 --- a/turn/src/lib.rs +++ /dev/null @@ -1,13 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] -#![recursion_limit = "256"] - -pub mod allocation; -pub mod auth; -pub mod client; -mod error; -pub mod proto; -pub mod relay; -pub mod server; - -pub use error::Error; diff --git a/turn/src/proto/addr.rs b/turn/src/proto/addr.rs deleted file mode 100644 index 70cc4f45f..000000000 --- a/turn/src/proto/addr.rs +++ /dev/null @@ -1,62 +0,0 @@ -#[cfg(test)] -mod addr_test; - -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - -use super::*; - -/// `Addr` is `ip:port`. -#[derive(PartialEq, Eq, Debug)] -pub struct Addr { - ip: IpAddr, - port: u16, -} - -impl Default for Addr { - fn default() -> Self { - Addr { - ip: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), - port: 0, - } - } -} - -impl fmt::Display for Addr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.ip, self.port) - } -} - -impl Addr { - /// Returns this network. - pub fn network(&self) -> String { - "turn".to_owned() - } - - /// Creates a new [`Addr`] from `n`. - pub fn from_socket_addr(n: &SocketAddr) -> Self { - let ip = n.ip(); - let port = n.port(); - - Addr { ip, port } - } - - /// Returns `true` if the `other` has the same IP address. - pub fn equal_ip(&self, other: &Addr) -> bool { - self.ip == other.ip - } -} - -// FiveTuple represents 5-TUPLE value. -#[derive(PartialEq, Eq, Default)] -pub struct FiveTuple { - pub client: Addr, - pub server: Addr, - pub proto: Protocol, -} - -impl fmt::Display for FiveTuple { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}->{} ({})", self.client, self.server, self.proto) - } -} diff --git a/turn/src/proto/addr/addr_test.rs b/turn/src/proto/addr/addr_test.rs deleted file mode 100644 index ba05a9b12..000000000 --- a/turn/src/proto/addr/addr_test.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::net::Ipv4Addr; - -use super::*; -use crate::error::Result; - -#[test] -fn test_addr_from_socket_addr() -> Result<()> { - let u = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); - - let a = Addr::from_socket_addr(&u); - assert!( - u.ip() == a.ip || u.port() != a.port || u.to_string() != a.to_string(), - "not equal" - ); - assert_eq!(a.network(), "turn", "unexpected network"); - - Ok(()) -} - -#[test] -fn test_addr_equal_ip() -> Result<()> { - let a = Addr { - ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - port: 1337, - }; - let b = Addr { - ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - port: 1338, - }; - assert_ne!(a, b, "a != b"); - assert!(a.equal_ip(&b), "a.IP should equal to b.IP"); - - Ok(()) -} - -#[test] -fn test_five_tuple_equal() -> Result<()> { - let tests = vec![ - ("blank", FiveTuple::default(), FiveTuple::default(), true), - ( - "proto", - FiveTuple { - proto: PROTO_UDP, - ..Default::default() - }, - FiveTuple::default(), - false, - ), - ( - "server", - FiveTuple { - server: Addr { - port: 100, - ..Default::default() - }, - ..Default::default() - }, - FiveTuple::default(), - false, - ), - ( - "client", - FiveTuple { - client: Addr { - port: 100, - ..Default::default() - }, - ..Default::default() - }, - FiveTuple::default(), - false, - ), - ]; - - for (name, a, b, r) in tests { - let v = a == b; - assert_eq!(v, r, "({name}) {a} [{v}!={r}] {b}"); - } - - Ok(()) -} - -#[test] -fn test_five_tuple_string() -> Result<()> { - let s = FiveTuple { - proto: PROTO_UDP, - server: Addr { - port: 100, - ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - }, - client: Addr { - port: 200, - ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - }, - } - .to_string(); - - assert_eq!( - s, "127.0.0.1:200->127.0.0.1:100 (UDP)", - "unexpected stringer output" - ); - - Ok(()) -} diff --git a/turn/src/proto/chandata.rs b/turn/src/proto/chandata.rs deleted file mode 100644 index 58b11549b..000000000 --- a/turn/src/proto/chandata.rs +++ /dev/null @@ -1,110 +0,0 @@ -#[cfg(test)] -mod chandata_test; - -use super::channum::*; -use crate::error::*; - -const PADDING: usize = 4; - -fn nearest_padded_value_length(l: usize) -> usize { - let mut n = PADDING * (l / PADDING); - if n < l { - n += PADDING; - } - n -} - -const CHANNEL_DATA_LENGTH_SIZE: usize = 2; -const CHANNEL_DATA_NUMBER_SIZE: usize = CHANNEL_DATA_LENGTH_SIZE; -const CHANNEL_DATA_HEADER_SIZE: usize = CHANNEL_DATA_LENGTH_SIZE + CHANNEL_DATA_NUMBER_SIZE; - -/// `ChannelData` represents the `ChannelData` Message defined in -/// [RFC 5766 Section 11.4](https://www.rfc-editor.org/rfc/rfc5766#section-11.4). -#[derive(Default, Debug)] -pub struct ChannelData { - pub data: Vec, // can be subslice of Raw - pub number: ChannelNumber, - pub raw: Vec, -} - -impl PartialEq for ChannelData { - fn eq(&self, other: &Self) -> bool { - self.data == other.data && self.number == other.number - } -} - -impl ChannelData { - /// Resets length, [`Self::data`] and [`Self::raw`] length. - #[inline] - pub fn reset(&mut self) { - self.raw.clear(); - self.data.clear(); - } - - /// Encodes this to [`Self::raw`]. - pub fn encode(&mut self) { - self.raw.clear(); - self.write_header(); - self.raw.extend_from_slice(&self.data); - let padded = nearest_padded_value_length(self.raw.len()); - let bytes_to_add = padded - self.raw.len(); - if bytes_to_add > 0 { - self.raw.extend_from_slice(&vec![0; bytes_to_add]); - } - } - - /// Decodes this from [`Self::raw`]. - pub fn decode(&mut self) -> Result<()> { - let buf = &self.raw; - if buf.len() < CHANNEL_DATA_HEADER_SIZE { - return Err(Error::ErrUnexpectedEof); - } - let num = u16::from_be_bytes([buf[0], buf[1]]); - self.number = ChannelNumber(num); - if !self.number.valid() { - return Err(Error::ErrInvalidChannelNumber); - } - let l = u16::from_be_bytes([ - buf[CHANNEL_DATA_NUMBER_SIZE], - buf[CHANNEL_DATA_NUMBER_SIZE + 1], - ]) as usize; - if l > buf[CHANNEL_DATA_HEADER_SIZE..].len() { - return Err(Error::ErrBadChannelDataLength); - } - self.data = buf[CHANNEL_DATA_HEADER_SIZE..CHANNEL_DATA_HEADER_SIZE + l].to_vec(); - - Ok(()) - } - - /// Writes channel number and length. - pub fn write_header(&mut self) { - if self.raw.len() < CHANNEL_DATA_HEADER_SIZE { - // Making WriteHeader call valid even when c.Raw - // is nil or len(c.Raw) is less than needed for header. - self.raw - .resize(self.raw.len() + CHANNEL_DATA_HEADER_SIZE, 0); - } - self.raw[..CHANNEL_DATA_NUMBER_SIZE].copy_from_slice(&self.number.0.to_be_bytes()); - self.raw[CHANNEL_DATA_NUMBER_SIZE..CHANNEL_DATA_HEADER_SIZE] - .copy_from_slice(&(self.data.len() as u16).to_be_bytes()); - } - - /// Returns `true` if `buf` looks like the `ChannelData` Message. - pub fn is_channel_data(buf: &[u8]) -> bool { - if buf.len() < CHANNEL_DATA_HEADER_SIZE { - return false; - } - - if u16::from_be_bytes([ - buf[CHANNEL_DATA_NUMBER_SIZE], - buf[CHANNEL_DATA_NUMBER_SIZE + 1], - ]) > buf[CHANNEL_DATA_HEADER_SIZE..].len() as u16 - { - return false; - } - - // Quick check for channel number. - let num = ChannelNumber(u16::from_be_bytes([buf[0], buf[1]])); - num.valid() - } -} diff --git a/turn/src/proto/chandata/chandata_test.rs b/turn/src/proto/chandata/chandata_test.rs deleted file mode 100644 index 9376f6f4d..000000000 --- a/turn/src/proto/chandata/chandata_test.rs +++ /dev/null @@ -1,211 +0,0 @@ -use super::*; - -#[test] -fn test_channel_data_encode() -> Result<()> { - let mut d = ChannelData { - data: vec![1, 2, 3, 4], - number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), - ..Default::default() - }; - d.encode(); - - let mut b = ChannelData::default(); - b.raw.extend_from_slice(&d.raw); - b.decode()?; - - assert_eq!(b, d, "not equal"); - - assert!( - ChannelData::is_channel_data(&b.raw) && ChannelData::is_channel_data(&d.raw), - "unexpected IsChannelData" - ); - - Ok(()) -} - -#[test] -fn test_channel_data_equal() -> Result<()> { - let tests = vec![ - ( - "equal", - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 3], - ..Default::default() - }, - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 3], - ..Default::default() - }, - true, - ), - ( - "number", - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), - data: vec![1, 2, 3], - ..Default::default() - }, - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 3], - ..Default::default() - }, - false, - ), - ( - "length", - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 3, 4], - ..Default::default() - }, - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 3], - ..Default::default() - }, - false, - ), - ( - "data", - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 2], - ..Default::default() - }, - ChannelData { - number: ChannelNumber(MIN_CHANNEL_NUMBER), - data: vec![1, 2, 3], - ..Default::default() - }, - false, - ), - ]; - - for (name, a, b, r) in tests { - let v = a == b; - assert_eq!(v, r, "unexpected: ({name}) {r} != {r}"); - } - - Ok(()) -} - -#[test] -fn test_channel_data_decode() -> Result<()> { - let tests = vec![ - ("small", vec![1, 2, 3], Error::ErrUnexpectedEof), - ( - "zeroes", - vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - Error::ErrInvalidChannelNumber, - ), - ( - "bad chan number", - vec![63, 255, 0, 0, 0, 4, 0, 0, 1, 2, 3, 4], - Error::ErrInvalidChannelNumber, - ), - ( - "bad length", - vec![0x40, 0x40, 0x02, 0x23, 0x16, 0, 0, 0, 0, 0, 0, 0], - Error::ErrBadChannelDataLength, - ), - ]; - - for (name, buf, want_err) in tests { - let mut m = ChannelData { - raw: buf, - ..Default::default() - }; - if let Err(err) = m.decode() { - assert_eq!(want_err, err, "unexpected: ({name}) {want_err} != {err}"); - } else { - panic!("expected error, but got ok"); - } - } - - Ok(()) -} - -#[test] -fn test_channel_data_reset() -> Result<()> { - let mut d = ChannelData { - data: vec![1, 2, 3, 4], - number: ChannelNumber(MIN_CHANNEL_NUMBER + 1), - ..Default::default() - }; - d.encode(); - let mut buf = vec![0; d.raw.len()]; - buf.copy_from_slice(&d.raw); - d.reset(); - d.raw = buf; - d.decode()?; - - Ok(()) -} - -#[test] -fn test_is_channel_data() -> Result<()> { - let tests = vec![ - ("small", vec![1, 2, 3, 4], false), - ("zeroes", vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], false), - ]; - - for (name, buf, r) in tests { - let v = ChannelData::is_channel_data(&buf); - assert_eq!(v, r, "unexpected: ({name}) {r} != {v}"); - } - - Ok(()) -} - -const CHANDATA_TEST_HEX: [&str; 2] = [ - "40000064000100502112a442453731722f2b322b6e4e7a5800060009443758343a33776c59000000c0570004000003e7802a00081d5136dab65b169300250000002400046e001eff0008001465d11a330e104a9f5f598af4abc6a805f26003cf802800046b334442", - "4000022316fefd0000000000000011012c0b000120000100000000012000011d00011a308201163081bda003020102020900afe52871340bd13e300a06082a8648ce3d0403023011310f300d06035504030c06576562525443301e170d3138303831313033353230305a170d3138303931313033353230305a3011310f300d06035504030c065765625254433059301306072a8648ce3d020106082a8648ce3d030107034200048080e348bd41469cfb7a7df316676fd72a06211765a50a0f0b07526c872dcf80093ed5caa3f5a40a725dd74b41b79bdd19ee630c5313c8601d6983286c8722c1300a06082a8648ce3d0403020348003045022100d13a0a131bc2a9f27abd3d4c547f7ef172996a0c0755c707b6a3e048d8762ded0220055fc8182818a644a3d3b5b157304cc3f1421fadb06263bfb451cd28be4bc9ee16fefd0000000000000012002d10000021000200000000002120f7e23c97df45a96e13cb3e76b37eff5e73e2aee0b6415d29443d0bd24f578b7e16fefd000000000000001300580f00004c000300000000004c040300483046022100fdbb74eab1aca1532e6ac0ab267d5b83a24bb4d5d7d504936e2785e6e388b2bd022100f6a457b9edd9ead52a9d0e9a19240b3a68b95699546c044f863cf8349bc8046214fefd000000000000001400010116fefd0001000000000004003000010000000000040aae2421e7d549632a7def8ed06898c3c5b53f5b812a963a39ab6cdd303b79bdb237f3314c1da21b", -]; - -#[test] -fn test_chrome_channel_data() -> Result<()> { - let mut data = vec![]; - let mut messages = vec![]; - - // Decoding hex data into binary. - for h in &CHANDATA_TEST_HEX { - let b = match hex::decode(h) { - Ok(b) => b, - Err(_) => return Err(Error::Other("hex decode error".to_owned())), - }; - data.push(b); - } - - // All hex streams decoded to raw binary format and stored in data slice. - // Decoding packets to messages. - for packet in data { - let mut m = ChannelData { - raw: packet, - ..Default::default() - }; - - m.decode()?; - let mut encoded = ChannelData { - data: m.data.clone(), - number: m.number, - ..Default::default() - }; - encoded.encode(); - - let mut decoded = ChannelData { - raw: encoded.raw.clone(), - ..Default::default() - }; - - decoded.decode()?; - assert_eq!(decoded, m, "should be equal"); - - messages.push(m); - } - assert_eq!(messages.len(), 2, "unexpected message slice list"); - - Ok(()) -} diff --git a/turn/src/proto/channum.rs b/turn/src/proto/channum.rs deleted file mode 100644 index 406bbcfb2..000000000 --- a/turn/src/proto/channum.rs +++ /dev/null @@ -1,70 +0,0 @@ -#[cfg(test)] -mod channnum_test; - -use std::fmt; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -// 16 bits of uint + 16 bits of RFFU = 0. -const CHANNEL_NUMBER_SIZE: usize = 4; - -// See https://tools.ietf.org/html/rfc5766#section-11: -// -// 0x4000 through 0x7FFF: These values are the allowed channel -// numbers (16,383 possible values). -pub const MIN_CHANNEL_NUMBER: u16 = 0x4000; -pub const MAX_CHANNEL_NUMBER: u16 = 0x7FFF; - -/// `ChannelNumber` represents `CHANNEL-NUMBER` attribute. Encoded as `u16`. -/// -/// The `CHANNEL-NUMBER` attribute contains the number of the channel. -/// -/// [RFC 5766 Section 14.1](https://www.rfc-editor.org/rfc/rfc5766#section-14.1). -#[derive(Default, Eq, PartialEq, Debug, Copy, Clone, Hash)] -pub struct ChannelNumber(pub u16); - -impl fmt::Display for ChannelNumber { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl Setter for ChannelNumber { - /// Adds `CHANNEL-NUMBER` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let mut v = vec![0; CHANNEL_NUMBER_SIZE]; - v[..2].copy_from_slice(&self.0.to_be_bytes()); - // v[2:4] are zeroes (RFFU = 0) - m.add(ATTR_CHANNEL_NUMBER, &v); - Ok(()) - } -} - -impl Getter for ChannelNumber { - /// Decodes `CHANNEL-NUMBER` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_CHANNEL_NUMBER)?; - - check_size(ATTR_CHANNEL_NUMBER, v.len(), CHANNEL_NUMBER_SIZE)?; - - //_ = v[CHANNEL_NUMBER_SIZE-1] // asserting length - self.0 = u16::from_be_bytes([v[0], v[1]]); - // v[2:4] is RFFU and equals to 0. - Ok(()) - } -} - -impl ChannelNumber { - /// Returns true if c in `[0x4000, 0x7FFF]`. - fn is_channel_number_valid(&self) -> bool { - self.0 >= MIN_CHANNEL_NUMBER && self.0 <= MAX_CHANNEL_NUMBER - } - - /// returns `true` if channel number has correct value that complies - /// [RFC 5766 Section 11](https://www.rfc-editor.org/rfc/rfc5766#section-11) range. - pub fn valid(&self) -> bool { - self.is_channel_number_valid() - } -} diff --git a/turn/src/proto/channum/channnum_test.rs b/turn/src/proto/channum/channnum_test.rs deleted file mode 100644 index f05756802..000000000 --- a/turn/src/proto/channum/channnum_test.rs +++ /dev/null @@ -1,81 +0,0 @@ -use super::*; - -#[test] -fn test_channel_number_string() -> Result<(), stun::Error> { - let n = ChannelNumber(112); - assert_eq!(n.to_string(), "112", "bad string {n}, expected 112"); - Ok(()) -} - -/* -#[test] -fn test_channel_number_NoAlloc() -> Result<(), stun::Error> { - let mut m = Message::default(); - - if wasAllocs(func() { - // Case with ChannelNumber on stack. - n: = ChannelNumber(6) - n.AddTo(m) //nolint - m.Reset() - }) { - t.Error("Unexpected allocations") - } - - n: = ChannelNumber(12) - nP: = &n - if wasAllocs(func() { - // On heap. - nP.AddTo(m) //nolint - m.Reset() - }) { - t.Error("Unexpected allocations") - } - Ok(()) -} -*/ - -#[test] -fn test_channel_number_add_to() -> Result<(), stun::Error> { - let mut m = Message::new(); - let n = ChannelNumber(6); - n.add_to(&mut m)?; - m.write_header(); - - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - - let mut num_decoded = ChannelNumber::default(); - num_decoded.get_from(&decoded)?; - assert_eq!(num_decoded, n, "Decoded {num_decoded}, expected {n}"); - - //"HandleErr" - { - let mut m = Message::new(); - let mut n_handle = ChannelNumber::default(); - if let Err(err) = n_handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } else { - panic!("expected error, but got ok"); - } - - m.add(ATTR_CHANNEL_NUMBER, &[1, 2, 3]); - - if let Err(err) = n_handle.get_from(&m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, but got ok"); - } - } - } - - Ok(()) -} diff --git a/turn/src/proto/data.rs b/turn/src/proto/data.rs deleted file mode 100644 index 03fdf9b4d..000000000 --- a/turn/src/proto/data.rs +++ /dev/null @@ -1,33 +0,0 @@ -#[cfg(test)] -mod data_test; - -use stun::attributes::*; -use stun::message::*; - -/// `Data` represents `DATA` attribute. -/// -/// The `DATA` attribute is present in all Send and Data indications. The -/// value portion of this attribute is variable length and consists of -/// the application data (that is, the data that would immediately follow -/// the UDP header if the data was been sent directly between the client -/// and the peer). -/// -/// [RFC 5766 Section 14.4](https://www.rfc-editor.org/rfc/rfc5766#section-14.4). -#[derive(Default, Debug, PartialEq, Eq)] -pub struct Data(pub Vec); - -impl Setter for Data { - /// Adds `DATA` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - m.add(ATTR_DATA, &self.0); - Ok(()) - } -} - -impl Getter for Data { - /// Decodes `DATA` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - self.0 = m.get(ATTR_DATA)?; - Ok(()) - } -} diff --git a/turn/src/proto/data/data_test.rs b/turn/src/proto/data/data_test.rs deleted file mode 100644 index 813047af2..000000000 --- a/turn/src/proto/data/data_test.rs +++ /dev/null @@ -1,33 +0,0 @@ -use super::*; - -#[test] -fn test_data_add_to() -> Result<(), stun::Error> { - let mut m = Message::new(); - let d = Data(vec![1, 2, 33, 44, 0x13, 0xaf]); - d.add_to(&mut m)?; - m.write_header(); - - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - - let mut data_decoded = Data::default(); - data_decoded.get_from(&decoded)?; - assert_eq!(data_decoded, d); - - //"HandleErr" - { - let m = Message::new(); - let mut handle = Data::default(); - if let Err(err) = handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } - } - } - Ok(()) -} diff --git a/turn/src/proto/dontfrag.rs b/turn/src/proto/dontfrag.rs deleted file mode 100644 index 621586ff2..000000000 --- a/turn/src/proto/dontfrag.rs +++ /dev/null @@ -1,25 +0,0 @@ -#[cfg(test)] -mod dontfrag_test; - -use stun::attributes::*; -use stun::message::*; - -/// `DontFragmentAttr` represents `DONT-FRAGMENT` attribute. -#[derive(Debug, Default, PartialEq, Eq)] -pub struct DontFragmentAttr; - -impl Setter for DontFragmentAttr { - /// Adds `DONT-FRAGMENT` attribute to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - m.add(ATTR_DONT_FRAGMENT, &[]); - Ok(()) - } -} - -impl Getter for DontFragmentAttr { - /// Returns true if `DONT-FRAGMENT` attribute is set. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let _ = m.get(ATTR_DONT_FRAGMENT)?; - Ok(()) - } -} diff --git a/turn/src/proto/dontfrag/dontfrag_test.rs b/turn/src/proto/dontfrag/dontfrag_test.rs deleted file mode 100644 index a981cd778..000000000 --- a/turn/src/proto/dontfrag/dontfrag_test.rs +++ /dev/null @@ -1,27 +0,0 @@ -use super::*; - -#[test] -fn test_dont_fragment_false() -> Result<(), stun::Error> { - let mut dont_fragment = DontFragmentAttr; - - let mut m = Message::new(); - m.write_header(); - assert!(dont_fragment.get_from(&m).is_err(), "should not be set"); - - Ok(()) -} - -#[test] -fn test_dont_fragment_add_to() -> Result<(), stun::Error> { - let mut dont_fragment = DontFragmentAttr; - - let mut m = Message::new(); - dont_fragment.add_to(&mut m)?; - m.write_header(); - - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - assert!(dont_fragment.get_from(&m).is_ok(), "should be set"); - - Ok(()) -} diff --git a/turn/src/proto/evenport.rs b/turn/src/proto/evenport.rs deleted file mode 100644 index fe364d18a..000000000 --- a/turn/src/proto/evenport.rs +++ /dev/null @@ -1,63 +0,0 @@ -#[cfg(test)] -mod evenport_test; - -use std::fmt; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -/// `EvenPort` represents `EVEN-PORT` attribute. -/// -/// This attribute allows the client to request that the port in the -/// relayed transport address be even, and (optionally) that the server -/// reserve the next-higher port number. -/// -/// [RFC 5766 Section 14.6](https://www.rfc-editor.org/rfc/rfc5766#section-14.6). -#[derive(Default, Debug, PartialEq, Eq)] -pub struct EvenPort { - /// `reserve_port` means that the server is requested to reserve - /// the next-higher port number (on the same IP address) - /// for a subsequent allocation. - reserve_port: bool, -} - -impl fmt::Display for EvenPort { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.reserve_port { - write!(f, "reserve: true") - } else { - write!(f, "reserve: false") - } - } -} - -const EVEN_PORT_SIZE: usize = 1; -const FIRST_BIT_SET: u8 = 0b10000000; //FIXME? (1 << 8) - 1; - -impl Setter for EvenPort { - /// Adds `EVEN-PORT` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let mut v = vec![0; EVEN_PORT_SIZE]; - if self.reserve_port { - // Set first bit to 1. - v[0] = FIRST_BIT_SET; - } - m.add(ATTR_EVEN_PORT, &v); - Ok(()) - } -} - -impl Getter for EvenPort { - /// Decodes `EVEN-PORT` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_EVEN_PORT)?; - - check_size(ATTR_EVEN_PORT, v.len(), EVEN_PORT_SIZE)?; - - if v[0] & FIRST_BIT_SET > 0 { - self.reserve_port = true; - } - Ok(()) - } -} diff --git a/turn/src/proto/evenport/evenport_test.rs b/turn/src/proto/evenport/evenport_test.rs deleted file mode 100644 index 57cc8bd9f..000000000 --- a/turn/src/proto/evenport/evenport_test.rs +++ /dev/null @@ -1,78 +0,0 @@ -use super::*; - -#[test] -fn test_even_port_string() -> Result<(), stun::Error> { - let mut p = EvenPort::default(); - assert_eq!( - p.to_string(), - "reserve: false", - "bad value {p} for reselve: false" - ); - - p.reserve_port = true; - assert_eq!( - p.to_string(), - "reserve: true", - "bad value {p} for reselve: true" - ); - - Ok(()) -} - -#[test] -fn test_even_port_false() -> Result<(), stun::Error> { - let mut m = Message::new(); - let p = EvenPort { - reserve_port: false, - }; - p.add_to(&mut m)?; - m.write_header(); - - let mut decoded = Message::new(); - let mut port = EvenPort::default(); - decoded.write(&m.raw)?; - port.get_from(&m)?; - assert_eq!(port, p); - - Ok(()) -} - -#[test] -fn test_even_port_add_to() -> Result<(), stun::Error> { - let mut m = Message::new(); - let p = EvenPort { reserve_port: true }; - p.add_to(&mut m)?; - m.write_header(); - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - let mut port = EvenPort::default(); - port.get_from(&decoded)?; - assert_eq!(port, p, "Decoded {port}, expected {p}"); - - //"HandleErr" - { - let mut m = Message::new(); - let mut handle = EvenPort::default(); - if let Err(err) = handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } - m.add(ATTR_EVEN_PORT, &[1, 2, 3]); - if let Err(err) = handle.get_from(&m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, but got ok"); - } - } - } - - Ok(()) -} diff --git a/turn/src/proto/lifetime.rs b/turn/src/proto/lifetime.rs deleted file mode 100644 index 449bc277e..000000000 --- a/turn/src/proto/lifetime.rs +++ /dev/null @@ -1,59 +0,0 @@ -#[cfg(test)] -mod lifetime_test; - -use std::fmt; -use std::time::Duration; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -/// `DEFAULT_LIFETIME` in RFC 5766 is 10 minutes. -/// -/// [RFC 5766 Section 2.2](https://www.rfc-editor.org/rfc/rfc5766#section-2.2). -pub const DEFAULT_LIFETIME: Duration = Duration::from_secs(10 * 60); - -/// `Lifetime` represents `LIFETIME` attribute. -/// -/// The `LIFETIME` attribute represents the duration for which the server -/// will maintain an allocation in the absence of a refresh. The value -/// portion of this attribute is 4-bytes long and consists of a 32-bit -/// unsigned integral value representing the number of seconds remaining -/// until expiration. -/// -/// [RFC 5766 Section 14.2](https://www.rfc-editor.org/rfc/rfc5766#section-14.2). -#[derive(Default, Debug, PartialEq, Eq)] -pub struct Lifetime(pub Duration); - -impl fmt::Display for Lifetime { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}s", self.0.as_secs()) - } -} - -// uint32 seconds -const LIFETIME_SIZE: usize = 4; // 4 bytes, 32 bits - -impl Setter for Lifetime { - /// Adds `LIFETIME` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let mut v = vec![0; LIFETIME_SIZE]; - v.copy_from_slice(&(self.0.as_secs() as u32).to_be_bytes()); - m.add(ATTR_LIFETIME, &v); - Ok(()) - } -} - -impl Getter for Lifetime { - /// Decodes `LIFETIME` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_LIFETIME)?; - - check_size(ATTR_LIFETIME, v.len(), LIFETIME_SIZE)?; - - let seconds = u32::from_be_bytes([v[0], v[1], v[2], v[3]]); - self.0 = Duration::from_secs(seconds as u64); - - Ok(()) - } -} diff --git a/turn/src/proto/lifetime/lifetime_test.rs b/turn/src/proto/lifetime/lifetime_test.rs deleted file mode 100644 index 2ec54ac31..000000000 --- a/turn/src/proto/lifetime/lifetime_test.rs +++ /dev/null @@ -1,54 +0,0 @@ -use super::*; - -#[test] -fn test_lifetime_string() -> Result<(), stun::Error> { - let l = Lifetime(Duration::from_secs(10)); - assert_eq!(l.to_string(), "10s", "bad string {l}, expected 10s"); - - Ok(()) -} - -#[test] -fn test_lifetime_add_to() -> Result<(), stun::Error> { - let mut m = Message::new(); - let l = Lifetime(Duration::from_secs(10)); - l.add_to(&mut m)?; - m.write_header(); - - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - - let mut life = Lifetime::default(); - life.get_from(&decoded)?; - assert_eq!(life, l, "Decoded {life}, expected {l}"); - - //"HandleErr" - { - let mut m = Message::new(); - let mut n_handle = Lifetime::default(); - if let Err(err) = n_handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } else { - panic!("expected error, but got ok"); - } - m.add(ATTR_LIFETIME, &[1, 2, 3]); - - if let Err(err) = n_handle.get_from(&m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, but got ok"); - } - } - } - - Ok(()) -} diff --git a/turn/src/proto/mod.rs b/turn/src/proto/mod.rs deleted file mode 100644 index ada03ffd8..000000000 --- a/turn/src/proto/mod.rs +++ /dev/null @@ -1,70 +0,0 @@ -#[cfg(test)] -mod proto_test; - -pub mod addr; -pub mod chandata; -pub mod channum; -pub mod data; -pub mod dontfrag; -pub mod evenport; -pub mod lifetime; -pub mod peeraddr; -pub mod relayaddr; -pub mod reqfamily; -pub mod reqtrans; -pub mod rsrvtoken; - -use std::fmt; - -use stun::message::*; - -// proto implements RFC 5766 Traversal Using Relays around NAT. - -/// `Protocol` is IANA assigned protocol number. -#[derive(PartialEq, Eq, Default, Debug, Clone, Copy, Hash)] -pub struct Protocol(pub u8); - -/// `PROTO_TCP` is IANA assigned protocol number for TCP. -pub const PROTO_TCP: Protocol = Protocol(6); -/// `PROTO_UDP` is IANA assigned protocol number for UDP. -pub const PROTO_UDP: Protocol = Protocol(17); - -impl fmt::Display for Protocol { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let others = format!("{}", self.0); - let s = match *self { - PROTO_UDP => "UDP", - PROTO_TCP => "TCP", - _ => others.as_str(), - }; - - write!(f, "{s}") - } -} - -// Default ports for TURN from RFC 5766 Section 4. - -/// `DEFAULT_PORT` for TURN is same as STUN. -pub const DEFAULT_PORT: u16 = stun::DEFAULT_PORT; -/// `DEFAULT_TLSPORT` is for TURN over TLS and is same as STUN. -pub const DEFAULT_TLS_PORT: u16 = stun::DEFAULT_TLS_PORT; - -/// Shorthand for create permission request type. -pub fn create_permission_request() -> MessageType { - MessageType::new(METHOD_CREATE_PERMISSION, CLASS_REQUEST) -} - -/// Shorthand for allocation request message type. -pub fn allocate_request() -> MessageType { - MessageType::new(METHOD_ALLOCATE, CLASS_REQUEST) -} - -/// Shorthand for send indication message type. -pub fn send_indication() -> MessageType { - MessageType::new(METHOD_SEND, CLASS_INDICATION) -} - -/// Shorthand for refresh request message type. -pub fn refresh_request() -> MessageType { - MessageType::new(METHOD_REFRESH, CLASS_REQUEST) -} diff --git a/turn/src/proto/peeraddr.rs b/turn/src/proto/peeraddr.rs deleted file mode 100644 index 58c1f20a3..000000000 --- a/turn/src/proto/peeraddr.rs +++ /dev/null @@ -1,71 +0,0 @@ -#[cfg(test)] -mod peeraddr_test; - -use std::fmt; -use std::net::{IpAddr, Ipv4Addr}; - -use stun::attributes::*; -use stun::message::*; -use stun::xoraddr::*; - -/// `PeerAddress` implements `XOR-PEER-ADDRESS` attribute. -/// -/// The `XOR-PEER-ADDRESS` specifies the address and port of the peer as -/// seen from the TURN server. (For example, the peer's server-reflexive -/// transport address if the peer is behind a NAT.) -/// -/// [RFC 5766 Section 14.3](https://www.rfc-editor.org/rfc/rfc5766#section-14.3). -#[derive(PartialEq, Eq, Debug)] -pub struct PeerAddress { - pub ip: IpAddr, - pub port: u16, -} - -impl Default for PeerAddress { - fn default() -> Self { - PeerAddress { - ip: IpAddr::V4(Ipv4Addr::from(0)), - port: 0, - } - } -} - -impl fmt::Display for PeerAddress { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.ip { - IpAddr::V4(_) => write!(f, "{}:{}", self.ip, self.port), - IpAddr::V6(_) => write!(f, "[{}]:{}", self.ip, self.port), - } - } -} - -impl Setter for PeerAddress { - /// Adds `XOR-PEER-ADDRESS` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let a = XorMappedAddress { - ip: self.ip, - port: self.port, - }; - a.add_to_as(m, ATTR_XOR_PEER_ADDRESS) - } -} - -impl Getter for PeerAddress { - /// Decodes `XOR-PEER-ADDRESS` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let mut a = XorMappedAddress::default(); - a.get_from_as(m, ATTR_XOR_PEER_ADDRESS)?; - self.ip = a.ip; - self.port = a.port; - Ok(()) - } -} - -/// `PeerAddress` implements `XOR-PEER-ADDRESS` attribute. -/// -/// The `XOR-PEER-ADDRESS` specifies the address and port of the peer as -/// seen from the TURN server. (For example, the peer's server-reflexive -/// transport address if the peer is behind a NAT.) -/// -/// [RFC 5766 Section 14.3](https://www.rfc-editor.org/rfc/rfc5766#section-14.3). -pub type XorPeerAddress = PeerAddress; diff --git a/turn/src/proto/peeraddr/peeraddr_test.rs b/turn/src/proto/peeraddr/peeraddr_test.rs deleted file mode 100644 index 353788c99..000000000 --- a/turn/src/proto/peeraddr/peeraddr_test.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::net::Ipv4Addr; - -use super::*; - -#[test] -fn test_peer_address() -> Result<(), stun::Error> { - // Simple tests because already tested in stun. - let a = PeerAddress { - ip: IpAddr::V4(Ipv4Addr::new(111, 11, 1, 2)), - port: 333, - }; - - assert_eq!(a.to_string(), "111.11.1.2:333", "invalid string"); - - let mut m = Message::new(); - a.add_to(&mut m)?; - m.write_header(); - - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - - let mut a_got = PeerAddress::default(); - a_got.get_from(&decoded)?; - - Ok(()) -} diff --git a/turn/src/proto/proto_test.rs b/turn/src/proto/proto_test.rs deleted file mode 100644 index c40d1a755..000000000 --- a/turn/src/proto/proto_test.rs +++ /dev/null @@ -1,35 +0,0 @@ -use super::*; -use crate::error::*; - -const CHROME_ALLOC_REQ_TEST_HEX: [&str; 4] = [ - "000300242112a442626b4a6849664c3630526863802f0016687474703a2f2f6c6f63616c686f73743a333030302f00000019000411000000", - "011300582112a442626b4a6849664c36305268630009001000000401556e617574686f72697a656400150010356130323039623563623830363130360014000b61312e63796465762e7275758022001a436f7475726e2d342e352e302e33202764616e204569646572272300", - "0003006c2112a442324e50695a437a4634535034802f0016687474703a2f2f6c6f63616c686f73743a333030302f000000190004110000000006000665726e61646f00000014000b61312e63796465762e7275000015001035613032303962356362383036313036000800145c8743f3b64bec0880cdd8d476d37b801a6c3d33", - "010300582112a442324e50695a437a4634535034001600080001fb922b1ab211002000080001adb2f49f38ae000d0004000002588022001a436f7475726e2d342e352e302e33202764616e204569646572277475000800145d7e85b767a519ffce91dbf0a96775e370db92e3", -]; - -#[test] -fn test_chrome_alloc_request() -> Result<()> { - let mut data = vec![]; - let mut messages = vec![]; - - // Decoding hex data into binary. - for h in &CHROME_ALLOC_REQ_TEST_HEX { - let b = match hex::decode(h) { - Ok(b) => b, - Err(_) => return Err(Error::Other("hex decode error".to_owned())), - }; - data.push(b); - } - - // All hex streams decoded to raw binary format and stored in data slice. - // Decoding packets to messages. - for packet in data { - let mut m = Message::new(); - m.write(&packet)?; - messages.push(m); - } - assert_eq!(messages.len(), 4, "unexpected message slice list"); - - Ok(()) -} diff --git a/turn/src/proto/relayaddr.rs b/turn/src/proto/relayaddr.rs deleted file mode 100644 index a5675e16f..000000000 --- a/turn/src/proto/relayaddr.rs +++ /dev/null @@ -1,69 +0,0 @@ -#[cfg(test)] -mod relayaddr_test; - -use std::fmt; -use std::net::{IpAddr, Ipv4Addr}; - -use stun::attributes::*; -use stun::message::*; -use stun::xoraddr::*; - -/// `RelayedAddress` implements `XOR-RELAYED-ADDRESS` attribute. -/// -/// It specifies the address and port that the server allocated to the -/// client. It is encoded in the same way as `XOR-MAPPED-ADDRESS`. -/// -/// [RFC 5766 Section 14.5](https://www.rfc-editor.org/rfc/rfc5766#section-14.5). -#[derive(PartialEq, Eq, Debug)] -pub struct RelayedAddress { - pub ip: IpAddr, - pub port: u16, -} - -impl Default for RelayedAddress { - fn default() -> Self { - RelayedAddress { - ip: IpAddr::V4(Ipv4Addr::from(0)), - port: 0, - } - } -} - -impl fmt::Display for RelayedAddress { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.ip { - IpAddr::V4(_) => write!(f, "{}:{}", self.ip, self.port), - IpAddr::V6(_) => write!(f, "[{}]:{}", self.ip, self.port), - } - } -} - -impl Setter for RelayedAddress { - /// Adds `XOR-PEER-ADDRESS` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let a = XorMappedAddress { - ip: self.ip, - port: self.port, - }; - a.add_to_as(m, ATTR_XOR_RELAYED_ADDRESS) - } -} - -impl Getter for RelayedAddress { - /// Decodes `XOR-PEER-ADDRESS` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let mut a = XorMappedAddress::default(); - a.get_from_as(m, ATTR_XOR_RELAYED_ADDRESS)?; - self.ip = a.ip; - self.port = a.port; - Ok(()) - } -} - -/// `XorRelayedAddress` implements `XOR-RELAYED-ADDRESS` attribute. -/// -/// It specifies the address and port that the server allocated to the -/// client. It is encoded in the same way as `XOR-MAPPED-ADDRESS`. -/// -/// [RFC 5766 Section 14.5](https://www.rfc-editor.org/rfc/rfc5766#section-14.5). -pub type XorRelayedAddress = RelayedAddress; diff --git a/turn/src/proto/relayaddr/relayaddr_test.rs b/turn/src/proto/relayaddr/relayaddr_test.rs deleted file mode 100644 index 8ea3e04bb..000000000 --- a/turn/src/proto/relayaddr/relayaddr_test.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::net::Ipv4Addr; - -use super::*; - -#[test] -fn test_relayed_address() -> Result<(), stun::Error> { - // Simple tests because already tested in stun. - let a = RelayedAddress { - ip: IpAddr::V4(Ipv4Addr::new(111, 11, 1, 2)), - port: 333, - }; - - assert_eq!(a.to_string(), "111.11.1.2:333", "invalid string"); - - let mut m = Message::new(); - a.add_to(&mut m)?; - m.write_header(); - - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - - let mut a_got = RelayedAddress::default(); - a_got.get_from(&decoded)?; - - Ok(()) -} diff --git a/turn/src/proto/reqfamily.rs b/turn/src/proto/reqfamily.rs deleted file mode 100644 index 409ddc516..000000000 --- a/turn/src/proto/reqfamily.rs +++ /dev/null @@ -1,61 +0,0 @@ -#[cfg(test)] -mod reqfamily_test; - -use std::fmt; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -// Values for RequestedAddressFamily as defined in RFC 6156 Section 4.1.1. -pub const REQUESTED_FAMILY_IPV4: RequestedAddressFamily = RequestedAddressFamily(0x01); -pub const REQUESTED_FAMILY_IPV6: RequestedAddressFamily = RequestedAddressFamily(0x02); - -/// `RequestedAddressFamily` represents the `REQUESTED-ADDRESS-FAMILY` Attribute as -/// defined in [RFC 6156 Section 4.1.1](https://www.rfc-editor.org/rfc/rfc6156#section-4.1.1). -#[derive(Debug, Default, PartialEq, Eq)] -pub struct RequestedAddressFamily(pub u8); - -impl fmt::Display for RequestedAddressFamily { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - REQUESTED_FAMILY_IPV4 => "IPv4", - REQUESTED_FAMILY_IPV6 => "IPv6", - _ => "unknown", - }; - write!(f, "{s}") - } -} - -const REQUESTED_FAMILY_SIZE: usize = 4; - -impl Setter for RequestedAddressFamily { - /// Adds `REQUESTED-ADDRESS-FAMILY` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let mut v = vec![0; REQUESTED_FAMILY_SIZE]; - v[0] = self.0; - // b[1:4] is RFFU = 0. - // The RFFU field MUST be set to zero on transmission and MUST be - // ignored on reception. It is reserved for future uses. - m.add(ATTR_REQUESTED_ADDRESS_FAMILY, &v); - Ok(()) - } -} - -impl Getter for RequestedAddressFamily { - /// Decodes `REQUESTED-ADDRESS-FAMILY` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_REQUESTED_ADDRESS_FAMILY)?; - check_size( - ATTR_REQUESTED_ADDRESS_FAMILY, - v.len(), - REQUESTED_FAMILY_SIZE, - )?; - - if v[0] != REQUESTED_FAMILY_IPV4.0 && v[0] != REQUESTED_FAMILY_IPV6.0 { - return Err(stun::Error::Other("ErrInvalidRequestedFamilyValue".into())); - } - self.0 = v[0]; - Ok(()) - } -} diff --git a/turn/src/proto/reqfamily/reqfamily_test.rs b/turn/src/proto/reqfamily/reqfamily_test.rs deleted file mode 100644 index aa56bef35..000000000 --- a/turn/src/proto/reqfamily/reqfamily_test.rs +++ /dev/null @@ -1,77 +0,0 @@ -use super::*; - -#[test] -fn test_requested_address_family_string() -> Result<(), stun::Error> { - assert_eq!( - REQUESTED_FAMILY_IPV4.to_string(), - "IPv4", - "bad string {}, expected {}", - REQUESTED_FAMILY_IPV4, - "IPv4" - ); - - assert_eq!( - REQUESTED_FAMILY_IPV6.to_string(), - "IPv6", - "bad string {}, expected {}", - REQUESTED_FAMILY_IPV6, - "IPv6" - ); - - assert_eq!( - RequestedAddressFamily(0x04).to_string(), - "unknown", - "should be unknown" - ); - - Ok(()) -} - -#[test] -fn test_requested_address_family_add_to() -> Result<(), stun::Error> { - let mut m = Message::new(); - let r = REQUESTED_FAMILY_IPV4; - r.add_to(&mut m)?; - m.write_header(); - - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - let mut req = RequestedAddressFamily::default(); - req.get_from(&decoded)?; - assert_eq!(req, r, "Decoded {req}, expected {r}"); - - //"HandleErr" - { - let mut m = Message::new(); - let mut handle = RequestedAddressFamily::default(); - if let Err(err) = handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } else { - panic!("expected error, but got ok"); - } - m.add(ATTR_REQUESTED_ADDRESS_FAMILY, &[1, 2, 3]); - if let Err(err) = handle.get_from(&m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, but got ok"); - } - m.reset(); - m.add(ATTR_REQUESTED_ADDRESS_FAMILY, &[5, 0, 0, 0]); - assert!( - handle.get_from(&m).is_err(), - "should error on invalid value" - ); - } - } - - Ok(()) -} diff --git a/turn/src/proto/reqtrans.rs b/turn/src/proto/reqtrans.rs deleted file mode 100644 index 0d2a64349..000000000 --- a/turn/src/proto/reqtrans.rs +++ /dev/null @@ -1,54 +0,0 @@ -#[cfg(test)] -mod reqtrans_test; - -use std::fmt; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -use super::*; - -/// `RequestedTransport` represents `REQUESTED-TRANSPORT` attribute. -/// -/// This attribute is used by the client to request a specific transport -/// protocol for the allocated transport address. RFC 5766 only allows the use of -/// codepoint 17 (User Datagram protocol). -/// -/// [RFC 5766 Section 14.7](https://www.rfc-editor.org/rfc/rfc5766#section-14.7). -#[derive(Default, Debug, PartialEq, Eq)] -pub struct RequestedTransport { - pub protocol: Protocol, -} - -impl fmt::Display for RequestedTransport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "protocol: {}", self.protocol) - } -} - -const REQUESTED_TRANSPORT_SIZE: usize = 4; - -impl Setter for RequestedTransport { - /// Adds `REQUESTED-TRANSPORT` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - let mut v = vec![0; REQUESTED_TRANSPORT_SIZE]; - v[0] = self.protocol.0; - // b[1:4] is RFFU = 0. - // The RFFU field MUST be set to zero on transmission and MUST be - // ignored on reception. It is reserved for future uses. - m.add(ATTR_REQUESTED_TRANSPORT, &v); - Ok(()) - } -} - -impl Getter for RequestedTransport { - /// Decodes `REQUESTED-TRANSPORT` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_REQUESTED_TRANSPORT)?; - - check_size(ATTR_REQUESTED_TRANSPORT, v.len(), REQUESTED_TRANSPORT_SIZE)?; - self.protocol = Protocol(v[0]); - Ok(()) - } -} diff --git a/turn/src/proto/reqtrans/reqtrans_test.rs b/turn/src/proto/reqtrans/reqtrans_test.rs deleted file mode 100644 index 30a3e9463..000000000 --- a/turn/src/proto/reqtrans/reqtrans_test.rs +++ /dev/null @@ -1,75 +0,0 @@ -use super::*; - -#[test] -fn test_requested_transport_string() -> Result<(), stun::Error> { - let mut r = RequestedTransport { - protocol: PROTO_UDP, - }; - assert_eq!( - r.to_string(), - "protocol: UDP", - "bad string {}, expected {}", - r, - "protocol: UDP", - ); - r.protocol = Protocol(254); - if r.to_string() != "protocol: 254" { - assert_eq!( - r.to_string(), - "protocol: UDP", - "bad string {}, expected {}", - r, - "protocol: 254", - ); - } - - Ok(()) -} - -#[test] -fn test_requested_transport_add_to() -> Result<(), stun::Error> { - let mut m = Message::new(); - let r = RequestedTransport { - protocol: PROTO_UDP, - }; - r.add_to(&mut m)?; - m.write_header(); - - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - let mut req = RequestedTransport { - protocol: PROTO_UDP, - }; - req.get_from(&decoded)?; - assert_eq!(req, r, "Decoded {req}, expected {r}"); - - //"HandleErr" - { - let mut m = Message::new(); - let mut handle = RequestedTransport::default(); - if let Err(err) = handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } else { - panic!("expected error, got ok"); - } - - m.add(ATTR_REQUESTED_TRANSPORT, &[1, 2, 3]); - if let Err(err) = handle.get_from(&m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, got ok"); - } - } - } - - Ok(()) -} diff --git a/turn/src/proto/rsrvtoken.rs b/turn/src/proto/rsrvtoken.rs deleted file mode 100644 index 6e4ac9663..000000000 --- a/turn/src/proto/rsrvtoken.rs +++ /dev/null @@ -1,40 +0,0 @@ -#[cfg(test)] -mod rsrvtoken_test; - -use stun::attributes::*; -use stun::checks::*; -use stun::message::*; - -/// `ReservationToken` represents `RESERVATION-TOKEN` attribute. -/// -/// The `RESERVATION-TOKEN` attribute contains a token that uniquely -/// identifies a relayed transport address being held in reserve by the -/// server. The server includes this attribute in a success response to -/// tell the client about the token, and the client includes this -/// attribute in a subsequent Allocate request to request the server use -/// that relayed transport address for the allocation. -/// -/// [RFC 5766 Section 14.9](https://www.rfc-editor.org/rfc/rfc5766#section-14.9). -#[derive(Debug, Default, PartialEq, Eq)] -pub struct ReservationToken(pub Vec); - -const RESERVATION_TOKEN_SIZE: usize = 8; // 8 bytes - -impl Setter for ReservationToken { - /// Adds `RESERVATION-TOKEN` to message. - fn add_to(&self, m: &mut Message) -> Result<(), stun::Error> { - check_size(ATTR_RESERVATION_TOKEN, self.0.len(), RESERVATION_TOKEN_SIZE)?; - m.add(ATTR_RESERVATION_TOKEN, &self.0); - Ok(()) - } -} - -impl Getter for ReservationToken { - /// Decodes `RESERVATION-TOKEN` from message. - fn get_from(&mut self, m: &Message) -> Result<(), stun::Error> { - let v = m.get(ATTR_RESERVATION_TOKEN)?; - check_size(ATTR_RESERVATION_TOKEN, v.len(), RESERVATION_TOKEN_SIZE)?; - self.0 = v; - Ok(()) - } -} diff --git a/turn/src/proto/rsrvtoken/rsrvtoken_test.rs b/turn/src/proto/rsrvtoken/rsrvtoken_test.rs deleted file mode 100644 index 4990d8982..000000000 --- a/turn/src/proto/rsrvtoken/rsrvtoken_test.rs +++ /dev/null @@ -1,60 +0,0 @@ -use super::*; - -#[test] -fn test_reservation_token() -> Result<(), stun::Error> { - let mut m = Message::new(); - let mut v = vec![0; 8]; - v[2] = 33; - v[7] = 1; - let tk = ReservationToken(v); - tk.add_to(&mut m)?; - m.write_header(); - - //"HandleErr" - { - let bad_tk = ReservationToken(vec![34, 45]); - if let Err(err) = bad_tk.add_to(&mut m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, but got ok"); - } - } - - //"GetFrom" - { - let mut decoded = Message::new(); - decoded.write(&m.raw)?; - let mut tok = ReservationToken::default(); - tok.get_from(&decoded)?; - assert_eq!(tok, tk, "Decoded {tok:?}, expected {tk:?}"); - - //"HandleErr" - { - let mut m = Message::new(); - let mut handle = ReservationToken::default(); - if let Err(err) = handle.get_from(&m) { - assert_eq!( - stun::Error::ErrAttributeNotFound, - err, - "{err} should be not found" - ); - } else { - panic!("expected error, but got ok"); - } - m.add(ATTR_RESERVATION_TOKEN, &[1, 2, 3]); - if let Err(err) = handle.get_from(&m) { - assert!( - is_attr_size_invalid(&err), - "IsAttrSizeInvalid should be true" - ); - } else { - panic!("expected error, got ok"); - } - } - } - - Ok(()) -} diff --git a/turn/src/relay/mod.rs b/turn/src/relay/mod.rs deleted file mode 100644 index 928db108a..000000000 --- a/turn/src/relay/mod.rs +++ /dev/null @@ -1,26 +0,0 @@ -pub mod relay_none; -pub mod relay_range; -pub mod relay_static; - -use std::net::SocketAddr; -use std::sync::Arc; - -use async_trait::async_trait; -use util::Conn; - -use crate::error::Result; - -/// `RelayAddressGenerator` is used to generate a Relay Address when creating an allocation. -/// You can use one of the provided ones or provide your own. -#[async_trait] -pub trait RelayAddressGenerator { - /// Confirms that this is properly initialized - fn validate(&self) -> Result<()>; - - /// Allocates a Relay Address - async fn allocate_conn( - &self, - use_ipv4: bool, - requested_port: u16, - ) -> Result<(Arc, SocketAddr)>; -} diff --git a/turn/src/relay/relay_none.rs b/turn/src/relay/relay_none.rs deleted file mode 100644 index 1cf0a54a6..000000000 --- a/turn/src/relay/relay_none.rs +++ /dev/null @@ -1,37 +0,0 @@ -use async_trait::async_trait; -use util::vnet::net::*; - -use super::*; -use crate::error::*; - -/// `RelayAddressGeneratorNone` returns the listener with no modifications. -pub struct RelayAddressGeneratorNone { - /// `address` is passed to Listen/ListenPacket when creating the Relay. - pub address: String, - pub net: Arc, -} - -#[async_trait] -impl RelayAddressGenerator for RelayAddressGeneratorNone { - fn validate(&self) -> Result<()> { - if self.address.is_empty() { - Err(Error::ErrListeningAddressInvalid) - } else { - Ok(()) - } - } - - async fn allocate_conn( - &self, - use_ipv4: bool, - requested_port: u16, - ) -> Result<(Arc, SocketAddr)> { - let addr = self - .net - .resolve_addr(use_ipv4, &format!("{}:{}", self.address, requested_port)) - .await?; - let conn = self.net.bind(addr).await?; - let relay_addr = conn.local_addr()?; - Ok((conn, relay_addr)) - } -} diff --git a/turn/src/relay/relay_range.rs b/turn/src/relay/relay_range.rs deleted file mode 100644 index c612100bd..000000000 --- a/turn/src/relay/relay_range.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::net::IpAddr; - -use async_trait::async_trait; -use util::vnet::net::*; - -use super::*; -use crate::error::*; - -/// `RelayAddressGeneratorRanges` can be used to only allocate connections inside a defined port range. -pub struct RelayAddressGeneratorRanges { - /// `relay_address` is the IP returned to the user when the relay is created. - pub relay_address: IpAddr, - - /// `min_port` the minimum port to allocate. - pub min_port: u16, - - /// `max_port` the maximum (inclusive) port to allocate. - pub max_port: u16, - - /// `max_retries` the amount of tries to allocate a random port in the defined range. - pub max_retries: u16, - - /// `address` is passed to Listen/ListenPacket when creating the Relay. - pub address: String, - - pub net: Arc, -} - -#[async_trait] -impl RelayAddressGenerator for RelayAddressGeneratorRanges { - fn validate(&self) -> Result<()> { - if self.min_port == 0 { - Err(Error::ErrMinPortNotZero) - } else if self.max_port == 0 { - Err(Error::ErrMaxPortNotZero) - } else if self.max_port < self.min_port { - Err(Error::ErrMaxPortLessThanMinPort) - } else if self.address.is_empty() { - Err(Error::ErrListeningAddressInvalid) - } else { - Ok(()) - } - } - - async fn allocate_conn( - &self, - use_ipv4: bool, - requested_port: u16, - ) -> Result<(Arc, SocketAddr)> { - let max_retries = if self.max_retries == 0 { - 10 - } else { - self.max_retries - }; - - if requested_port != 0 { - let addr = self - .net - .resolve_addr(use_ipv4, &format!("{}:{}", self.address, requested_port)) - .await?; - let conn = self.net.bind(addr).await?; - let mut relay_addr = conn.local_addr()?; - relay_addr.set_ip(self.relay_address); - return Ok((conn, relay_addr)); - } - - for _ in 0..max_retries { - let port = self.min_port + rand::random::() % (self.max_port - self.min_port + 1); - let addr = self - .net - .resolve_addr(use_ipv4, &format!("{}:{}", self.address, port)) - .await?; - let conn = match self.net.bind(addr).await { - Ok(conn) => conn, - Err(_) => continue, - }; - - let mut relay_addr = conn.local_addr()?; - relay_addr.set_ip(self.relay_address); - return Ok((conn, relay_addr)); - } - - Err(Error::ErrMaxRetriesExceeded) - } -} diff --git a/turn/src/relay/relay_static.rs b/turn/src/relay/relay_static.rs deleted file mode 100644 index 72056691c..000000000 --- a/turn/src/relay/relay_static.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::net::IpAddr; - -use async_trait::async_trait; -use util::vnet::net::*; - -use super::*; -use crate::error::*; - -/// `RelayAddressGeneratorStatic` can be used to return static IP address each time a relay is created. -/// This can be used when you have a single static IP address that you want to use. -pub struct RelayAddressGeneratorStatic { - /// `relay_address` is the IP returned to the user when the relay is created. - pub relay_address: IpAddr, - - /// `address` is passed to Listen/ListenPacket when creating the Relay. - pub address: String, - - pub net: Arc, -} - -#[async_trait] -impl RelayAddressGenerator for RelayAddressGeneratorStatic { - fn validate(&self) -> Result<()> { - if self.address.is_empty() { - Err(Error::ErrListeningAddressInvalid) - } else { - Ok(()) - } - } - - async fn allocate_conn( - &self, - use_ipv4: bool, - requested_port: u16, - ) -> Result<(Arc, SocketAddr)> { - let addr = self - .net - .resolve_addr(use_ipv4, &format!("{}:{}", self.address, requested_port)) - .await?; - let conn = self.net.bind(addr).await?; - let mut relay_addr = conn.local_addr()?; - relay_addr.set_ip(self.relay_address); - return Ok((conn, relay_addr)); - } -} diff --git a/turn/src/server/config.rs b/turn/src/server/config.rs deleted file mode 100644 index bf1a36c93..000000000 --- a/turn/src/server/config.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::sync::Arc; - -use tokio::sync::mpsc; -use tokio::time::Duration; -use util::Conn; - -use crate::allocation::*; -use crate::auth::*; -use crate::error::*; -use crate::relay::*; - -/// ConnConfig is used for UDP listeners -pub struct ConnConfig { - pub conn: Arc, - - // When an allocation is generated the RelayAddressGenerator - // creates the net.PacketConn and returns the IP/Port it is available at - pub relay_addr_generator: Box, -} - -impl ConnConfig { - pub fn validate(&self) -> Result<()> { - self.relay_addr_generator.validate() - } -} - -/// ServerConfig configures the TURN Server -pub struct ServerConfig { - /// `conn_configs` are a list of all the turn listeners. - /// Each listener can have custom behavior around the creation of Relays. - pub conn_configs: Vec, - - /// `realm` sets the realm for this server - pub realm: String, - - /// `auth_handler` is a callback used to handle incoming auth requests, - /// allowing users to customize Pion TURN with custom behavior. - pub auth_handler: Arc, - - /// `channel_bind_timeout` sets the lifetime of channel binding. Defaults to 10 minutes. - pub channel_bind_timeout: Duration, - - /// To receive notify on allocation close event, with metrics data. - pub alloc_close_notify: Option>, -} - -impl ServerConfig { - pub fn validate(&self) -> Result<()> { - if self.conn_configs.is_empty() { - return Err(Error::ErrNoAvailableConns); - } - - for cc in &self.conn_configs { - cc.validate()?; - } - Ok(()) - } -} diff --git a/turn/src/server/mod.rs b/turn/src/server/mod.rs deleted file mode 100644 index c8604fc01..000000000 --- a/turn/src/server/mod.rs +++ /dev/null @@ -1,261 +0,0 @@ -#[cfg(test)] -mod server_test; - -pub mod config; -pub mod request; - -use std::collections::HashMap; -use std::sync::Arc; - -use config::*; -use request::*; -use tokio::sync::broadcast::error::RecvError; -use tokio::sync::broadcast::{self}; -use tokio::sync::{mpsc, oneshot, Mutex}; -use tokio::time::{Duration, Instant}; -use util::Conn; - -use crate::allocation::allocation_manager::*; -use crate::allocation::five_tuple::FiveTuple; -use crate::allocation::AllocationInfo; -use crate::auth::AuthHandler; -use crate::error::*; -use crate::proto::lifetime::DEFAULT_LIFETIME; - -const INBOUND_MTU: usize = 1500; - -/// Server is an instance of the TURN Server -pub struct Server { - auth_handler: Arc, - realm: String, - channel_bind_timeout: Duration, - pub(crate) nonces: Arc>>, - command_tx: Mutex>>, -} - -impl Server { - /// creates a new TURN server - pub async fn new(config: ServerConfig) -> Result { - config.validate()?; - - let (command_tx, _) = broadcast::channel(16); - let mut s = Server { - auth_handler: config.auth_handler, - realm: config.realm, - channel_bind_timeout: config.channel_bind_timeout, - nonces: Arc::new(Mutex::new(HashMap::new())), - command_tx: Mutex::new(Some(command_tx.clone())), - }; - - if s.channel_bind_timeout == Duration::from_secs(0) { - s.channel_bind_timeout = DEFAULT_LIFETIME; - } - - for p in config.conn_configs.into_iter() { - let nonces = Arc::clone(&s.nonces); - let auth_handler = Arc::clone(&s.auth_handler); - let realm = s.realm.clone(); - let channel_bind_timeout = s.channel_bind_timeout; - let handle_rx = command_tx.subscribe(); - let conn = p.conn; - let allocation_manager = Arc::new(Manager::new(ManagerConfig { - relay_addr_generator: p.relay_addr_generator, - alloc_close_notify: config.alloc_close_notify.clone(), - })); - - tokio::spawn(Server::read_loop( - conn, - allocation_manager, - nonces, - auth_handler, - realm, - channel_bind_timeout, - handle_rx, - )); - } - - Ok(s) - } - - /// Deletes all existing [`Allocation`][`Allocation`]s by the provided `username`. - /// - /// [`Allocation`]: crate::allocation::Allocation - pub async fn delete_allocations_by_username(&self, username: String) -> Result<()> { - let tx = { - let command_tx = self.command_tx.lock().await; - command_tx.clone() - }; - if let Some(tx) = tx { - let (closed_tx, closed_rx) = mpsc::channel(1); - tx.send(Command::DeleteAllocations(username, Arc::new(closed_rx))) - .map_err(|_| Error::ErrClosed)?; - - closed_tx.closed().await; - - Ok(()) - } else { - Err(Error::ErrClosed) - } - } - - /// Get information of [`Allocation`][`Allocation`]s by specified [`FiveTuple`]s. - /// - /// If `five_tuples` is: - /// - [`None`]: It returns information about the all - /// [`Allocation`][`Allocation`]s. - /// - [`Some`] and not empty: It returns information about - /// the [`Allocation`][`Allocation`]s associated with - /// the specified [`FiveTuples`]. - /// - [`Some`], but empty: It returns an empty [`HashMap`]. - /// - /// [`Allocation`]: crate::allocation::Allocation - pub async fn get_allocations_info( - &self, - five_tuples: Option>, - ) -> Result> { - if let Some(five_tuples) = &five_tuples { - if five_tuples.is_empty() { - return Ok(HashMap::new()); - } - } - - let tx = { - let command_tx = self.command_tx.lock().await; - command_tx.clone() - }; - if let Some(tx) = tx { - let (infos_tx, mut infos_rx) = mpsc::channel(1); - tx.send(Command::GetAllocationsInfo(five_tuples, infos_tx)) - .map_err(|_| Error::ErrClosed)?; - - let mut info: HashMap = HashMap::new(); - - for _ in 0..tx.receiver_count() { - info.extend(infos_rx.recv().await.ok_or(Error::ErrClosed)?); - } - - Ok(info) - } else { - Err(Error::ErrClosed) - } - } - - async fn read_loop( - conn: Arc, - allocation_manager: Arc, - nonces: Arc>>, - auth_handler: Arc, - realm: String, - channel_bind_timeout: Duration, - mut handle_rx: broadcast::Receiver, - ) { - let mut buf = vec![0u8; INBOUND_MTU]; - - let (mut close_tx, mut close_rx) = oneshot::channel::<()>(); - - tokio::spawn({ - let allocation_manager = Arc::clone(&allocation_manager); - - async move { - loop { - match handle_rx.recv().await { - Ok(Command::DeleteAllocations(name, _)) => { - allocation_manager - .delete_allocations_by_username(name.as_str()) - .await; - continue; - } - Ok(Command::GetAllocationsInfo(five_tuples, tx)) => { - let infos = allocation_manager.get_allocations_info(five_tuples).await; - let _ = tx.send(infos).await; - - continue; - } - Err(RecvError::Closed) | Ok(Command::Close(_)) => { - close_rx.close(); - break; - } - Err(RecvError::Lagged(n)) => { - log::warn!("Turn server has lagged by {} messages", n); - continue; - } - } - } - } - }); - - loop { - let (n, addr) = tokio::select! { - v = conn.recv_from(&mut buf) => { - match v { - Ok(v) => v, - Err(err) => { - log::debug!("exit read loop on error: {}", err); - break; - } - } - }, - _ = close_tx.closed() => break - }; - - let mut r = Request { - conn: Arc::clone(&conn), - src_addr: addr, - buff: buf[..n].to_vec(), - allocation_manager: Arc::clone(&allocation_manager), - nonces: Arc::clone(&nonces), - auth_handler: Arc::clone(&auth_handler), - realm: realm.clone(), - channel_bind_timeout, - }; - - if let Err(err) = r.handle_request().await { - log::error!("error when handling datagram: {}", err); - } - } - - let _ = allocation_manager.close().await; - let _ = conn.close().await; - } - - /// Close stops the TURN Server. It cleans up any associated state and closes all connections it is managing. - pub async fn close(&self) -> Result<()> { - let tx = { - let mut command_tx = self.command_tx.lock().await; - command_tx.take() - }; - - if let Some(tx) = tx { - if tx.receiver_count() == 0 { - return Ok(()); - } - - let (closed_tx, closed_rx) = mpsc::channel(1); - let _ = tx.send(Command::Close(Arc::new(closed_rx))); - closed_tx.closed().await - } - - Ok(()) - } -} - -/// The protocol to communicate between the [`Server`]'s public methods -/// and the tasks spawned in the [`Server::read_loop`] method. -#[derive(Clone)] -enum Command { - /// Command to delete [`Allocation`][`Allocation`] by provided `username`. - /// - /// [`Allocation`]: `crate::allocation::Allocation` - DeleteAllocations(String, Arc>), - - /// Command to get information of [`Allocation`][`Allocation`]s by provided [`FiveTuple`]s. - /// - /// [`Allocation`]: `crate::allocation::Allocation` - GetAllocationsInfo( - Option>, - mpsc::Sender>, - ), - - /// Command to close the [`Server`]. - Close(Arc>), -} diff --git a/turn/src/server/request.rs b/turn/src/server/request.rs deleted file mode 100644 index 9d1a3104f..000000000 --- a/turn/src/server/request.rs +++ /dev/null @@ -1,1031 +0,0 @@ -#[cfg(test)] -mod request_test; - -use std::collections::HashMap; -use std::marker::{Send, Sync}; -use std::net::SocketAddr; -#[cfg(feature = "metrics")] -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use md5::{Digest, Md5}; -use stun::agent::*; -use stun::attributes::*; -use stun::error_code::*; -use stun::fingerprint::*; -use stun::integrity::*; -use stun::message::*; -use stun::textattrs::*; -use stun::uattrs::*; -use stun::xoraddr::*; -use tokio::sync::Mutex; -use tokio::time::{Duration, Instant}; -use util::Conn; - -use crate::allocation::allocation_manager::*; -use crate::allocation::channel_bind::ChannelBind; -use crate::allocation::five_tuple::*; -use crate::allocation::permission::Permission; -use crate::auth::*; -use crate::error::*; -use crate::proto::chandata::ChannelData; -use crate::proto::channum::ChannelNumber; -use crate::proto::data::Data; -use crate::proto::evenport::EvenPort; -use crate::proto::lifetime::*; -use crate::proto::peeraddr::PeerAddress; -use crate::proto::relayaddr::RelayedAddress; -use crate::proto::reqfamily::{ - RequestedAddressFamily, REQUESTED_FAMILY_IPV4, REQUESTED_FAMILY_IPV6, -}; -use crate::proto::reqtrans::RequestedTransport; -use crate::proto::rsrvtoken::ReservationToken; -use crate::proto::*; - -pub(crate) const MAXIMUM_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); // https://tools.ietf.org/html/rfc5766#section-6.2 defines 3600 seconds recommendation -pub(crate) const NONCE_LIFETIME: Duration = Duration::from_secs(3600); // https://tools.ietf.org/html/rfc5766#section-4 - -/// Request contains all the state needed to process a single incoming datagram -pub struct Request { - // Current Request State - pub conn: Arc, - pub src_addr: SocketAddr, - pub buff: Vec, - - // Server State - pub allocation_manager: Arc, - pub nonces: Arc>>, - - // User Configuration - pub auth_handler: Arc, - pub realm: String, - pub channel_bind_timeout: Duration, -} - -impl Request { - pub fn new( - conn: Arc, - src_addr: SocketAddr, - allocation_manager: Arc, - auth_handler: Arc, - ) -> Self { - Request { - conn, - src_addr, - buff: vec![], - allocation_manager, - nonces: Arc::new(Mutex::new(HashMap::new())), - auth_handler, - realm: String::new(), - channel_bind_timeout: Duration::from_secs(0), - } - } - - /// Processes the give [`Request`] - pub async fn handle_request(&mut self) -> Result<()> { - /*log::debug!( - "received {} bytes of udp from {} on {}", - self.buff.len(), - self.src_addr, - self.conn.local_addr().await? - );*/ - - if ChannelData::is_channel_data(&self.buff) { - self.handle_data_packet().await - } else { - self.handle_turn_packet().await - } - } - - async fn handle_data_packet(&mut self) -> Result<()> { - log::debug!("received DataPacket from {}", self.src_addr); - let mut c = ChannelData { - raw: self.buff.clone(), - ..Default::default() - }; - c.decode()?; - self.handle_channel_data(&c).await - } - - async fn handle_turn_packet(&mut self) -> Result<()> { - log::debug!("handle_turn_packet"); - let mut m = Message { - raw: self.buff.clone(), - ..Default::default() - }; - m.decode()?; - - self.process_message_handler(&m).await - } - - async fn process_message_handler(&mut self, m: &Message) -> Result<()> { - if m.typ.class == CLASS_INDICATION { - match m.typ.method { - METHOD_SEND => self.handle_send_indication(m).await, - _ => Err(Error::ErrUnexpectedClass), - } - } else if m.typ.class == CLASS_REQUEST { - match m.typ.method { - METHOD_ALLOCATE => self.handle_allocate_request(m).await, - METHOD_REFRESH => self.handle_refresh_request(m).await, - METHOD_CREATE_PERMISSION => self.handle_create_permission_request(m).await, - METHOD_CHANNEL_BIND => self.handle_channel_bind_request(m).await, - METHOD_BINDING => self.handle_binding_request(m).await, - _ => Err(Error::ErrUnexpectedClass), - } - } else { - Err(Error::ErrUnexpectedClass) - } - } - - pub(crate) async fn authenticate_request( - &mut self, - m: &Message, - calling_method: Method, - ) -> Result> { - if !m.contains(ATTR_MESSAGE_INTEGRITY) { - self.respond_with_nonce(m, calling_method, CODE_UNAUTHORIZED) - .await?; - return Ok(None); - } - - let mut nonce_attr = Nonce::new(ATTR_NONCE, String::new()); - let mut username_attr = Username::new(ATTR_USERNAME, String::new()); - let mut realm_attr = Realm::new(ATTR_REALM, String::new()); - let bad_request_msg = build_msg( - m.transaction_id, - MessageType::new(calling_method, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_BAD_REQUEST, - reason: vec![], - })], - )?; - - if let Err(err) = nonce_attr.get_from(m) { - build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err.into()).await?; - return Ok(None); - } - - let to_be_deleted = { - // Assert Nonce exists and is not expired - let mut nonces = self.nonces.lock().await; - - let to_be_deleted = if let Some(nonce_creation_time) = nonces.get(&nonce_attr.text) { - Instant::now() - .checked_duration_since(*nonce_creation_time) - .unwrap_or_else(|| Duration::from_secs(0)) - >= NONCE_LIFETIME - } else { - true - }; - - if to_be_deleted { - nonces.remove(&nonce_attr.text); - } - to_be_deleted - }; - - if to_be_deleted { - self.respond_with_nonce(m, calling_method, CODE_STALE_NONCE) - .await?; - return Ok(None); - } - - if let Err(err) = realm_attr.get_from(m) { - build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err.into()).await?; - return Ok(None); - } - if let Err(err) = username_attr.get_from(m) { - build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err.into()).await?; - return Ok(None); - } - - let our_key = match self.auth_handler.auth_handle( - &username_attr.to_string(), - &realm_attr.to_string(), - self.src_addr, - ) { - Ok(key) => key, - Err(_) => { - build_and_send_err( - &self.conn, - self.src_addr, - bad_request_msg, - Error::ErrNoSuchUser, - ) - .await?; - return Ok(None); - } - }; - - let mi = MessageIntegrity(our_key); - if let Err(err) = mi.check(&mut m.clone()) { - build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err.into()).await?; - Ok(None) - } else { - Ok(Some((username_attr, mi))) - } - } - - async fn respond_with_nonce( - &mut self, - m: &Message, - calling_method: Method, - response_code: ErrorCode, - ) -> Result<()> { - let nonce = build_nonce()?; - - { - // Nonce has already been taken - let mut nonces = self.nonces.lock().await; - if nonces.contains_key(&nonce) { - return Err(Error::ErrDuplicatedNonce); - } - nonces.insert(nonce.clone(), Instant::now()); - } - - let msg = build_msg( - m.transaction_id, - MessageType::new(calling_method, CLASS_ERROR_RESPONSE), - vec![ - Box::new(ErrorCodeAttribute { - code: response_code, - reason: vec![], - }), - Box::new(Nonce::new(ATTR_NONCE, nonce)), - Box::new(Realm::new(ATTR_REALM, self.realm.clone())), - ], - )?; - - build_and_send(&self.conn, self.src_addr, msg).await - } - - pub(crate) async fn handle_binding_request(&mut self, m: &Message) -> Result<()> { - log::debug!("received BindingRequest from {}", self.src_addr); - - let (ip, port) = (self.src_addr.ip(), self.src_addr.port()); - - let msg = build_msg( - m.transaction_id, - BINDING_SUCCESS, - vec![ - Box::new(XorMappedAddress { ip, port }), - Box::new(FINGERPRINT), - ], - )?; - - build_and_send(&self.conn, self.src_addr, msg).await - } - - /// https://tools.ietf.org/html/rfc5766#section-6.2 - pub(crate) async fn handle_allocate_request(&mut self, m: &Message) -> Result<()> { - log::debug!("received AllocateRequest from {}", self.src_addr); - - // 1. The server MUST require that the request be authenticated. This - // authentication MUST be done using the long-term credential - // mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2] - // unless the client and server agree to use another mechanism through - // some procedure outside the scope of this document. - let (username, message_integrity) = - if let Some(mi) = self.authenticate_request(m, METHOD_ALLOCATE).await? { - mi - } else { - log::debug!("no MessageIntegrity"); - return Ok(()); - }; - - let five_tuple = FiveTuple { - src_addr: self.src_addr, - dst_addr: self.conn.local_addr()?, - protocol: PROTO_UDP, - }; - let mut requested_port = 0; - let mut reservation_token = "".to_owned(); - let mut use_ipv4 = true; - - // 2. The server checks if the 5-tuple is currently in use by an - // existing allocation. If yes, the server rejects the request with - // a 437 (Allocation Mismatch) error. - if self - .allocation_manager - .get_allocation(&five_tuple) - .await - .is_some() - { - let msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_ALLOC_MISMATCH, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - msg, - Error::ErrRelayAlreadyAllocatedForFiveTuple, - ) - .await; - } - - // 3. The server checks if the request contains a REQUESTED-TRANSPORT - // attribute. If the REQUESTED-TRANSPORT attribute is not included - // or is malformed, the server rejects the request with a 400 (Bad - // Request) error. Otherwise, if the attribute is included but - // specifies a protocol other that UDP, the server rejects the - // request with a 442 (Unsupported Transport Protocol) error. - let mut requested_transport = RequestedTransport::default(); - if let Err(err) = requested_transport.get_from(m) { - let bad_request_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_BAD_REQUEST, - reason: vec![], - })], - )?; - return build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err.into()) - .await; - } else if requested_transport.protocol != PROTO_UDP { - let msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_UNSUPPORTED_TRANS_PROTO, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - msg, - Error::ErrRequestedTransportMustBeUdp, - ) - .await; - } - - // 4. The request may contain a DONT-FRAGMENT attribute. If it does, - // but the server does not support sending UDP datagrams with the DF - // bit set to 1 (see Section 12), then the server treats the DONT- - // FRAGMENT attribute in the Allocate request as an unknown - // comprehension-required attribute. - if m.contains(ATTR_DONT_FRAGMENT) { - let msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![ - Box::new(ErrorCodeAttribute { - code: CODE_UNKNOWN_ATTRIBUTE, - reason: vec![], - }), - Box::new(UnknownAttributes(vec![ATTR_DONT_FRAGMENT])), - ], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - msg, - Error::ErrNoDontFragmentSupport, - ) - .await; - } - - // 5. The server checks if the request contains a RESERVATION-TOKEN - // attribute. If yes, and the request also contains an EVEN-PORT - // attribute, then the server rejects the request with a 400 (Bad - // Request) error. Otherwise, it checks to see if the token is - // valid (i.e., the token is in range and has not expired and the - // corresponding relayed transport address is still available). If - // the token is not valid for some reason, the server rejects the - // request with a 508 (Insufficient Capacity) error. - let mut reservation_token_attr = ReservationToken::default(); - let reservation_token_attr_result = reservation_token_attr.get_from(m); - if reservation_token_attr_result.is_ok() { - let mut even_port = EvenPort::default(); - if even_port.get_from(m).is_ok() { - let bad_request_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_BAD_REQUEST, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - bad_request_msg, - Error::ErrRequestWithReservationTokenAndEvenPort, - ) - .await; - } - } - - // RFC 6156, Section 4.2: - // - // If it contains both a RESERVATION-TOKEN and a - // REQUESTED-ADDRESS-FAMILY, the server replies with a 400 - // (Bad Request) Allocate error response. - // - // 4.2.1. Unsupported Address Family - // This document defines the following new error response code: - // 440 (Address Family not Supported): The server does not support the - // address family requested by the client. - let mut req_family = RequestedAddressFamily::default(); - match req_family.get_from(m) { - Err(err) => { - // Currently, the RequestedAddressFamily::get_from() function returns - // Err::Other only when it is an unsupported address family. - if let stun::Error::Other(_) = err { - let addr_family_not_supported_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_ADDR_FAMILY_NOT_SUPPORTED, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - addr_family_not_supported_msg, - Error::ErrInvalidRequestedFamilyValue, - ) - .await; - } - } - Ok(()) => { - if reservation_token_attr_result.is_ok() { - let bad_request_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_BAD_REQUEST, - reason: vec![], - })], - )?; - - return build_and_send_err( - &self.conn, - self.src_addr, - bad_request_msg, - Error::ErrRequestWithReservationTokenAndReqAddressFamily, - ) - .await; - } - - if req_family == REQUESTED_FAMILY_IPV6 { - use_ipv4 = false; - } - } - } - - // 6. The server checks if the request contains an EVEN-PORT attribute. - // If yes, then the server checks that it can satisfy the request - // (i.e., can allocate a relayed transport address as described - // below). If the server cannot satisfy the request, then the - // server rejects the request with a 508 (Insufficient Capacity) - // error. - let mut even_port = EvenPort::default(); - if even_port.get_from(m).is_ok() { - let mut random_port = 1; - - while random_port % 2 != 0 { - random_port = match self.allocation_manager.get_random_even_port().await { - Ok(port) => port, - Err(err) => { - let insufficient_capacity_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_INSUFFICIENT_CAPACITY, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - insufficient_capacity_msg, - err, - ) - .await; - } - }; - } - - requested_port = random_port; - reservation_token = rand_seq(8); - } - - // 7. At any point, the server MAY choose to reject the request with a - // 486 (Allocation Quota Reached) error if it feels the client is - // trying to exceed some locally defined allocation quota. The - // server is free to define this allocation quota any way it wishes, - // but SHOULD define it based on the username used to authenticate - // the request, and not on the client's transport address. - - // 8. Also at any point, the server MAY choose to reject the request - // with a 300 (Try Alternate) error if it wishes to redirect the - // client to a different server. The use of this error code and - // attribute follow the specification in [RFC5389]. - let lifetime_duration = allocation_lifetime(m); - let a = match self - .allocation_manager - .create_allocation( - five_tuple, - Arc::clone(&self.conn), - requested_port, - lifetime_duration, - username, - use_ipv4, - ) - .await - { - Ok(a) => a, - Err(err) => { - let insufficient_capacity_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_INSUFFICIENT_CAPACITY, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - insufficient_capacity_msg, - err, - ) - .await; - } - }; - - // Once the allocation is created, the server replies with a success - // response. The success response contains: - // * An XOR-RELAYED-ADDRESS attribute containing the relayed transport - // address. - // * A LIFETIME attribute containing the current value of the time-to- - // expiry timer. - // * A RESERVATION-TOKEN attribute (if a second relayed transport - // address was reserved). - // * An XOR-MAPPED-ADDRESS attribute containing the client's IP address - // and port (from the 5-tuple). - - let (src_ip, src_port) = (self.src_addr.ip(), self.src_addr.port()); - let relay_ip = a.relay_addr.ip(); - let relay_port = a.relay_addr.port(); - - let msg = { - if !reservation_token.is_empty() { - self.allocation_manager - .create_reservation(reservation_token.clone(), relay_port) - .await; - } - - let mut response_attrs: Vec> = vec![ - Box::new(RelayedAddress { - ip: relay_ip, - port: relay_port, - }), - Box::new(Lifetime(lifetime_duration)), - Box::new(XorMappedAddress { - ip: src_ip, - port: src_port, - }), - ]; - - if !reservation_token.is_empty() { - response_attrs.push(Box::new(ReservationToken( - reservation_token.as_bytes().to_vec(), - ))); - } - - response_attrs.push(Box::new(message_integrity)); - build_msg( - m.transaction_id, - MessageType::new(METHOD_ALLOCATE, CLASS_SUCCESS_RESPONSE), - response_attrs, - )? - }; - - build_and_send(&self.conn, self.src_addr, msg).await - } - - pub(crate) async fn handle_refresh_request(&mut self, m: &Message) -> Result<()> { - log::debug!("received RefreshRequest from {}", self.src_addr); - - let (_, message_integrity) = - if let Some(mi) = self.authenticate_request(m, METHOD_REFRESH).await? { - mi - } else { - log::debug!("no MessageIntegrity"); - return Ok(()); - }; - - let lifetime_duration = allocation_lifetime(m); - let five_tuple = FiveTuple { - src_addr: self.src_addr, - dst_addr: self.conn.local_addr()?, - protocol: PROTO_UDP, - }; - - if lifetime_duration != Duration::from_secs(0) { - let a = self.allocation_manager.get_allocation(&five_tuple).await; - if let Some(a) = a { - // If a server receives a Refresh Request with a REQUESTED-ADDRESS-FAMILY - // attribute, and the attribute's value doesn't match the address - // family of the allocation, the server MUST reply with a 443 (Peer - // Address Family Mismatch) Refresh error response. [RFC 6156, Section 5.2] - let mut req_family = RequestedAddressFamily::default(); - if req_family.get_from(m).is_ok() - && ((req_family == REQUESTED_FAMILY_IPV6 && !a.relay_addr.is_ipv6()) - || (req_family == REQUESTED_FAMILY_IPV4 && !a.relay_addr.is_ipv4())) - { - let peer_address_family_mismatch_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_REFRESH, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_PEER_ADDR_FAMILY_MISMATCH, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - peer_address_family_mismatch_msg, - Error::ErrPeerAddressFamilyMismatch, - ) - .await; - } - a.refresh(lifetime_duration).await; - } else { - return Err(Error::ErrNoAllocationFound); - } - } else { - self.allocation_manager.delete_allocation(&five_tuple).await; - } - - let msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_REFRESH, CLASS_SUCCESS_RESPONSE), - vec![ - Box::new(Lifetime(lifetime_duration)), - Box::new(message_integrity), - ], - )?; - - build_and_send(&self.conn, self.src_addr, msg).await - } - - pub(crate) async fn handle_create_permission_request(&mut self, m: &Message) -> Result<()> { - log::debug!("received CreatePermission from {}", self.src_addr); - - let a = self - .allocation_manager - .get_allocation(&FiveTuple { - src_addr: self.src_addr, - dst_addr: self.conn.local_addr()?, - protocol: PROTO_UDP, - }) - .await; - - if let Some(a) = a { - let (_, message_integrity) = if let Some(mi) = self - .authenticate_request(m, METHOD_CREATE_PERMISSION) - .await? - { - mi - } else { - log::debug!("no MessageIntegrity"); - return Ok(()); - }; - let mut add_count = 0; - - { - for attr in &m.attributes.0 { - if attr.typ != ATTR_XOR_PEER_ADDRESS { - continue; - } - - let mut peer_address = PeerAddress::default(); - if peer_address.get_from(m).is_err() { - add_count = 0; - break; - } - - // If an XOR-PEER-ADDRESS attribute contains an address of an address - // family different than that of the relayed transport address for the - // allocation, the server MUST generate an error response with the 443 - // (Peer Address Family Mismatch) response code. [RFC 6156, Section 6.2] - if (peer_address.ip.is_ipv4() && !a.relay_addr.is_ipv4()) - || (peer_address.ip.is_ipv6() && !a.relay_addr.is_ipv6()) - { - let peer_address_family_mismatch_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_CREATE_PERMISSION, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_PEER_ADDR_FAMILY_MISMATCH, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - peer_address_family_mismatch_msg, - Error::ErrPeerAddressFamilyMismatch, - ) - .await; - } - - log::debug!( - "adding permission for {}", - format!("{}:{}", peer_address.ip, peer_address.port) - ); - - a.add_permission(Permission::new(SocketAddr::new( - peer_address.ip, - peer_address.port, - ))) - .await; - add_count += 1; - } - } - - let mut resp_class = CLASS_SUCCESS_RESPONSE; - if add_count == 0 { - resp_class = CLASS_ERROR_RESPONSE; - } - - let msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_CREATE_PERMISSION, resp_class), - vec![Box::new(message_integrity)], - )?; - - build_and_send(&self.conn, self.src_addr, msg).await - } else { - Err(Error::ErrNoAllocationFound) - } - } - - pub(crate) async fn handle_send_indication(&mut self, m: &Message) -> Result<()> { - log::debug!("received SendIndication from {}", self.src_addr); - - let a = self - .allocation_manager - .get_allocation(&FiveTuple { - src_addr: self.src_addr, - dst_addr: self.conn.local_addr()?, - protocol: PROTO_UDP, - }) - .await; - - if let Some(a) = a { - let mut data_attr = Data::default(); - data_attr.get_from(m)?; - - let mut peer_address = PeerAddress::default(); - peer_address.get_from(m)?; - - let msg_dst = SocketAddr::new(peer_address.ip, peer_address.port); - - let has_perm = a.has_permission(&msg_dst).await; - if !has_perm { - return Err(Error::ErrNoPermission); - } - - let l = a.relay_socket.send_to(&data_attr.0, msg_dst).await?; - if l != data_attr.0.len() { - Err(Error::ErrShortWrite) - } else { - #[cfg(feature = "metrics")] - a.relayed_bytes - .fetch_add(data_attr.0.len(), Ordering::AcqRel); - - Ok(()) - } - } else { - Err(Error::ErrNoAllocationFound) - } - } - - pub(crate) async fn handle_channel_bind_request(&mut self, m: &Message) -> Result<()> { - log::debug!("received ChannelBindRequest from {}", self.src_addr); - - let a = self - .allocation_manager - .get_allocation(&FiveTuple { - src_addr: self.src_addr, - dst_addr: self.conn.local_addr()?, - protocol: PROTO_UDP, - }) - .await; - - if let Some(a) = a { - let bad_request_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_CHANNEL_BIND, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_BAD_REQUEST, - reason: vec![], - })], - )?; - - let (_, message_integrity) = - if let Some(mi) = self.authenticate_request(m, METHOD_CHANNEL_BIND).await? { - mi - } else { - log::debug!("no MessageIntegrity"); - return Ok(()); - }; - let mut channel = ChannelNumber::default(); - if let Err(err) = channel.get_from(m) { - return build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err.into()) - .await; - } - - let mut peer_addr = PeerAddress::default(); - match peer_addr.get_from(m) { - Err(err) => { - return build_and_send_err( - &self.conn, - self.src_addr, - bad_request_msg, - err.into(), - ) - .await; - } - _ => { - // If the XOR-PEER-ADDRESS attribute contains an address of an address - // family different than that of the relayed transport address for the - // allocation, the server MUST generate an error response with the 443 - // (Peer Address Family Mismatch) response code. [RFC 6156, Section 7.2] - if (peer_addr.ip.is_ipv4() && !a.relay_addr.is_ipv4()) - || (peer_addr.ip.is_ipv6() && !a.relay_addr.is_ipv6()) - { - let peer_address_family_mismatch_msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_CHANNEL_BIND, CLASS_ERROR_RESPONSE), - vec![Box::new(ErrorCodeAttribute { - code: CODE_PEER_ADDR_FAMILY_MISMATCH, - reason: vec![], - })], - )?; - return build_and_send_err( - &self.conn, - self.src_addr, - peer_address_family_mismatch_msg, - Error::ErrPeerAddressFamilyMismatch, - ) - .await; - } - } - } - - log::debug!( - "binding channel {} to {}", - channel, - format!("{}:{}", peer_addr.ip, peer_addr.port) - ); - - let result = { - a.add_channel_bind( - ChannelBind::new(channel, SocketAddr::new(peer_addr.ip, peer_addr.port)), - self.channel_bind_timeout, - ) - .await - }; - if let Err(err) = result { - return build_and_send_err(&self.conn, self.src_addr, bad_request_msg, err).await; - } - - let msg = build_msg( - m.transaction_id, - MessageType::new(METHOD_CHANNEL_BIND, CLASS_SUCCESS_RESPONSE), - vec![Box::new(message_integrity)], - )?; - build_and_send(&self.conn, self.src_addr, msg).await - } else { - Err(Error::ErrNoAllocationFound) - } - } - - pub(crate) async fn handle_channel_data(&mut self, c: &ChannelData) -> Result<()> { - log::debug!("received ChannelData from {}", self.src_addr); - - let a = self - .allocation_manager - .get_allocation(&FiveTuple { - src_addr: self.src_addr, - dst_addr: self.conn.local_addr()?, - protocol: PROTO_UDP, - }) - .await; - - if let Some(a) = a { - let channel = a.get_channel_addr(&c.number).await; - if let Some(peer) = channel { - let l = a.relay_socket.send_to(&c.data, peer).await?; - if l != c.data.len() { - Err(Error::ErrShortWrite) - } else { - #[cfg(feature = "metrics")] - a.relayed_bytes.fetch_add(c.data.len(), Ordering::AcqRel); - - Ok(()) - } - } else { - Err(Error::ErrNoSuchChannelBind) - } - } else { - Err(Error::ErrNoAllocationFound) - } - } -} - -pub(crate) fn rand_seq(n: usize) -> String { - let letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".as_bytes(); - let mut buf = vec![0u8; n]; - for b in &mut buf { - *b = letters[rand::random::() % letters.len()]; - } - if let Ok(s) = String::from_utf8(buf) { - s - } else { - String::new() - } -} - -pub(crate) fn build_nonce() -> Result { - /* #nosec */ - let mut s = String::new(); - s.push_str( - format!( - "{}", - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH)? - .as_nanos() - ) - .as_str(), - ); - s.push_str(format!("{}", rand::random::()).as_str()); - - let mut h = Md5::new(); - h.update(s.as_bytes()); - Ok(format!("{:x}", h.finalize())) -} - -pub(crate) async fn build_and_send( - conn: &Arc, - dst: SocketAddr, - msg: Message, -) -> Result<()> { - let _ = conn.send_to(&msg.raw, dst).await?; - Ok(()) -} - -/// Send a STUN packet and return the original error to the caller -pub(crate) async fn build_and_send_err( - conn: &Arc, - dst: SocketAddr, - msg: Message, - err: Error, -) -> Result<()> { - build_and_send(conn, dst, msg).await?; - - Err(err) -} - -pub(crate) fn build_msg( - transaction_id: TransactionId, - msg_type: MessageType, - mut additional: Vec>, -) -> Result { - let mut attrs: Vec> = vec![ - Box::new(Message { - transaction_id, - ..Default::default() - }), - Box::new(msg_type), - ]; - - attrs.append(&mut additional); - - let mut msg = Message::new(); - msg.build(&attrs)?; - Ok(msg) -} - -pub(crate) fn allocation_lifetime(m: &Message) -> Duration { - let mut lifetime_duration = DEFAULT_LIFETIME; - - let mut lifetime = Lifetime::default(); - if lifetime.get_from(m).is_ok() && lifetime.0 < MAXIMUM_ALLOCATION_LIFETIME { - lifetime_duration = lifetime.0; - } - - lifetime_duration -} diff --git a/turn/src/server/request/request_test.rs b/turn/src/server/request/request_test.rs deleted file mode 100644 index cfe012055..000000000 --- a/turn/src/server/request/request_test.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::net::IpAddr; -use std::str::FromStr; - -use tokio::net::UdpSocket; -use tokio::time::{Duration, Instant}; -use util::vnet::net::*; - -use super::*; -use crate::relay::relay_none::*; - -const STATIC_KEY: &str = "ABC"; - -#[tokio::test] -async fn test_allocation_lifetime_parsing() -> Result<()> { - let lifetime = Lifetime(Duration::from_secs(5)); - - let mut m = Message::new(); - let lifetime_duration = allocation_lifetime(&m); - - assert_eq!( - lifetime_duration, DEFAULT_LIFETIME, - "Allocation lifetime should be default time duration" - ); - - lifetime.add_to(&mut m)?; - - let lifetime_duration = allocation_lifetime(&m); - assert_eq!( - lifetime_duration, lifetime.0, - "Expect lifetime_duration is {lifetime}, but {lifetime_duration:?}" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_allocation_lifetime_overflow() -> Result<()> { - let lifetime = Lifetime(MAXIMUM_ALLOCATION_LIFETIME * 2); - - let mut m2 = Message::new(); - lifetime.add_to(&mut m2)?; - - let lifetime_duration = allocation_lifetime(&m2); - assert_eq!( - lifetime_duration, DEFAULT_LIFETIME, - "Expect lifetime_duration is {DEFAULT_LIFETIME:?}, but {lifetime_duration:?}" - ); - - Ok(()) -} - -struct TestAuthHandler; -impl AuthHandler for TestAuthHandler { - fn auth_handle(&self, _username: &str, _realm: &str, _src_addr: SocketAddr) -> Result> { - Ok(STATIC_KEY.as_bytes().to_vec()) - } -} - -#[tokio::test] -async fn test_allocation_lifetime_deletion_zero_lifetime() -> Result<()> { - //env_logger::init(); - - let l = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let allocation_manager = Arc::new(Manager::new(ManagerConfig { - relay_addr_generator: Box::new(RelayAddressGeneratorNone { - address: "0.0.0.0".to_owned(), - net: Arc::new(Net::new(None)), - }), - alloc_close_notify: None, - })); - - let socket = SocketAddr::new(IpAddr::from_str("127.0.0.1")?, 5000); - - let mut r = Request::new(l, socket, allocation_manager, Arc::new(TestAuthHandler {})); - - { - let mut nonces = r.nonces.lock().await; - nonces.insert(STATIC_KEY.to_owned(), Instant::now()); - } - - let five_tuple = FiveTuple { - src_addr: r.src_addr, - dst_addr: r.conn.local_addr()?, - protocol: PROTO_UDP, - }; - - r.allocation_manager - .create_allocation( - five_tuple, - Arc::clone(&r.conn), - 0, - Duration::from_secs(3600), - TextAttribute::new(ATTR_USERNAME, "user".into()), - true, - ) - .await?; - assert!(r - .allocation_manager - .get_allocation(&five_tuple) - .await - .is_some()); - - let mut m = Message::new(); - Lifetime::default().add_to(&mut m)?; - MessageIntegrity(STATIC_KEY.as_bytes().to_vec()).add_to(&mut m)?; - Nonce::new(ATTR_NONCE, STATIC_KEY.to_owned()).add_to(&mut m)?; - Realm::new(ATTR_REALM, STATIC_KEY.to_owned()).add_to(&mut m)?; - Username::new(ATTR_USERNAME, STATIC_KEY.to_owned()).add_to(&mut m)?; - - r.handle_refresh_request(&m).await?; - assert!(r - .allocation_manager - .get_allocation(&five_tuple) - .await - .is_none()); - - Ok(()) -} diff --git a/turn/src/server/server_test.rs b/turn/src/server/server_test.rs deleted file mode 100644 index 1505a12da..000000000 --- a/turn/src/server/server_test.rs +++ /dev/null @@ -1,338 +0,0 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::str::FromStr; - -use tokio::net::UdpSocket; -use tokio::sync::mpsc; -use util::vnet::router::Nic; -use util::vnet::*; - -use super::config::*; -use super::*; -use crate::auth::generate_auth_key; -use crate::client::*; -use crate::error::*; -use crate::relay::relay_none::RelayAddressGeneratorNone; -use crate::relay::relay_static::*; - -struct TestAuthHandler { - cred_map: HashMap>, -} - -impl TestAuthHandler { - fn new() -> Self { - let mut cred_map = HashMap::new(); - cred_map.insert( - "user".to_owned(), - generate_auth_key("user", "webrtc.rs", "pass"), - ); - - TestAuthHandler { cred_map } - } -} - -impl AuthHandler for TestAuthHandler { - fn auth_handle(&self, username: &str, _realm: &str, _src_addr: SocketAddr) -> Result> { - if let Some(pw) = self.cred_map.get(username) { - Ok(pw.to_vec()) - } else { - Err(Error::ErrFakeErr) - } - } -} - -#[tokio::test] -async fn test_server_simple() -> Result<()> { - // here, it should use static port, like "0.0.0.0:3478", - // but, due to different test environment, let's fake it by using "0.0.0.0:0" - // to auto assign a "static" port - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - let server_port = conn.local_addr()?.port(); - - let server = Server::new(ServerConfig { - conn_configs: vec![ConnConfig { - conn, - relay_addr_generator: Box::new(RelayAddressGeneratorStatic { - relay_address: IpAddr::from_str("127.0.0.1")?, - address: "0.0.0.0".to_owned(), - net: Arc::new(net::Net::new(None)), - }), - }], - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(TestAuthHandler::new()), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - assert_eq!( - DEFAULT_LIFETIME, server.channel_bind_timeout, - "should match" - ); - - let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); - - let client = Client::new(ClientConfig { - stun_serv_addr: String::new(), - turn_serv_addr: String::new(), - username: String::new(), - password: String::new(), - realm: String::new(), - software: String::new(), - rto_in_ms: 0, - conn, - vnet: None, - }) - .await?; - - client.listen().await?; - - client - .send_binding_request_to(format!("127.0.0.1:{server_port}").as_str()) - .await?; - - client.close().await?; - server.close().await?; - - Ok(()) -} - -struct VNet { - wan: Arc>, - net0: Arc, - net1: Arc, - netl0: Arc, - server: Server, -} - -async fn build_vnet() -> Result { - // WAN - let wan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?)); - - let net0 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ip: "1.2.3.4".to_owned(), // will be assigned to eth0 - ..Default::default() - }))); - - let net1 = Arc::new(net::Net::new(Some(net::NetConfig { - static_ip: "1.2.3.5".to_owned(), // will be assigned to eth0 - ..Default::default() - }))); - - { - let nic0 = net0.get_nic()?; - let nic1 = net1.get_nic()?; - - { - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic0)).await?; - w.add_net(Arc::clone(&nic1)).await?; - } - - let n0 = nic0.lock().await; - n0.set_router(Arc::clone(&wan)).await?; - - let n1 = nic1.lock().await; - n1.set_router(Arc::clone(&wan)).await?; - } - - // LAN - let lan = Arc::new(Mutex::new(router::Router::new(router::RouterConfig { - static_ip: "5.6.7.8".to_owned(), // this router's external IP on eth0 - cidr: "192.168.0.0/24".to_owned(), - nat_type: Some(nat::NatType { - mapping_behavior: nat::EndpointDependencyType::EndpointIndependent, - filtering_behavior: nat::EndpointDependencyType::EndpointIndependent, - ..Default::default() - }), - ..Default::default() - })?)); - - let netl0 = Arc::new(net::Net::new(Some(net::NetConfig::default()))); - - { - let nic = netl0.get_nic()?; - - { - let mut l = lan.lock().await; - l.add_net(Arc::clone(&nic)).await?; - } - - let n = nic.lock().await; - n.set_router(Arc::clone(&lan)).await?; - } - - { - { - let mut w = wan.lock().await; - w.add_router(Arc::clone(&lan)).await?; - } - - { - let l = lan.lock().await; - l.set_router(Arc::clone(&wan)).await?; - } - } - - { - let mut w = wan.lock().await; - w.start().await?; - } - - // start server... - let conn = net0.bind(SocketAddr::from_str("0.0.0.0:3478")?).await?; - - let server = Server::new(ServerConfig { - conn_configs: vec![ConnConfig { - conn, - relay_addr_generator: Box::new(RelayAddressGeneratorNone { - address: "1.2.3.4".to_owned(), - net: Arc::clone(&net0), - }), - }], - realm: "webrtc.rs".to_owned(), - auth_handler: Arc::new(TestAuthHandler::new()), - channel_bind_timeout: Duration::from_secs(0), - alloc_close_notify: None, - }) - .await?; - - // register host names - { - let mut w = wan.lock().await; - w.add_host("stun.webrtc.rs".to_owned(), "1.2.3.4".to_owned()) - .await?; - w.add_host("turn.webrtc.rs".to_owned(), "1.2.3.4".to_owned()) - .await?; - w.add_host("echo.webrtc.rs".to_owned(), "1.2.3.5".to_owned()) - .await?; - } - - Ok(VNet { - wan, - net0, - net1, - netl0, - server, - }) -} - -#[tokio::test] -async fn test_server_vnet_send_binding_request() -> Result<()> { - let v = build_vnet().await?; - - let lconn = v.netl0.bind(SocketAddr::from_str("0.0.0.0:0")?).await?; - log::debug!("creating a client."); - let client = Client::new(ClientConfig { - stun_serv_addr: "1.2.3.4:3478".to_owned(), - turn_serv_addr: String::new(), - username: String::new(), - password: String::new(), - realm: String::new(), - software: String::new(), - rto_in_ms: 0, - conn: lconn, - vnet: Some(Arc::clone(&v.netl0)), - }) - .await?; - - client.listen().await?; - - log::debug!("sending a binding request."); - let refl_addr = client.send_binding_request().await?; - log::debug!("mapped-address: {}", refl_addr); - - // The mapped-address should have IP address that was assigned - // to the LAN router. - assert_eq!( - refl_addr.ip().to_string(), - Ipv4Addr::new(5, 6, 7, 8).to_string(), - "should match", - ); - - client.close().await?; - Ok(()) -} - -#[tokio::test] -async fn test_server_vnet_echo_via_relay() -> Result<()> { - let v = build_vnet().await?; - - let lconn = v.netl0.bind(SocketAddr::from_str("0.0.0.0:0")?).await?; - log::debug!("creating a client."); - let client = Client::new(ClientConfig { - stun_serv_addr: "stun.webrtc.rs:3478".to_owned(), - turn_serv_addr: "turn.webrtc.rs:3478".to_owned(), - username: "user".to_owned(), - password: "pass".to_owned(), - realm: String::new(), - software: String::new(), - rto_in_ms: 0, - conn: lconn, - vnet: Some(Arc::clone(&v.netl0)), - }) - .await?; - - client.listen().await?; - - log::debug!("sending a binding request."); - let conn = client.allocate().await?; - let local_addr = conn.local_addr()?; - - log::debug!("laddr: {}", conn.local_addr()?); - - let echo_conn = v.net1.bind(SocketAddr::from_str("1.2.3.5:5678")?).await?; - let echo_addr = echo_conn.local_addr()?; - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - let mut n; - let mut from; - loop { - tokio::select! { - _ = done_rx.recv() => break, - result = echo_conn.recv_from(&mut buf) => { - match result { - Ok((s, addr)) => { - n = s; - from = addr; - } - Err(_) => break, - } - } - } - - // verify the message was received from the relay address - assert_eq!(local_addr.to_string(), from.to_string(), "should match"); - assert_eq!(b"Hello", &buf[..n], "should match"); - - // echo the data - let _ = echo_conn.send_to(&buf[..n], from).await; - } - }); - - let mut buf = vec![0u8; 1500]; - - for _ in 0..10 { - log::debug!("sending \"Hello\".."); - conn.send_to(b"Hello", echo_addr).await?; - - let (_, from) = conn.recv_from(&mut buf).await?; - - // verify the message was received from the relay address - assert_eq!(echo_addr.to_string(), from.to_string(), "should match"); - - tokio::time::sleep(Duration::from_millis(100)).await; - } - - tokio::time::sleep(Duration::from_millis(100)).await; - - client.close().await?; - drop(done_tx); - - Ok(()) -} diff --git a/util/.gitignore b/util/.gitignore deleted file mode 100644 index 81561ed32..000000000 --- a/util/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -/target/ -/.idea/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# These are backup files generated by rustfmt -**/*.rs.bk diff --git a/util/CHANGELOG.md b/util/CHANGELOG.md deleted file mode 100644 index c7946330b..000000000 --- a/util/CHANGELOG.md +++ /dev/null @@ -1,24 +0,0 @@ -# webrtc-util changelog - -## v0.7.0 - -### Breaking changes - -* Make functions non-async [#338](https://github.com/webrtc-rs/webrtc/pull/338): - - `Bridge`: - - `drop_next_nwrites`; - - `reorder_next_nwrites`. - - `Conn`: - - `local_addr`; - - `remote_addr`. - - -## v0.6.0 - -* Increase min version of `log` dependency to `0.4.16`. [#250 Fix log at ^0.4.16 to make tests compile](https://github.com/webrtc-rs/webrtc/pull/250) by [@k0nserv](https://github.com/k0nserv). -* Increased minimum support rust version to `1.60.0`. - -## Prior to 0.6.0 - -Before 0.6.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/util/releases). - diff --git a/util/Cargo.toml b/util/Cargo.toml deleted file mode 100644 index a902e8d9c..000000000 --- a/util/Cargo.toml +++ /dev/null @@ -1,66 +0,0 @@ -[package] -name = "webrtc-util" -version = "0.9.0" -authors = ["Rain Liu "] -edition = "2021" -description = "Utilities for WebRTC.rs stack" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc-util" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc/tree/master/util" - -[features] -default = ["buffer", "conn", "ifaces", "vnet", "marshal", "sync"] -buffer = [] -conn = ["buffer", "sync"] -ifaces = [] -vnet = ["ifaces"] -marshal = [] -sync = [] - -[dependencies] -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -lazy_static = "1" -async-trait = "0.1" -ipnet = "2.6.0" -log = "0.4" -rand = "0.8" -bytes = "1" -thiserror = "1" -portable-atomic = "1.6" - -[target.'cfg(not(windows))'.dependencies] -nix = "0.26.2" -libc = "0.2.126" - -[target.'cfg(windows)'.dependencies] -bitflags = "1.3" -winapi = { version = "0.3.9", features = [ - "basetsd", - "guiddef", - "ws2def", - "winerror", - "ws2ipdef", -] } - -[dev-dependencies] -tokio-test = "0.4" -env_logger = "0.10" -chrono = "0.4.28" -criterion = { version = "0.5", features = ["async_futures"] } -async-global-executor = "2" - -[[bench]] -name = "bench" -harness = false diff --git a/util/LICENSE-APACHE b/util/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/util/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/util/LICENSE-MIT b/util/LICENSE-MIT deleted file mode 100644 index e11d93bef..000000000 --- a/util/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 WebRTC.rs - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/util/README.md b/util/README.md deleted file mode 100644 index 03a6bd7b2..000000000 --- a/util/README.md +++ /dev/null @@ -1,30 +0,0 @@ -

- WebRTC.rs -
-

-

- - - - - - - - - - - - - - - - - License: MIT/Apache 2.0 - - - Discord - -

-

- Utilities for WebRTC.rs stack. Rewrite Pion Util/Transport in Rust -

diff --git a/util/benches/bench.rs b/util/benches/bench.rs deleted file mode 100644 index ff5eb5e71..000000000 --- a/util/benches/bench.rs +++ /dev/null @@ -1,33 +0,0 @@ -use criterion::async_executor::FuturesExecutor; -use criterion::{criterion_group, criterion_main, Criterion}; -use webrtc_util::Buffer; - -async fn buffer_write_then_read(times: u32) { - let buffer = Buffer::new(0, 0); - let mut packet: Vec = vec![0; 4]; - for _ in 0..times { - buffer.write(&[0, 1]).await.unwrap(); - buffer.read(&mut packet, None).await.unwrap(); - } -} - -fn benchmark_buffer(c: &mut Criterion) { - /////////////////////////////////////////////////////////////////////////////////////////////// - c.bench_function("Benchmark Buffer WriteThenRead 1", |b| { - b.to_async(FuturesExecutor) - .iter(|| buffer_write_then_read(1)); - }); - - c.bench_function("Benchmark Buffer WriteThenRead 10", |b| { - b.to_async(FuturesExecutor) - .iter(|| buffer_write_then_read(10)); - }); - - c.bench_function("Benchmark Buffer WriteThenRead 100", |b| { - b.to_async(FuturesExecutor) - .iter(|| buffer_write_then_read(100)); - }); -} - -criterion_group!(benches, benchmark_buffer); -criterion_main!(benches); diff --git a/util/codecov.yml b/util/codecov.yml deleted file mode 100644 index 2961b8e43..000000000 --- a/util/codecov.yml +++ /dev/null @@ -1,23 +0,0 @@ -codecov: - require_ci_to_pass: yes - max_report_age: off - token: 5dbbc458-896e-486d-af8e-96fc2fbbbcff - -coverage: - precision: 2 - round: down - range: 50..90 - status: - project: - default: - enabled: no - threshold: 0.2 - if_not_found: success - patch: - default: - enabled: no - if_not_found: success - changes: - default: - enabled: no - if_not_found: success diff --git a/util/doc/webrtc.rs.png b/util/doc/webrtc.rs.png deleted file mode 100644 index 7bf0dda2a..000000000 Binary files a/util/doc/webrtc.rs.png and /dev/null differ diff --git a/util/examples/display-interfaces.rs b/util/examples/display-interfaces.rs deleted file mode 100644 index f04badbb7..000000000 --- a/util/examples/display-interfaces.rs +++ /dev/null @@ -1,11 +0,0 @@ -use std::error::Error; - -use webrtc_util::ifaces::ifaces; - -fn main() -> Result<(), Box> { - let interfaces = ifaces()?; - for (index, interface) in interfaces.iter().enumerate() { - println!("{index} {interface:?}"); - } - Ok(()) -} diff --git a/util/src/buffer/buffer_test.rs b/util/src/buffer/buffer_test.rs deleted file mode 100644 index 375098824..000000000 --- a/util/src/buffer/buffer_test.rs +++ /dev/null @@ -1,358 +0,0 @@ -use tokio::sync::mpsc; -use tokio::time::{sleep, Duration}; -use tokio_test::assert_ok; - -use super::*; -use crate::error::Error; - -#[tokio::test] -async fn test_buffer() { - let buffer = Buffer::new(0, 0); - let mut packet: Vec = vec![0; 4]; - - // Write once - let n = assert_ok!(buffer.write(&[0, 1]).await); - assert_eq!(n, 2, "n must be 2"); - - // Read once - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[0, 1]); - - // Read deadline - let result = buffer.read(&mut packet, Some(Duration::new(0, 1))).await; - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), Error::ErrTimeout); - - // Write twice - let n = assert_ok!(buffer.write(&[2, 3, 4]).await); - assert_eq!(n, 3, "n must be 3"); - - let n = assert_ok!(buffer.write(&[5, 6, 7]).await); - assert_eq!(n, 3, "n must be 3"); - - // Read twice - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 3, "n must be 3"); - assert_eq!(&packet[..n], &[2, 3, 4]); - - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 3, "n must be 3"); - assert_eq!(&packet[..n], &[5, 6, 7]); - - // Write once prior to close. - let n = assert_ok!(buffer.write(&[3]).await); - assert_eq!(n, 1, "n must be 1"); - - // Close - buffer.close().await; - - // Future writes will error - let result = buffer.write(&[4]).await; - assert!(result.is_err()); - - // But we can read the remaining data. - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 1, "n must be 1"); - assert_eq!(&packet[..n], &[3]); - - // Until EOF - let result = buffer.read(&mut packet, None).await; - assert!(result.is_err()); - assert_eq!(Error::ErrBufferClosed, result.unwrap_err()); -} - -async fn test_wraparound(grow: bool) { - let buffer = Buffer::new(0, 0); - { - let mut b = buffer.buffer.lock().await; - let result = b.grow(); - assert!(result.is_ok()); - - b.head = b.data.len() - 13; - b.tail = b.head; - } - - let p1 = vec![1, 2, 3]; - let p2 = vec![4, 5, 6]; - let p3 = vec![7, 8, 9]; - let p4 = vec![10, 11, 12]; - - assert_ok!(buffer.write(&p1).await); - assert_ok!(buffer.write(&p2).await); - assert_ok!(buffer.write(&p3).await); - - let mut p = vec![0; 10]; - - let n = assert_ok!(buffer.read(&mut p, None).await); - assert_eq!(&p1[..], &p[..n]); - - if grow { - let mut b = buffer.buffer.lock().await; - let result = b.grow(); - assert!(result.is_ok()); - } - - let n = assert_ok!(buffer.read(&mut p, None).await); - assert_eq!(&p2[..], &p[..n]); - - assert_ok!(buffer.write(&p4).await); - - let n = assert_ok!(buffer.read(&mut p, None).await); - assert_eq!(&p3[..], &p[..n]); - let n = assert_ok!(buffer.read(&mut p, None).await); - assert_eq!(&p4[..], &p[..n]); - - { - let b = buffer.buffer.lock().await; - if !grow { - assert_eq!(b.data.len(), MIN_SIZE); - } else { - assert_eq!(b.data.len(), 2 * MIN_SIZE); - } - } -} - -#[tokio::test] -async fn test_buffer_wraparound() { - test_wraparound(false).await; -} - -#[tokio::test] -async fn test_buffer_wraparound_grow() { - test_wraparound(true).await; -} - -#[tokio::test] -async fn test_buffer_async() { - let buffer = Buffer::new(0, 0); - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - - let buffer2 = buffer.clone(); - tokio::spawn(async move { - let mut packet: Vec = vec![0; 4]; - - let n = assert_ok!(buffer2.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[0, 1]); - - let result = buffer2.read(&mut packet, None).await; - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), Error::ErrBufferClosed); - - drop(done_tx); - }); - - // Wait for the reader to start reading. - sleep(Duration::from_micros(1)).await; - - // Write once - let n = assert_ok!(buffer.write(&[0, 1]).await); - assert_eq!(n, 2, "n must be 2"); - - // Wait for the reader to start reading again. - sleep(Duration::from_micros(1)).await; - - // Close will unblock the reader. - buffer.close().await; - - done_rx.recv().await; -} - -#[tokio::test] -async fn test_buffer_limit_count() { - let buffer = Buffer::new(2, 0); - - assert_eq!(buffer.count().await, 0); - - // Write twice - let n = assert_ok!(buffer.write(&[0, 1]).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(buffer.count().await, 1); - - let n = assert_ok!(buffer.write(&[2, 3]).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(buffer.count().await, 2); - - // Over capacity - let result = buffer.write(&[4, 5]).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrBufferFull); - } - assert_eq!(buffer.count().await, 2); - - // Read once - let mut packet: Vec = vec![0; 4]; - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[0, 1]); - assert_eq!(buffer.count().await, 1); - - // Write once - let n = assert_ok!(buffer.write(&[6, 7]).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(buffer.count().await, 2); - - // Over capacity - let result = buffer.write(&[8, 9]).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrBufferFull, err); - } - assert_eq!(buffer.count().await, 2); - - // Read twice - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[2, 3]); - assert_eq!(buffer.count().await, 1); - - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[6, 7]); - assert_eq!(buffer.count().await, 0); - - // Nothing left. - buffer.close().await; -} - -#[tokio::test] -async fn test_buffer_limit_size() { - let buffer = Buffer::new(0, 11); - - assert_eq!(buffer.size().await, 0); - - // Write twice - let n = assert_ok!(buffer.write(&[0, 1]).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(buffer.size().await, 4); - - let n = assert_ok!(buffer.write(&[2, 3]).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(buffer.size().await, 8); - - // Over capacity - let result = buffer.write(&[4, 5]).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrBufferFull, err); - } - assert_eq!(buffer.size().await, 8); - - // Cheeky write at exact size. - let n = assert_ok!(buffer.write(&[6]).await); - assert_eq!(n, 1, "n must be 1"); - assert_eq!(buffer.size().await, 11); - - // Read once - let mut packet: Vec = vec![0; 4]; - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[0, 1]); - assert_eq!(buffer.size().await, 7); - - // Write once - let n = assert_ok!(buffer.write(&[7, 8]).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(buffer.size().await, 11); - - // Over capacity - let result = buffer.write(&[9, 10]).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(Error::ErrBufferFull, err); - } - assert_eq!(buffer.size().await, 11); - - // Read everything - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[2, 3]); - assert_eq!(buffer.size().await, 7); - - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 1, "n must be 1"); - assert_eq!(&packet[..n], &[6]); - assert_eq!(buffer.size().await, 4); - - let n = assert_ok!(buffer.read(&mut packet, None).await); - assert_eq!(n, 2, "n must be 2"); - assert_eq!(&packet[..n], &[7, 8]); - assert_eq!(buffer.size().await, 0); - - // Nothing left. - buffer.close().await; -} - -#[tokio::test] -async fn test_buffer_limit_sizes() { - let sizes = vec![ - 128 * 1024, - 1024 * 1024, - 8 * 1024 * 1024, - 0, // default - ]; - const HEADER_SIZE: usize = 2; - const PACKET_SIZE: usize = 0x8000; - - for mut size in sizes { - let mut name = "default".to_owned(); - if size > 0 { - name = format!("{}kbytes", size / 1024); - } - - let buffer = Buffer::new(0, 0); - if size == 0 { - size = MAX_SIZE; - } else { - buffer.set_limit_size(size + HEADER_SIZE).await; - } - - //assert.NoError(buffer.SetReadDeadline(now.Add(5 * time.Second))) // Set deadline to avoid test deadlock - - let n_packets = size / (PACKET_SIZE + HEADER_SIZE); - let pkt = vec![0; PACKET_SIZE]; - for _ in 0..n_packets { - assert_ok!(buffer.write(&pkt).await); - } - - // Next write is expected to be errored. - let result = buffer.write(&pkt).await; - assert!(result.is_err(), "{}", name); - assert_eq!(result.unwrap_err(), Error::ErrBufferFull, "{name}"); - - let mut packet = vec![0; size]; - for _ in 0..n_packets { - let n = assert_ok!(buffer.read(&mut packet, Some(Duration::new(5, 0))).await); - assert_eq!(n, PACKET_SIZE, "{name}"); - } - } -} - -#[tokio::test] -async fn test_buffer_misc() { - let buffer = Buffer::new(0, 0); - - // Write once - let n = assert_ok!(buffer.write(&[0, 1, 2, 3]).await); - assert_eq!(n, 4, "n must be 4"); - - // Try to read with a short buffer - let mut packet: Vec = vec![0; 3]; - let result = buffer.read(&mut packet, None).await; - assert!(result.is_err()); - if let Err(err) = result { - assert_eq!(err, Error::ErrBufferShort); - } - - // Close - buffer.close().await; - - // check is_close - assert!(buffer.is_closed().await); - - // Make sure you can Close twice - buffer.close().await; -} diff --git a/util/src/buffer/mod.rs b/util/src/buffer/mod.rs deleted file mode 100644 index a39a3e9ec..000000000 --- a/util/src/buffer/mod.rs +++ /dev/null @@ -1,322 +0,0 @@ -#[cfg(test)] -mod buffer_test; - -use std::sync::Arc; - -use tokio::sync::{Mutex, Notify}; -use tokio::time::{timeout, Duration}; - -use crate::error::{Error, Result}; - -const MIN_SIZE: usize = 2048; -const CUTOFF_SIZE: usize = 128 * 1024; -const MAX_SIZE: usize = 4 * 1024 * 1024; - -/// Buffer allows writing packets to an intermediate buffer, which can then be read form. -/// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read. -#[derive(Debug)] -struct BufferInternal { - data: Vec, - head: usize, - tail: usize, - - closed: bool, - subs: bool, - - count: usize, - limit_count: usize, - limit_size: usize, -} - -impl BufferInternal { - /// available returns true if the buffer is large enough to fit a packet - /// of the given size, taking overhead into account. - fn available(&self, size: usize) -> bool { - let mut available = self.head as isize - self.tail as isize; - if available <= 0 { - available += self.data.len() as isize; - } - // we interpret head=tail as empty, so always keep a byte free - size as isize + 2 < available - } - - /// grow increases the size of the buffer. If it returns nil, then the - /// buffer has been grown. It returns ErrFull if hits a limit. - fn grow(&mut self) -> Result<()> { - let mut newsize = if self.data.len() < CUTOFF_SIZE { - 2 * self.data.len() - } else { - 5 * self.data.len() / 4 - }; - - if newsize < MIN_SIZE { - newsize = MIN_SIZE - } - if (self.limit_size == 0/*|| sizeHardlimit*/) && newsize > MAX_SIZE { - newsize = MAX_SIZE - } - - // one byte slack - if self.limit_size > 0 && newsize > self.limit_size + 1 { - newsize = self.limit_size + 1 - } - - if newsize <= self.data.len() { - return Err(Error::ErrBufferFull); - } - - let mut newdata: Vec = vec![0; newsize]; - - let mut n; - if self.head <= self.tail { - // data was contiguous - n = self.tail - self.head; - newdata[..n].copy_from_slice(&self.data[self.head..self.tail]); - } else { - // data was discontiguous - n = self.data.len() - self.head; - newdata[..n].copy_from_slice(&self.data[self.head..]); - newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]); - n += self.tail; - } - self.head = 0; - self.tail = n; - self.data = newdata; - - Ok(()) - } - - fn size(&self) -> usize { - let mut size = self.tail as isize - self.head as isize; - if size < 0 { - size += self.data.len() as isize; - } - size as usize - } -} - -#[derive(Debug, Clone)] -pub struct Buffer { - buffer: Arc>, - notify: Arc, -} - -impl Buffer { - pub fn new(limit_count: usize, limit_size: usize) -> Self { - Buffer { - buffer: Arc::new(Mutex::new(BufferInternal { - data: vec![], - head: 0, - tail: 0, - - closed: false, - subs: false, - - count: 0, - limit_count, - limit_size, - })), - notify: Arc::new(Notify::new()), - } - } - - /// Write appends a copy of the packet data to the buffer. - /// Returns ErrFull if the packet doesn't fit. - /// Note that the packet size is limited to 65536 bytes since v0.11.0 - /// due to the internal data structure. - pub async fn write(&self, packet: &[u8]) -> Result { - if packet.len() >= 0x10000 { - return Err(Error::ErrPacketTooBig); - } - - let mut b = self.buffer.lock().await; - - if b.closed { - return Err(Error::ErrBufferClosed); - } - - if (b.limit_count > 0 && b.count >= b.limit_count) - || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size) - { - return Err(Error::ErrBufferFull); - } - - // grow the buffer until the packet fits - while !b.available(packet.len()) { - b.grow()?; - } - - // store the length of the packet - let tail = b.tail; - b.data[tail] = (packet.len() >> 8) as u8; - b.tail += 1; - if b.tail >= b.data.len() { - b.tail = 0; - } - - let tail = b.tail; - b.data[tail] = packet.len() as u8; - b.tail += 1; - if b.tail >= b.data.len() { - b.tail = 0; - } - - // store the packet - let end = std::cmp::min(b.data.len(), b.tail + packet.len()); - let n = end - b.tail; - let tail = b.tail; - b.data[tail..end].copy_from_slice(&packet[..n]); - b.tail += n; - if b.tail >= b.data.len() { - // we reached the end, wrap around - let m = packet.len() - n; - b.data[..m].copy_from_slice(&packet[n..]); - b.tail = m; - } - b.count += 1; - - if b.subs { - // we have other are waiting data - self.notify.notify_one(); - b.subs = false; - } - - Ok(packet.len()) - } - - // Read populates the given byte slice, returning the number of bytes read. - // Blocks until data is available or the buffer is closed. - // Returns io.ErrShortBuffer is the packet is too small to copy the Write. - // Returns io.EOF if the buffer is closed. - pub async fn read(&self, packet: &mut [u8], duration: Option) -> Result { - loop { - { - // use {} to let LockGuard RAII - let mut b = self.buffer.lock().await; - - if b.head != b.tail { - // decode the packet size - let n1 = b.data[b.head]; - b.head += 1; - if b.head >= b.data.len() { - b.head = 0; - } - let n2 = b.data[b.head]; - b.head += 1; - if b.head >= b.data.len() { - b.head = 0; - } - let count = ((n1 as usize) << 8) | n2 as usize; - - // determine the number of bytes we'll actually copy - let mut copied = count; - if copied > packet.len() { - copied = packet.len(); - } - - // copy the data - if b.head + copied < b.data.len() { - packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]); - } else { - let k = b.data.len() - b.head; - packet[..k].copy_from_slice(&b.data[b.head..]); - packet[k..copied].copy_from_slice(&b.data[..copied - k]); - } - - // advance head, discarding any data that wasn't copied - b.head += count; - if b.head >= b.data.len() { - b.head -= b.data.len(); - } - - if b.head == b.tail { - // the buffer is empty, reset to beginning - // in order to improve cache locality. - b.head = 0; - b.tail = 0; - } - - b.count -= 1; - - if copied < count { - return Err(Error::ErrBufferShort); - } - return Ok(copied); - } else { - // Dont have data -> need wait - b.subs = true; - } - - if b.closed { - return Err(Error::ErrBufferClosed); - } - } - - // Wait for signal. - if let Some(d) = duration { - if timeout(d, self.notify.notified()).await.is_err() { - return Err(Error::ErrTimeout); - } - } else { - self.notify.notified().await; - } - } - } - - // Close will unblock any readers and prevent future writes. - // Data in the buffer can still be read, returning io.EOF when fully depleted. - pub async fn close(&self) { - // note: We don't use defer so we can close the notify channel after unlocking. - // This will unblock goroutines that can grab the lock immediately, instead of blocking again. - let mut b = self.buffer.lock().await; - - if b.closed { - return; - } - - b.closed = true; - self.notify.notify_waiters(); - } - - pub async fn is_closed(&self) -> bool { - let b = self.buffer.lock().await; - - b.closed - } - - // Count returns the number of packets in the buffer. - pub async fn count(&self) -> usize { - let b = self.buffer.lock().await; - - b.count - } - - // set_limit_count controls the maximum number of packets that can be buffered. - // Causes Write to return ErrFull when this limit is reached. - // A zero value will disable this limit. - pub async fn set_limit_count(&self, limit: usize) { - let mut b = self.buffer.lock().await; - - b.limit_count = limit - } - - // Size returns the total byte size of packets in the buffer. - pub async fn size(&self) -> usize { - let b = self.buffer.lock().await; - - b.size() - } - - // set_limit_size controls the maximum number of bytes that can be buffered. - // Causes Write to return ErrFull when this limit is reached. - // A zero value means 4MB since v0.11.0. - // - // User can set packetioSizeHardlimit build tag to enable 4MB hardlimit. - // When packetioSizeHardlimit build tag is set, set_limit_size exceeding - // the hardlimit will be silently discarded. - pub async fn set_limit_size(&self, limit: usize) { - let mut b = self.buffer.lock().await; - - b.limit_size = limit - } -} diff --git a/util/src/conn/conn_bridge.rs b/util/src/conn/conn_bridge.rs deleted file mode 100644 index a16ed8931..000000000 --- a/util/src/conn/conn_bridge.rs +++ /dev/null @@ -1,236 +0,0 @@ -use std::collections::VecDeque; -use std::io::{Error, ErrorKind}; -use std::str::FromStr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use bytes::Bytes; -use portable_atomic::AtomicUsize; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; - -use super::*; - -const TICK_WAIT: Duration = Duration::from_micros(10); - -/// BridgeConn is a Conn that represents an endpoint of the bridge. -struct BridgeConn { - br: Arc, - id: usize, - rd_rx: Mutex>, - loss_chance: u8, -} - -#[async_trait] -impl Conn for BridgeConn { - async fn connect(&self, _addr: SocketAddr) -> Result<()> { - Err(Error::new(ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, b: &mut [u8]) -> Result { - let mut rd_rx = self.rd_rx.lock().await; - let v = match rd_rx.recv().await { - Some(v) => v, - None => return Err(Error::new(ErrorKind::UnexpectedEof, "Unexpected EOF").into()), - }; - let l = std::cmp::min(v.len(), b.len()); - b[..l].copy_from_slice(&v[..l]); - Ok(l) - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - let n = self.recv(buf).await?; - Ok((n, SocketAddr::from_str("0.0.0.0:0")?)) - } - - async fn send(&self, b: &[u8]) -> Result { - if rand::random::() % 100 < self.loss_chance { - return Ok(b.len()); - } - - self.br.push(b, self.id).await - } - - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { - Err(Error::new(ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> Result { - Err(Error::new(ErrorKind::AddrNotAvailable, "Addr Not Available").into()) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> Result<()> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -pub type FilterCbFn = Box bool + Send + Sync>; - -/// Bridge represents a network between the two endpoints. -#[derive(Default)] -pub struct Bridge { - drop_nwrites: [AtomicUsize; 2], - reorder_nwrites: [AtomicUsize; 2], - - stack: [Mutex>; 2], - queue: [Mutex>; 2], - - wr_tx: [Option>; 2], - filter_cb: [Option; 2], -} - -impl Bridge { - pub fn new( - loss_chance: u8, - filter_cb0: Option, - filter_cb1: Option, - ) -> (Arc, impl Conn, impl Conn) { - let (wr_tx0, rd_rx0) = mpsc::channel(1024); - let (wr_tx1, rd_rx1) = mpsc::channel(1024); - - let br = Arc::new(Bridge { - wr_tx: [Some(wr_tx0), Some(wr_tx1)], - filter_cb: [filter_cb0, filter_cb1], - ..Default::default() - }); - let conn0 = BridgeConn { - br: Arc::clone(&br), - id: 0, - rd_rx: Mutex::new(rd_rx0), - loss_chance, - }; - let conn1 = BridgeConn { - br: Arc::clone(&br), - id: 1, - rd_rx: Mutex::new(rd_rx1), - loss_chance, - }; - - (br, conn0, conn1) - } - - /// Len returns number of queued packets. - #[allow(clippy::len_without_is_empty)] - pub async fn len(&self, id: usize) -> usize { - let q = self.queue[id].lock().await; - q.len() - } - - pub async fn push(&self, b: &[u8], id: usize) -> Result { - // Push rate should be limited as same as Tick rate. - // Otherwise, queue grows too fast on free running Write. - tokio::time::sleep(TICK_WAIT).await; - - let d = Bytes::from(b.to_vec()); - if self.drop_nwrites[id].load(Ordering::SeqCst) > 0 { - self.drop_nwrites[id].fetch_sub(1, Ordering::SeqCst); - } else if self.reorder_nwrites[id].load(Ordering::SeqCst) > 0 { - let mut stack = self.stack[id].lock().await; - stack.push_back(d); - if self.reorder_nwrites[id].fetch_sub(1, Ordering::SeqCst) == 1 { - let ok = inverse(&mut stack); - if ok { - let mut queue = self.queue[id].lock().await; - queue.append(&mut stack); - } - } - } else if let Some(filter_cb) = &self.filter_cb[id] { - if filter_cb(&d) { - let mut queue = self.queue[id].lock().await; - queue.push_back(d); - } - } else { - //log::debug!("queue [{}] enter lock", id); - let mut queue = self.queue[id].lock().await; - queue.push_back(d); - //log::debug!("queue [{}] exit lock", id); - } - - Ok(b.len()) - } - - /// Reorder inverses the order of packets currently in the specified queue. - pub async fn reorder(&self, id: usize) -> bool { - let mut queue = self.queue[id].lock().await; - inverse(&mut queue) - } - - /// Drop drops the specified number of packets from the given offset index - /// of the specified queue. - pub async fn drop_offset(&self, id: usize, offset: usize, n: usize) { - let mut queue = self.queue[id].lock().await; - queue.drain(offset..offset + n); - } - - /// drop_next_nwrites drops the next n packets that will be written - /// to the specified queue. - pub fn drop_next_nwrites(&self, id: usize, n: usize) { - self.drop_nwrites[id].store(n, Ordering::SeqCst); - } - - /// reorder_next_nwrites drops the next n packets that will be written - /// to the specified queue. - pub fn reorder_next_nwrites(&self, id: usize, n: usize) { - self.reorder_nwrites[id].store(n, Ordering::SeqCst); - } - - pub async fn clear(&self) { - for id in 0..2 { - let mut queue = self.queue[id].lock().await; - queue.clear(); - } - } - - /// Tick attempts to hand a packet from the queue for each directions, to readers, - /// if there are waiting on the queue. If there's no reader, it will return - /// immediately. - pub async fn tick(&self) -> usize { - let mut n = 0; - - for id in 0..2 { - let mut queue = self.queue[id].lock().await; - if let Some(d) = queue.pop_front() { - n += 1; - if let Some(wr_tx) = &self.wr_tx[1 - id] { - let _ = wr_tx.send(d).await; - } - } - } - - n - } - - /// Process repeats tick() calls until no more outstanding packet in the queues. - pub async fn process(&self) { - loop { - tokio::time::sleep(TICK_WAIT).await; - self.tick().await; - if self.len(0).await == 0 && self.len(1).await == 0 { - break; - } - } - } -} - -pub(crate) fn inverse(s: &mut VecDeque) -> bool { - if s.len() < 2 { - return false; - } - - let (mut i, mut j) = (0, s.len() - 1); - while i < j { - s.swap(i, j); - i += 1; - j -= 1; - } - - true -} diff --git a/util/src/conn/conn_bridge_test.rs b/util/src/conn/conn_bridge_test.rs deleted file mode 100644 index d775a169d..000000000 --- a/util/src/conn/conn_bridge_test.rs +++ /dev/null @@ -1,255 +0,0 @@ -use std::collections::VecDeque; -use std::sync::Arc; - -use bytes::Bytes; -use tokio::sync::mpsc; - -use super::conn_bridge::*; -use super::*; - -static MSG1: Bytes = Bytes::from_static(b"ADC"); -static MSG2: Bytes = Bytes::from_static(b"DEFG"); - -#[tokio::test] -async fn test_bridge_normal() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn0.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn1.recv(&mut buf).await?; - let _ = tx.send(n).await; - Result::<()>::Ok(()) - }); - - br.process().await; - - let n = rx.recv().await.unwrap(); - assert_eq!(n, MSG1.len(), "unexpected length"); - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_drop_first_packet_from_conn0() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn0.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - let n = conn0.send(&MSG2).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn1.recv(&mut buf).await?; - let _ = tx.send(n).await; - Result::<()>::Ok(()) - }); - - br.drop_offset(0, 0, 1).await; - br.process().await; - - let n = rx.recv().await.unwrap(); - assert_eq!(n, MSG2.len(), "unexpected length"); - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_drop_second_packet_from_conn0() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn0.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - let n = conn0.send(&MSG2).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn1.recv(&mut buf).await?; - let _ = tx.send(n).await; - Result::<()>::Ok(()) - }); - - br.drop_offset(0, 1, 1).await; - br.process().await; - - let n = rx.recv().await.unwrap(); - assert_eq!(n, MSG1.len(), "unexpected length"); - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_drop_first_packet_from_conn1() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn1.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - let n = conn1.send(&MSG2).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn0.recv(&mut buf).await?; - let _ = tx.send(n).await; - Result::<()>::Ok(()) - }); - - br.drop_offset(1, 0, 1).await; - br.process().await; - - let n = rx.recv().await.unwrap(); - assert_eq!(n, MSG2.len(), "unexpected length"); - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_drop_second_packet_from_conn1() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn1.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - let n = conn1.send(&MSG2).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn0.recv(&mut buf).await?; - let _ = tx.send(n).await; - Result::<()>::Ok(()) - }); - - br.drop_offset(1, 1, 1).await; - br.process().await; - - let n = rx.recv().await.unwrap(); - assert_eq!(n, MSG1.len(), "unexpected length"); - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_reorder_packets_from_conn0() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn0.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - let n = conn0.send(&MSG2).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn1.recv(&mut buf).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - let n = conn1.recv(&mut buf).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - - let _ = rx.recv().await; - - Result::<()>::Ok(()) - }); - - br.reorder(0).await; - br.process().await; - - let _ = tx.send(()).await; - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_reorder_packets_from_conn1() -> Result<()> { - let (br, conn0, conn1) = Bridge::new(0, None, None); - - let n = conn1.send(&MSG1).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - let n = conn1.send(&MSG2).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - - let (tx, mut rx) = mpsc::channel(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - let n = conn0.recv(&mut buf).await?; - assert_eq!(n, MSG2.len(), "unexpected length"); - let n = conn0.recv(&mut buf).await?; - assert_eq!(n, MSG1.len(), "unexpected length"); - - let _ = rx.recv().await; - - Result::<()>::Ok(()) - }); - - br.reorder(1).await; - br.process().await; - - let _ = tx.send(()).await; - - Ok(()) -} - -#[tokio::test] -async fn test_bridge_inverse_error() -> Result<()> { - let mut q = VecDeque::new(); - q.push_back(MSG1.clone()); - assert!(!inverse(&mut q)); - Ok(()) -} - -#[tokio::test] -async fn test_bridge_drop_next_n_packets() -> Result<()> { - for id in 0..2 { - let (br, conn0, conn1) = Bridge::new(0, None, None); - br.drop_next_nwrites(id, 3); - let conns: Vec> = vec![Arc::new(conn0), Arc::new(conn1)]; - let src_conn = Arc::clone(&conns[id]); - let dst_conn = Arc::clone(&conns[1 - id]); - - let (tx, mut rx) = mpsc::channel(5); - - tokio::spawn(async move { - let mut buf = vec![0u8; 256]; - for _ in 0..2u8 { - let n = dst_conn.recv(&mut buf).await?; - let _ = tx.send(buf[..n].to_vec()).await; - } - - Result::<()>::Ok(()) - }); - - let mut msgs = vec![]; - for i in 0..5u8 { - let msg = format!("msg{i}"); - let n = src_conn.send(msg.as_bytes()).await?; - assert_eq!(n, msg.len(), "[{id}] unexpected length"); - msgs.push(msg); - br.process().await; - } - - for i in 0..2 { - if let Some(buf) = rx.recv().await { - assert_eq!(msgs[i + 3].as_bytes(), &buf); - } else { - panic!("{id} unexpected number of packets"); - } - } - } - - Ok(()) -} diff --git a/util/src/conn/conn_disconnected_packet.rs b/util/src/conn/conn_disconnected_packet.rs deleted file mode 100644 index ec4230304..000000000 --- a/util/src/conn/conn_disconnected_packet.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::net::Ipv4Addr; -use std::sync::Arc; - -use super::*; -use crate::sync::RwLock; - -/// Since UDP is connectionless, as a server, it doesn't know how to reply -/// simply using the `Write` method. So, to make it work, `disconnectedPacketConn` -/// will infer the last packet that it reads as the reply address for `Write` -pub struct DisconnectedPacketConn { - raddr: RwLock, - pconn: Arc, -} - -impl DisconnectedPacketConn { - pub fn new(conn: Arc) -> Self { - DisconnectedPacketConn { - raddr: RwLock::new(SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)), - pconn: conn, - } - } -} - -#[async_trait] -impl Conn for DisconnectedPacketConn { - async fn connect(&self, addr: SocketAddr) -> Result<()> { - self.pconn.connect(addr).await - } - - async fn recv(&self, buf: &mut [u8]) -> Result { - let (n, addr) = self.pconn.recv_from(buf).await?; - *self.raddr.write() = addr; - Ok(n) - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - self.pconn.recv_from(buf).await - } - - async fn send(&self, buf: &[u8]) -> Result { - let addr = *self.raddr.read(); - self.pconn.send_to(buf, addr).await - } - - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result { - self.pconn.send_to(buf, target).await - } - - fn local_addr(&self) -> Result { - self.pconn.local_addr() - } - - fn remote_addr(&self) -> Option { - let raddr = *self.raddr.read(); - Some(raddr) - } - - async fn close(&self) -> Result<()> { - self.pconn.close().await - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/util/src/conn/conn_pipe.rs b/util/src/conn/conn_pipe.rs deleted file mode 100644 index 3427a9100..000000000 --- a/util/src/conn/conn_pipe.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::io::{Error, ErrorKind}; -use std::str::FromStr; - -use tokio::sync::{mpsc, Mutex}; - -use super::*; - -struct Pipe { - rd_rx: Mutex>>, - wr_tx: Mutex>>, -} - -pub fn pipe() -> (impl Conn, impl Conn) { - let (cb1_tx, cb1_rx) = mpsc::channel(16); - let (cb2_tx, cb2_rx) = mpsc::channel(16); - - let p1 = Pipe { - rd_rx: Mutex::new(cb1_rx), - wr_tx: Mutex::new(cb2_tx), - }; - - let p2 = Pipe { - rd_rx: Mutex::new(cb2_rx), - wr_tx: Mutex::new(cb1_tx), - }; - - (p1, p2) -} - -#[async_trait] -impl Conn for Pipe { - async fn connect(&self, _addr: SocketAddr) -> Result<()> { - Err(Error::new(ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, b: &mut [u8]) -> Result { - let mut rd_rx = self.rd_rx.lock().await; - let v = match rd_rx.recv().await { - Some(v) => v, - None => return Err(Error::new(ErrorKind::UnexpectedEof, "Unexpected EOF").into()), - }; - let l = std::cmp::min(v.len(), b.len()); - b[..l].copy_from_slice(&v[..l]); - Ok(l) - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - let n = self.recv(buf).await?; - Ok((n, SocketAddr::from_str("0.0.0.0:0")?)) - } - - async fn send(&self, b: &[u8]) -> Result { - let wr_tx = self.wr_tx.lock().await; - match wr_tx.send(b.to_vec()).await { - Ok(_) => {} - Err(err) => return Err(Error::new(ErrorKind::Other, err.to_string()).into()), - }; - Ok(b.len()) - } - - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { - Err(Error::new(ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> Result { - Err(Error::new(ErrorKind::AddrNotAvailable, "Addr Not Available").into()) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> Result<()> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/util/src/conn/conn_pipe_test.rs b/util/src/conn/conn_pipe_test.rs deleted file mode 100644 index 6fc896070..000000000 --- a/util/src/conn/conn_pipe_test.rs +++ /dev/null @@ -1,27 +0,0 @@ -use super::conn_pipe::*; -use super::*; - -#[tokio::test] -async fn test_pipe() -> Result<()> { - let (c1, c2) = pipe(); - let mut b1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; - let n = c1.send(&b1).await?; - assert_eq!(n, 10); - - let mut b2 = vec![133; 100]; - let n = c2.recv(&mut b2).await?; - assert_eq!(n, 10); - assert_eq!(&b2[..n], &b1[..]); - - let n = c2.send(&b2[..10]).await?; - assert_eq!(n, 10); - let n = c2.send(&b2[..5]).await?; - assert_eq!(n, 5); - - let n = c1.recv(&mut b1).await?; - assert_eq!(n, 10); - let n = c1.recv(&mut b1).await?; - assert_eq!(n, 5); - - Ok(()) -} diff --git a/util/src/conn/conn_test.rs b/util/src/conn/conn_test.rs deleted file mode 100644 index 1e82688f8..000000000 --- a/util/src/conn/conn_test.rs +++ /dev/null @@ -1,22 +0,0 @@ -use super::*; - -#[tokio::test] -async fn test_conn_lookup_host() -> Result<()> { - let stun_serv_addr = "stun1.l.google.com:19302"; - - if let Ok(ipv4_addr) = lookup_host(true, stun_serv_addr).await { - assert!( - ipv4_addr.is_ipv4(), - "expected ipv4 but got ipv6: {ipv4_addr}" - ); - } - - if let Ok(ipv6_addr) = lookup_host(false, stun_serv_addr).await { - assert!( - ipv6_addr.is_ipv6(), - "expected ipv6 but got ipv4: {ipv6_addr}" - ); - } - - Ok(()) -} diff --git a/util/src/conn/conn_udp.rs b/util/src/conn/conn_udp.rs deleted file mode 100644 index a6fa8d542..000000000 --- a/util/src/conn/conn_udp.rs +++ /dev/null @@ -1,42 +0,0 @@ -use tokio::net::UdpSocket; - -use super::*; - -#[async_trait] -impl Conn for UdpSocket { - async fn connect(&self, addr: SocketAddr) -> Result<()> { - Ok(self.connect(addr).await?) - } - - async fn recv(&self, buf: &mut [u8]) -> Result { - Ok(self.recv(buf).await?) - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - Ok(self.recv_from(buf).await?) - } - - async fn send(&self, buf: &[u8]) -> Result { - Ok(self.send(buf).await?) - } - - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result { - Ok(self.send_to(buf, target).await?) - } - - fn local_addr(&self) -> Result { - Ok(self.local_addr()?) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> Result<()> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/util/src/conn/conn_udp_listener.rs b/util/src/conn/conn_udp_listener.rs deleted file mode 100644 index 47dcd5c86..000000000 --- a/util/src/conn/conn_udp_listener.rs +++ /dev/null @@ -1,298 +0,0 @@ -use core::sync::atomic::Ordering; -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; - -use portable_atomic::AtomicBool; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, watch, Mutex}; - -use super::*; -use crate::error::Error; -use crate::Buffer; - -const RECEIVE_MTU: usize = 8192; -const DEFAULT_LISTEN_BACKLOG: usize = 128; // same as Linux default - -pub type AcceptFilterFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -type AcceptDoneCh = (mpsc::Receiver>, watch::Receiver<()>); - -/// listener is used in the [DTLS](https://github.com/webrtc-rs/dtls) and -/// [SCTP](https://github.com/webrtc-rs/sctp) transport to provide a connection-oriented -/// listener over a UDP. -struct ListenerImpl { - pconn: Arc, - accepting: Arc, - accept_ch_tx: Arc>>>>, - done_ch_tx: Arc>>>, - ch_rx: Arc>, - conns: Arc>>>, -} - -#[async_trait] -impl Listener for ListenerImpl { - /// accept waits for and returns the next connection to the listener. - async fn accept(&self) -> Result<(Arc, SocketAddr)> { - let (accept_ch_rx, done_ch_rx) = &mut *self.ch_rx.lock().await; - - tokio::select! { - c = accept_ch_rx.recv() =>{ - if let Some(c) = c{ - let raddr = c.raddr; - Ok((c, raddr)) - }else{ - Err(Error::ErrClosedListenerAcceptCh) - } - } - _ = done_ch_rx.changed() => Err(Error::ErrClosedListener), - } - } - - /// close closes the listener. - /// Any blocked Accept operations will be unblocked and return errors. - async fn close(&self) -> Result<()> { - if self.accepting.load(Ordering::SeqCst) { - self.accepting.store(false, Ordering::SeqCst); - { - let mut done_ch = self.done_ch_tx.lock().await; - done_ch.take(); - } - { - let mut accept_ch = self.accept_ch_tx.lock().await; - accept_ch.take(); - } - } - - Ok(()) - } - - /// Addr returns the listener's network address. - async fn addr(&self) -> Result { - self.pconn.local_addr() - } -} - -/// ListenConfig stores options for listening to an address. -#[derive(Default)] -pub struct ListenConfig { - /// Backlog defines the maximum length of the queue of pending - /// connections. It is equivalent of the backlog argument of - /// POSIX listen function. - /// If a connection request arrives when the queue is full, - /// the request will be silently discarded, unlike TCP. - /// Set zero to use default value 128 which is same as Linux default. - pub backlog: usize, - - /// AcceptFilter determines whether the new conn should be made for - /// the incoming packet. If not set, any packet creates new conn. - pub accept_filter: Option, -} - -pub async fn listen(laddr: A) -> Result { - ListenConfig::default().listen(laddr).await -} - -impl ListenConfig { - /// Listen creates a new listener based on the ListenConfig. - pub async fn listen(&mut self, laddr: A) -> Result { - if self.backlog == 0 { - self.backlog = DEFAULT_LISTEN_BACKLOG; - } - - let pconn = Arc::new(UdpSocket::bind(laddr).await?); - let (accept_ch_tx, accept_ch_rx) = mpsc::channel(self.backlog); - let (done_ch_tx, done_ch_rx) = watch::channel(()); - - let l = ListenerImpl { - pconn, - accepting: Arc::new(AtomicBool::new(true)), - accept_ch_tx: Arc::new(Mutex::new(Some(accept_ch_tx))), - done_ch_tx: Arc::new(Mutex::new(Some(done_ch_tx))), - ch_rx: Arc::new(Mutex::new((accept_ch_rx, done_ch_rx.clone()))), - conns: Arc::new(Mutex::new(HashMap::new())), - }; - - let pconn = Arc::clone(&l.pconn); - let accepting = Arc::clone(&l.accepting); - let accept_filter = self.accept_filter.take(); - let accept_ch_tx = Arc::clone(&l.accept_ch_tx); - let conns = Arc::clone(&l.conns); - tokio::spawn(async move { - ListenConfig::read_loop( - done_ch_rx, - pconn, - accepting, - accept_filter, - accept_ch_tx, - conns, - ) - .await; - }); - - Ok(l) - } - - /// read_loop has to tasks: - /// 1. Dispatching incoming packets to the correct Conn. - /// It can therefore not be ended until all Conns are closed. - /// 2. Creating a new Conn when receiving from a new remote. - async fn read_loop( - mut done_ch_rx: watch::Receiver<()>, - pconn: Arc, - accepting: Arc, - accept_filter: Option, - accept_ch_tx: Arc>>>>, - conns: Arc>>>, - ) { - let mut buf = vec![0u8; RECEIVE_MTU]; - - loop { - tokio::select! { - _ = done_ch_rx.changed() => { - break; - } - result = pconn.recv_from(&mut buf) => { - match result { - Ok((n, raddr)) => { - let udp_conn = match ListenConfig::get_udp_conn( - &pconn, - &accepting, - &accept_filter, - &accept_ch_tx, - &conns, - raddr, - &buf[..n], - ) - .await - { - Ok(conn) => conn, - Err(_) => continue, - }; - - if let Some(conn) = udp_conn { - let _ = conn.buffer.write(&buf[..n]).await; - } - } - Err(err) => { - log::warn!("ListenConfig pconn.recv_from error: {}", err); - break; - } - }; - } - } - } - } - - async fn get_udp_conn( - pconn: &Arc, - accepting: &Arc, - accept_filter: &Option, - accept_ch_tx: &Arc>>>>, - conns: &Arc>>>, - raddr: SocketAddr, - buf: &[u8], - ) -> Result>> { - { - let m = conns.lock().await; - if let Some(conn) = m.get(raddr.to_string().as_str()) { - return Ok(Some(conn.clone())); - } - } - - if !accepting.load(Ordering::SeqCst) { - return Err(Error::ErrClosedListener); - } - - if let Some(f) = accept_filter { - if !(f(buf).await) { - return Ok(None); - } - } - - let udp_conn = Arc::new(UdpConn::new(Arc::clone(pconn), Arc::clone(conns), raddr)); - { - let accept_ch = accept_ch_tx.lock().await; - if let Some(tx) = &*accept_ch { - if tx.try_send(Arc::clone(&udp_conn)).is_err() { - return Err(Error::ErrListenQueueExceeded); - } - } else { - return Err(Error::ErrClosedListenerAcceptCh); - } - } - - { - let mut m = conns.lock().await; - m.insert(raddr.to_string(), Arc::clone(&udp_conn)); - } - - Ok(Some(udp_conn)) - } -} - -/// UdpConn augments a connection-oriented connection over a UdpSocket -pub struct UdpConn { - pconn: Arc, - conns: Arc>>>, - raddr: SocketAddr, - buffer: Buffer, -} - -impl UdpConn { - fn new( - pconn: Arc, - conns: Arc>>>, - raddr: SocketAddr, - ) -> Self { - UdpConn { - pconn, - conns, - raddr, - buffer: Buffer::new(0, 0), - } - } -} - -#[async_trait] -impl Conn for UdpConn { - async fn connect(&self, addr: SocketAddr) -> Result<()> { - self.pconn.connect(addr).await - } - - async fn recv(&self, buf: &mut [u8]) -> Result { - Ok(self.buffer.read(buf, None).await?) - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - let n = self.buffer.read(buf, None).await?; - Ok((n, self.raddr)) - } - - async fn send(&self, buf: &[u8]) -> Result { - self.pconn.send_to(buf, self.raddr).await - } - - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result { - self.pconn.send_to(buf, target).await - } - - fn local_addr(&self) -> Result { - self.pconn.local_addr() - } - - fn remote_addr(&self) -> Option { - Some(self.raddr) - } - - async fn close(&self) -> Result<()> { - let mut conns = self.conns.lock().await; - conns.remove(self.raddr.to_string().as_str()); - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/util/src/conn/conn_udp_listener_test.rs b/util/src/conn/conn_udp_listener_test.rs deleted file mode 100644 index 3ac0a93c0..000000000 --- a/util/src/conn/conn_udp_listener_test.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::future::Future; -use std::pin::Pin; - -use tokio::net::UdpSocket; -use tokio::sync::mpsc; -use tokio::time::Duration; - -use super::conn_udp_listener::*; -use super::*; -use crate::error::{Error, Result}; - -async fn pipe() -> Result<( - Arc, - Arc, - UdpSocket, -)> { - // Start listening - let listener = Arc::new(listen("0.0.0.0:0").await?); - - // Open a connection - let d_conn = UdpSocket::bind("0.0.0.0:0").await?; - d_conn.connect(listener.addr().await?).await?; - - // Write to the connection to initiate it - let handshake = "hello"; - d_conn.send(handshake.as_bytes()).await?; - let daddr = d_conn.local_addr()?; - - // Accept the connection - let (l_conn, raddr) = listener.accept().await?; - assert_eq!(daddr, raddr, "remote address should be match"); - - let raddr = l_conn.remote_addr(); - if let Some(raddr) = raddr { - assert_eq!(daddr, raddr, "remote address should be match"); - } else { - panic!("expected Some, but got None, for remote_addr()"); - } - - let mut buf = vec![0u8; handshake.len()]; - let n = l_conn.recv(&mut buf).await?; - - let result = String::from_utf8(buf[..n].to_vec())?; - if handshake != result { - Err(Error::Other(format!( - "errHandshakeFailed: {handshake} != {result}" - ))) - } else { - Ok((listener, l_conn, d_conn)) - } -} - -#[tokio::test] -async fn test_listener_close_timeout() -> Result<()> { - let (listener, ca, _) = pipe().await?; - - listener.close().await?; - - // Close client after server closes to cleanup - ca.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_listener_close_unaccepted() -> Result<()> { - const BACKLOG: usize = 2; - - let listener = ListenConfig { - backlog: BACKLOG, - ..Default::default() - } - .listen("0.0.0.0:0") - .await?; - - for i in 0..BACKLOG as u8 { - let conn = UdpSocket::bind("0.0.0.0:0").await?; - conn.connect(listener.addr().await?).await?; - conn.send(&[i]).await?; - conn.close().await?; - } - - // Wait all packets being processed by readLoop - tokio::time::sleep(Duration::from_millis(100)).await; - - // Unaccepted connections must be closed by listener.Close() - listener.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_listener_accept_filter() -> Result<()> { - let tests = vec![("CreateConn", &[0xAA], true), ("Discarded", &[0x00], false)]; - - for (name, packet, expected) in tests { - let accept_filter: Option = Some(Box::new( - |pkt: &[u8]| -> Pin + Send + 'static>> { - let p0 = pkt[0]; - Box::pin(async move { p0 == 0xAA }) - }, - )); - - let listener = Arc::new( - ListenConfig { - accept_filter, - ..Default::default() - } - .listen("0.0.0.0:0") - .await?, - ); - - let conn = UdpSocket::bind("0.0.0.0:0").await?; - conn.connect(listener.addr().await?).await?; - conn.send(packet).await?; - - let (ch_accepted_tx, mut ch_accepted_rx) = mpsc::channel::<()>(1); - let mut ch_accepted_tx = Some(ch_accepted_tx); - let listener2 = Arc::clone(&listener); - tokio::spawn(async move { - let (c, _raddr) = match listener2.accept().await { - Ok((c, raddr)) => (c, raddr), - Err(err) => { - assert_eq!(Error::ErrClosedListener, err); - return Result::<()>::Ok(()); - } - }; - - ch_accepted_tx.take(); - c.close().await?; - - Result::<()>::Ok(()) - }); - - let mut accepted = false; - let mut timeout = false; - let timer = tokio::time::sleep(Duration::from_millis(10)); - tokio::pin!(timer); - tokio::select! { - _= ch_accepted_rx.recv()=>{ - accepted = true; - } - _ = timer.as_mut() => { - timeout = true; - } - } - - assert_eq!(accepted, expected, "{name}: unexpected result"); - assert_eq!(!timeout, expected, "{name}: unexpected result"); - - conn.close().await?; - listener.close().await?; - } - Ok(()) -} - -#[tokio::test] -async fn test_listener_concurrent() -> Result<()> { - const BACKLOG: usize = 2; - - let listener = Arc::new( - ListenConfig { - backlog: BACKLOG, - ..Default::default() - } - .listen("0.0.0.0:0") - .await?, - ); - - for i in 0..BACKLOG as u8 + 1 { - let conn = UdpSocket::bind("0.0.0.0:0").await?; - conn.connect(listener.addr().await?).await?; - conn.send(&[i]).await?; - conn.close().await?; - } - - // Wait all packets being processed by readLoop - tokio::time::sleep(Duration::from_millis(100)).await; - - let mut b = vec![0u8; 1]; - for i in 0..BACKLOG as u8 { - let (conn, _raddr) = listener.accept().await?; - let n = conn.recv(&mut b).await?; - assert_eq!( - &b[..n], - &[i], - "Packet from connection {} is wrong, expected: [{}], got: {:?}", - i, - i, - &b[..n] - ); - conn.close().await?; - } - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let mut done_tx = Some(done_tx); - let listener2 = Arc::clone(&listener); - tokio::spawn(async move { - match listener2.accept().await { - Ok((conn, _raddr)) => { - conn.close().await?; - } - Err(err) => { - assert!(Error::ErrClosedListener == err || Error::ErrClosedListenerAcceptCh == err); - } - } - - done_tx.take(); - - Result::<()>::Ok(()) - }); - - tokio::time::sleep(Duration::from_millis(100)).await; - - listener.close().await?; - - let _ = done_rx.recv().await; - - Ok(()) -} diff --git a/util/src/conn/mod.rs b/util/src/conn/mod.rs deleted file mode 100644 index 5f72ff17f..000000000 --- a/util/src/conn/mod.rs +++ /dev/null @@ -1,73 +0,0 @@ -pub mod conn_bridge; -pub mod conn_disconnected_packet; -pub mod conn_pipe; -pub mod conn_udp; -pub mod conn_udp_listener; - -#[cfg(test)] -mod conn_bridge_test; -#[cfg(test)] -mod conn_pipe_test; -#[cfg(test)] -mod conn_test; - -//TODO: remove this conditional test -#[cfg(not(target_os = "windows"))] -#[cfg(test)] -mod conn_udp_listener_test; - -use std::net::SocketAddr; -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::net::ToSocketAddrs; - -use crate::error::Result; - -#[async_trait] -pub trait Conn { - async fn connect(&self, addr: SocketAddr) -> Result<()>; - async fn recv(&self, buf: &mut [u8]) -> Result; - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>; - async fn send(&self, buf: &[u8]) -> Result; - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result; - fn local_addr(&self) -> Result; - fn remote_addr(&self) -> Option; - async fn close(&self) -> Result<()>; - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync); -} - -/// A Listener is a generic network listener for connection-oriented protocols. -/// Multiple connections may invoke methods on a Listener simultaneously. -#[async_trait] -pub trait Listener { - /// accept waits for and returns the next connection to the listener. - async fn accept(&self) -> Result<(Arc, SocketAddr)>; - - /// close closes the listener. - /// Any blocked accept operations will be unblocked and return errors. - async fn close(&self) -> Result<()>; - - /// addr returns the listener's network address. - async fn addr(&self) -> Result; -} - -pub async fn lookup_host(use_ipv4: bool, host: T) -> Result -where - T: ToSocketAddrs, -{ - for remote_addr in tokio::net::lookup_host(host).await? { - if (use_ipv4 && remote_addr.is_ipv4()) || (!use_ipv4 && remote_addr.is_ipv6()) { - return Ok(remote_addr); - } - } - - Err(std::io::Error::new( - std::io::ErrorKind::Other, - format!( - "No available {} IP address found!", - if use_ipv4 { "ipv4" } else { "ipv6" }, - ), - ) - .into()) -} diff --git a/util/src/error.rs b/util/src/error.rs deleted file mode 100644 index 71899d4c1..000000000 --- a/util/src/error.rs +++ /dev/null @@ -1,174 +0,0 @@ -#![allow(dead_code)] - -use std::num::ParseIntError; -use std::string::FromUtf8Error; -use std::{io, net}; - -use thiserror::Error; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - #[error("buffer: full")] - ErrBufferFull, - #[error("buffer: closed")] - ErrBufferClosed, - #[error("buffer: short")] - ErrBufferShort, - #[error("packet too big")] - ErrPacketTooBig, - #[error("i/o timeout")] - ErrTimeout, - #[error("udp: listener closed")] - ErrClosedListener, - #[error("udp: listen queue exceeded")] - ErrListenQueueExceeded, - #[error("udp: listener accept ch closed")] - ErrClosedListenerAcceptCh, - #[error("obs cannot be nil")] - ErrObsCannotBeNil, - #[error("se of closed network connection")] - ErrUseClosedNetworkConn, - #[error("addr is not a net.UDPAddr")] - ErrAddrNotUdpAddr, - #[error("something went wrong with locAddr")] - ErrLocAddr, - #[error("already closed")] - ErrAlreadyClosed, - #[error("no remAddr defined")] - ErrNoRemAddr, - #[error("address already in use")] - ErrAddressAlreadyInUse, - #[error("no such UDPConn")] - ErrNoSuchUdpConn, - #[error("cannot remove unspecified IP by the specified IP")] - ErrCannotRemoveUnspecifiedIp, - #[error("no address assigned")] - ErrNoAddressAssigned, - #[error("1:1 NAT requires more than one mapping")] - ErrNatRequiresMapping, - #[error("length mismtach between mappedIPs and localIPs")] - ErrMismatchLengthIp, - #[error("non-udp translation is not supported yet")] - ErrNonUdpTranslationNotSupported, - #[error("no associated local address")] - ErrNoAssociatedLocalAddress, - #[error("no NAT binding found")] - ErrNoNatBindingFound, - #[error("has no permission")] - ErrHasNoPermission, - #[error("host name must not be empty")] - ErrHostnameEmpty, - #[error("failed to parse IP address")] - ErrFailedToParseIpaddr, - #[error("no interface is available")] - ErrNoInterface, - #[error("not found")] - ErrNotFound, - #[error("unexpected network")] - ErrUnexpectedNetwork, - #[error("can't assign requested address")] - ErrCantAssignRequestedAddr, - #[error("unknown network")] - ErrUnknownNetwork, - #[error("no router linked")] - ErrNoRouterLinked, - #[error("invalid port number")] - ErrInvalidPortNumber, - #[error("unexpected type-switch failure")] - ErrUnexpectedTypeSwitchFailure, - #[error("bind failed")] - ErrBindFailed, - #[error("end port is less than the start")] - ErrEndPortLessThanStart, - #[error("port space exhausted")] - ErrPortSpaceExhausted, - #[error("vnet is not enabled")] - ErrVnetDisabled, - #[error("invalid local IP in static_ips")] - ErrInvalidLocalIpInStaticIps, - #[error("mapped in static_ips is beyond subnet")] - ErrLocalIpBeyondStaticIpsSubset, - #[error("all static_ips must have associated local IPs")] - ErrLocalIpNoStaticsIpsAssociated, - #[error("router already started")] - ErrRouterAlreadyStarted, - #[error("router already stopped")] - ErrRouterAlreadyStopped, - #[error("static IP is beyond subnet")] - ErrStaticIpIsBeyondSubnet, - #[error("address space exhausted")] - ErrAddressSpaceExhausted, - #[error("no IP address is assigned for eth0")] - ErrNoIpaddrEth0, - #[error("Invalid mask")] - ErrInvalidMask, - #[error("parse ipnet: {0}")] - ParseIpnet(#[from] ipnet::AddrParseError), - #[error("parse ip: {0}")] - ParseIp(#[from] net::AddrParseError), - #[error("parse int: {0}")] - ParseInt(#[from] ParseIntError), - #[error("{0}")] - Io(#[source] IoError), - #[error("utf8: {0}")] - Utf8(#[from] FromUtf8Error), - #[error("{0}")] - Std(#[source] StdError), - #[error("{0}")] - Other(String), -} - -impl Error { - pub fn from_std(error: T) -> Self - where - T: std::error::Error + Send + Sync + 'static, - { - Error::Std(StdError(Box::new(error))) - } - - pub fn downcast_ref(&self) -> Option<&T> { - if let Error::Std(s) = self { - return s.0.downcast_ref(); - } - - None - } -} - -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(IoError(e)) - } -} - -/// An escape hatch to preserve stack traces when we don't know the error. -/// -/// This crate exports some traits such as `Conn` and `Listener`. The trait functions -/// produce the local error `util::Error`. However when used in crates higher up the stack, -/// we are forced to handle errors that are local to that crate. For example we use -/// `Listener` the `dtls` crate and it needs to handle `dtls::Error`. -/// -/// By using `util::Error::from_std` we can preserve the underlying error (and stack trace!). -#[derive(Debug, Error)] -#[error("{0}")] -pub struct StdError(pub Box); - -impl PartialEq for StdError { - fn eq(&self, _: &Self) -> bool { - false - } -} diff --git a/util/src/fixed_big_int/fixed_big_int_test.rs b/util/src/fixed_big_int/fixed_big_int_test.rs deleted file mode 100644 index 4bad3c7b3..000000000 --- a/util/src/fixed_big_int/fixed_big_int_test.rs +++ /dev/null @@ -1,78 +0,0 @@ -use super::*; - -#[test] -fn test_fixed_big_int_set_bit() { - let mut bi = FixedBigInt::new(224); - - bi.set_bit(0); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000000000000000000000000001" - ); - - bi.lsh(1); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000000000000000000000000002" - ); - - bi.lsh(0); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000000000000000000000000002" - ); - - bi.set_bit(10); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000000000000000000000000402" - ); - bi.lsh(20); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000000000000000000040200000" - ); - - bi.set_bit(80); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000000100000000000040200000" - ); - bi.lsh(4); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000000000000001000000000000402000000" - ); - - bi.set_bit(130); - assert_eq!( - bi.to_string(), - "0000000000000000000000000000000400000000001000000000000402000000" - ); - bi.lsh(64); - assert_eq!( - bi.to_string(), - "0000000000000004000000000010000000000004020000000000000000000000" - ); - - bi.set_bit(7); - assert_eq!( - bi.to_string(), - "0000000000000004000000000010000000000004020000000000000000000080" - ); - - bi.lsh(129); - assert_eq!( - bi.to_string(), - "0000000004000000000000000000010000000000000000000000000000000000" - ); - - for _ in 0..256 { - bi.lsh(1); - bi.set_bit(0); - } - assert_eq!( - bi.to_string(), - "00000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" - ); -} diff --git a/util/src/fixed_big_int/mod.rs b/util/src/fixed_big_int/mod.rs deleted file mode 100644 index 400e04851..000000000 --- a/util/src/fixed_big_int/mod.rs +++ /dev/null @@ -1,96 +0,0 @@ -#[cfg(test)] -mod fixed_big_int_test; - -use std::fmt; - -// FixedBigInt is the fix-sized multi-word integer. -pub(crate) struct FixedBigInt { - bits: Vec, - n: usize, - msb_mask: u64, -} - -impl fmt::Display for FixedBigInt { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut out = String::new(); - for i in (0..self.bits.len()).rev() { - out += format!("{:016X}", self.bits[i]).as_str(); - } - - write!(f, "{out}") - } -} - -impl FixedBigInt { - pub(crate) fn new(n: usize) -> Self { - let mut chunk_size = (n + 63) / 64; - if chunk_size == 0 { - chunk_size = 1; - } - - FixedBigInt { - bits: vec![0; chunk_size], - n, - msb_mask: if n % 64 == 0 { - u64::MAX - } else { - (1 << (64 - n % 64)) - 1 - }, - } - } - - // lsh is the left shift operation. - pub(crate) fn lsh(&mut self, n: usize) { - if n == 0 { - return; - } - let n_chunk = (n / 64) as isize; - let n_n = n % 64; - - for i in (0..self.bits.len() as isize).rev() { - let mut carry: u64 = 0; - if i - n_chunk >= 0 { - carry = if n_n >= 64 { - 0 - } else { - self.bits[(i - n_chunk) as usize] << n_n - }; - if i - n_chunk > 0 { - carry |= if n_n == 0 { - 0 - } else { - self.bits[(i - n_chunk - 1) as usize] >> (64 - n_n) - }; - } - } - self.bits[i as usize] = if n >= 64 { - carry - } else { - (self.bits[i as usize] << n) | carry - }; - } - - let last = self.bits.len() - 1; - self.bits[last] &= self.msb_mask; - } - - // bit returns i-th bit of the fixedBigInt. - pub(crate) fn bit(&self, i: usize) -> usize { - if i >= self.n { - return 0; - } - let chunk = i / 64; - let pos = i % 64; - usize::from(self.bits[chunk] & (1 << pos) != 0) - } - - // set_bit sets i-th bit to 1. - pub(crate) fn set_bit(&mut self, i: usize) { - if i >= self.n { - return; - } - let chunk = i / 64; - let pos = i % 64; - self.bits[chunk] |= 1 << pos; - } -} diff --git a/util/src/ifaces/ffi/mod.rs b/util/src/ifaces/ffi/mod.rs deleted file mode 100644 index c8323f313..000000000 --- a/util/src/ifaces/ffi/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[cfg(target_family = "windows")] -mod windows; -#[cfg(target_family = "windows")] -pub use self::windows::ifaces; - -#[cfg(target_family = "unix")] -mod unix; -#[cfg(target_family = "unix")] -pub use self::unix::ifaces; diff --git a/util/src/ifaces/ffi/unix/mod.rs b/util/src/ifaces/ffi/unix/mod.rs deleted file mode 100644 index 1d834e187..000000000 --- a/util/src/ifaces/ffi/unix/mod.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::ifaces::{Interface, Kind, NextHop}; - -use nix::sys::socket::{AddressFamily, SockaddrLike, SockaddrStorage}; -use std::io::Error; -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; - -fn ss_to_netsa(ss: &SockaddrStorage) -> Option { - match ss.family() { - Some(AddressFamily::Inet) => ss.as_sockaddr_in().map(|sin| { - SocketAddr::V4(SocketAddrV4::new( - std::net::Ipv4Addr::from(sin.ip()), - sin.port(), - )) - }), - Some(AddressFamily::Inet6) => ss.as_sockaddr_in6().map(|sin6| { - SocketAddr::V6(SocketAddrV6::new( - sin6.ip(), - sin6.port(), - sin6.flowinfo(), - sin6.scope_id(), - )) - }), - _ => None, - } -} - -/// Query the local system for all interface addresses. -pub fn ifaces() -> Result, Error> { - let mut ret = Vec::new(); - for ifa in nix::ifaddrs::getifaddrs()? { - if let Some(kind) = ifa - .address - .as_ref() - .and_then(SockaddrStorage::family) - .and_then(|af| match af { - AddressFamily::Inet => Some(Kind::Ipv4), - AddressFamily::Inet6 => Some(Kind::Ipv6), - #[cfg(any( - target_os = "android", - target_os = "linux", - target_os = "illumos", - target_os = "fuchsia", - target_os = "solaris" - ))] - AddressFamily::Packet => Some(Kind::Packet), - #[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "ios", - target_os = "macos", - target_os = "illumos", - target_os = "netbsd", - target_os = "openbsd" - ))] - AddressFamily::Link => Some(Kind::Link), - _ => None, - }) - { - let name = ifa.interface_name; - let dst = ifa.destination.as_ref().and_then(ss_to_netsa); - let broadcast = ifa.broadcast.as_ref().and_then(ss_to_netsa); - let hop = dst - .map(NextHop::Destination) - .or(broadcast.map(NextHop::Broadcast)); - let addr = ifa.address.as_ref().and_then(ss_to_netsa); - let mask = ifa.netmask.as_ref().and_then(ss_to_netsa); - - ret.push(Interface { - name, - kind, - addr, - mask, - hop, - }); - } - } - - Ok(ret) -} diff --git a/util/src/ifaces/ffi/windows/mod.rs b/util/src/ifaces/ffi/windows/mod.rs deleted file mode 100644 index 415a73e98..000000000 --- a/util/src/ifaces/ffi/windows/mod.rs +++ /dev/null @@ -1,393 +0,0 @@ -#![allow(unused, non_upper_case_globals)] - -use winapi::shared::basetsd::{UINT32, UINT8, ULONG64}; -use winapi::shared::guiddef::GUID; -use winapi::shared::minwindef::{BYTE, DWORD, PULONG, ULONG}; -use winapi::shared::ws2def::SOCKET_ADDRESS; -use winapi::um::winnt::{PCHAR, PVOID, PWCHAR, WCHAR}; - -const MAX_ADAPTER_ADDRESS_LENGTH: usize = 8; -const ZONE_INDICES_LENGTH: usize = 16; -const MAX_DHCPV6_DUID_LENGTH: usize = 130; -const MAX_DNS_SUFFIX_STRING_LENGTH: usize = 256; - -pub const IP_ADAPTER_IPV4_ENABLED: DWORD = 0x0080; -pub const IP_ADAPTER_IPV6_ENABLED: DWORD = 0x0100; - -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::{io, mem, ptr}; - -use winapi::shared::winerror::{ - ERROR_ADDRESS_NOT_ASSOCIATED, ERROR_BUFFER_OVERFLOW, ERROR_INVALID_PARAMETER, - ERROR_NOT_ENOUGH_MEMORY, ERROR_NO_DATA, ERROR_SUCCESS, -}; -use winapi::shared::ws2def::{AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN}; -use winapi::shared::ws2ipdef::SOCKADDR_IN6; - -const PREALLOC_ADAPTERS_LEN: usize = 15 * 1024; - -use crate::ifaces::{Interface, Kind, NextHop}; - -#[link(name = "iphlpapi")] -extern "system" { - pub fn GetAdaptersAddresses( - family: ULONG, - flags: ULONG, - reserved: PVOID, - addresses: *mut u8, - size: PULONG, - ) -> ULONG; -} - -#[repr(C)] -pub struct IpAdapterAddresses { - pub head: IpAdapterAddressesHead, - pub all: IpAdaptersAddressesAll, - pub xp: IpAdaptersAddressesXp, - pub vista: IpAdaptersAddressesVista, -} - -#[repr(C)] -pub struct IpAdapterAddressesHead { - pub length: ULONG, - if_index: DWORD, -} - -/// All Windows & Later -#[repr(C)] -pub struct IpAdaptersAddressesAll { - pub next: *const IpAdapterAddresses, - pub adapter_name: PCHAR, - pub first_unicast_address: *const IpAdapterUnicastAddress, - first_anycast_address: *const IpAdapterAnycastAddress, - first_multicast_address: *const IpAdapterMulticastAddress, - first_dns_server_address: *const IpAdapterDnsServerAddress, - dns_suffix: PWCHAR, - pub description: PWCHAR, - friendly_name: PWCHAR, - pub physical_address: [BYTE; MAX_ADAPTER_ADDRESS_LENGTH], - pub physical_address_length: DWORD, - pub flags: DWORD, - mtu: DWORD, - pub if_type: DWORD, - oper_status: IfOperStatus, -} - -/// Windows XP & Later -#[repr(C)] -pub struct IpAdaptersAddressesXp { - pub ipv6_if_index: DWORD, - pub zone_indices: [DWORD; ZONE_INDICES_LENGTH], - first_prefix: *const IpAdapterPrefix, -} - -/// Windows Vista & Later -#[repr(C)] -pub struct IpAdaptersAddressesVista { - transmit_link_speed: ULONG64, - receive_link_speed: ULONG64, - first_wins_server_address: *const IpAdapterWinsServerAddress, - first_gateway_address: *const IpAdapterGatewayAddress, - ipv4_metric: ULONG, - ipv6_metric: ULONG, - luid: IfLuid, - dhcpv4_server: SOCKET_ADDRESS, - compartment_id: UINT32, - network_guid: GUID, - connection_type: NetIfConnectionType, - tunnel_type: TunnelType, - dhcpv6_server: SOCKET_ADDRESS, - dhcpv6_client_duid: [BYTE; MAX_DHCPV6_DUID_LENGTH], - dhcpv6_client_duid_length: ULONG, - dhcpv6_iaid: ULONG, - first_dns_suffix: *const IpAdapterDnsSuffix, -} - -#[repr(C)] -pub struct IpAdapterUnicastAddress { - pub length: ULONG, - flags: DWORD, - pub next: *const IpAdapterUnicastAddress, - pub address: SOCKET_ADDRESS, - prefix_origin: IpPrefixOrigin, - suffix_origin: IpSuffixOrigin, - pub dad_state: IpDadState, - valid_lifetime: ULONG, - preferred_lifetime: ULONG, - lease_lifetime: ULONG, - on_link_prefix_length: UINT8, -} - -#[repr(C)] -pub struct IpAdapterAnycastAddress { - length: ULONG, - flags: DWORD, - next: *const IpAdapterAnycastAddress, - address: SOCKET_ADDRESS, -} - -#[repr(C)] -pub struct IpAdapterMulticastAddress { - length: ULONG, - flags: DWORD, - next: *const IpAdapterMulticastAddress, - address: SOCKET_ADDRESS, -} - -#[repr(C)] -pub struct IpAdapterDnsServerAddress { - length: ULONG, - reserved: DWORD, - next: *const IpAdapterDnsServerAddress, - address: SOCKET_ADDRESS, -} - -#[repr(C)] -pub struct IpAdapterPrefix { - length: ULONG, - flags: DWORD, - next: *const IpAdapterPrefix, - address: SOCKET_ADDRESS, - prefix_length: ULONG, -} - -#[repr(C)] -pub struct IpAdapterWinsServerAddress { - length: ULONG, - reserved: DWORD, - next: *const IpAdapterWinsServerAddress, - address: SOCKET_ADDRESS, -} - -#[repr(C)] -pub struct IpAdapterGatewayAddress { - length: ULONG, - reserved: DWORD, - next: *const IpAdapterGatewayAddress, - address: SOCKET_ADDRESS, -} - -#[repr(C)] -pub struct IpAdapterDnsSuffix { - next: *const IpAdapterDnsSuffix, - string: [WCHAR; MAX_DNS_SUFFIX_STRING_LENGTH], -} - -bitflags! { - struct IfLuid: ULONG64 { - const Reserved = 0x0000000000FFFFFF; - const NetLuidIndex = 0x0000FFFFFF000000; - const IfType = 0xFFFF00000000000; - } -} - -#[repr(C)] -pub enum IpPrefixOrigin { - IpPrefixOriginOther = 0, - IpPrefixOriginManual, - IpPrefixOriginWellKnown, - IpPrefixOriginDhcp, - IpPrefixOriginRouterAdvertisement, - IpPrefixOriginUnchanged = 16, -} - -#[repr(C)] -pub enum IpSuffixOrigin { - IpSuffixOriginOther = 0, - IpSuffixOriginManual, - IpSuffixOriginWellKnown, - IpSuffixOriginDhcp, - IpSuffixOriginLinkLayerAddress, - IpSuffixOriginRandom, - IpSuffixOriginUnchanged = 16, -} - -#[derive(PartialEq, Eq)] -#[repr(C)] -pub enum IpDadState { - IpDadStateInvalid = 0, - IpDadStateTentative, - IpDadStateDuplicate, - IpDadStateDeprecated, - IpDadStatePreferred, -} - -#[repr(C)] -pub enum IfOperStatus { - IfOperStatusUp = 1, - IfOperStatusDown = 2, - IfOperStatusTesting = 3, - IfOperStatusUnknown = 4, - IfOperStatusDormant = 5, - IfOperStatusNotPresent = 6, - IfOperStatusLowerLayerDown = 7, -} - -#[repr(C)] -pub enum NetIfConnectionType { - NetIfConnectionDedicated = 1, - NetIfConnectionPassive = 2, - NetIfConnectionDemand = 3, - NetIfConnectionMaximum = 4, -} - -#[repr(C)] -pub enum TunnelType { - TunnelTypeNone = 0, - TunnelTypeOther = 1, - TunnelTypeDirect = 2, - TunnelType6To4 = 11, - TunnelTypeIsatap = 13, - TunnelTypeTeredo = 14, - TunnelTypeIpHttps = 15, -} - -unsafe fn v4_socket_from_adapter(unicast_addr: &IpAdapterUnicastAddress) -> SocketAddrV4 { - let socket_addr = &unicast_addr.address; - - let in_addr: SOCKADDR_IN = mem::transmute(*socket_addr.lpSockaddr); - let sin_addr = in_addr.sin_addr.S_un; - - let v4_addr = Ipv4Addr::new( - *sin_addr.S_addr() as u8, - (*sin_addr.S_addr() >> 8) as u8, - (*sin_addr.S_addr() >> 16) as u8, - (*sin_addr.S_addr() >> 24) as u8, - ); - - SocketAddrV4::new(v4_addr, 0) -} - -unsafe fn v6_socket_from_adapter(unicast_addr: &IpAdapterUnicastAddress) -> SocketAddrV6 { - let socket_addr = &unicast_addr.address; - - let sock_addr6: *const SOCKADDR_IN6 = socket_addr.lpSockaddr as *const SOCKADDR_IN6; - let in6_addr: SOCKADDR_IN6 = *sock_addr6; - - let v6_addr = (*in6_addr.sin6_addr.u.Word()).into(); - - SocketAddrV6::new( - v6_addr, - 0, - in6_addr.sin6_flowinfo, - *in6_addr.u.sin6_scope_id(), - ) -} - -unsafe fn local_ifaces_with_buffer(buffer: &mut Vec) -> io::Result<()> { - let mut length = buffer.capacity() as u32; - - let ret_code = GetAdaptersAddresses( - AF_UNSPEC as u32, - 0, - ptr::null_mut(), - buffer.as_mut_ptr(), - &mut length, - ); - match ret_code { - ERROR_SUCCESS => Ok(()), - ERROR_ADDRESS_NOT_ASSOCIATED => Err(io::Error::new( - io::ErrorKind::AddrNotAvailable, - "An address has not yet been associated with the network endpoint.", - )), - ERROR_BUFFER_OVERFLOW => { - buffer.reserve_exact(length as usize); - - local_ifaces_with_buffer(buffer) - } - ERROR_INVALID_PARAMETER => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "One of the parameters is invalid.", - )), - ERROR_NOT_ENOUGH_MEMORY => Err(io::Error::new( - io::ErrorKind::Other, - "Insufficient memory resources are available to complete the operation.", - )), - ERROR_NO_DATA => Err(io::Error::new( - io::ErrorKind::AddrNotAvailable, - "No addresses were found for the requested parameters.", - )), - _ => Err(io::Error::new( - io::ErrorKind::Other, - "Some Other Error Occurred.", - )), - } -} - -unsafe fn map_adapter_addresses(mut adapter_addr: *const IpAdapterAddresses) -> Vec { - let mut adapter_addresses = Vec::new(); - - while !adapter_addr.is_null() { - let curr_adapter_addr = &*adapter_addr; - - let mut unicast_addr = curr_adapter_addr.all.first_unicast_address; - while !unicast_addr.is_null() { - let curr_unicast_addr = &*unicast_addr; - - // For some reason, some IpDadState::IpDadStateDeprecated addresses are return - // These contain BOGUS interface indices and will cause problesm if used - if curr_unicast_addr.dad_state != IpDadState::IpDadStateDeprecated { - if is_ipv4_enabled(curr_unicast_addr) { - adapter_addresses.push(Interface { - name: "".to_string(), - kind: Kind::Ipv4, - addr: Some(SocketAddr::V4(v4_socket_from_adapter(curr_unicast_addr))), - mask: None, - hop: None, - }); - } else if is_ipv6_enabled(curr_unicast_addr) { - let mut v6_sock = v6_socket_from_adapter(curr_unicast_addr); - // Make sure the scope id is set for ALL interfaces, not just link-local - v6_sock.set_scope_id(curr_adapter_addr.xp.ipv6_if_index); - adapter_addresses.push(Interface { - name: "".to_string(), - kind: Kind::Ipv6, - addr: Some(SocketAddr::V6(v6_sock)), - mask: None, - hop: None, - }); - } - } - - unicast_addr = curr_unicast_addr.next; - } - - adapter_addr = curr_adapter_addr.all.next; - } - - adapter_addresses -} - -/// Query the local system for all interface addresses. -pub fn ifaces() -> Result, ::std::io::Error> { - let mut adapters_list = Vec::with_capacity(PREALLOC_ADAPTERS_LEN); - unsafe { - local_ifaces_with_buffer(&mut adapters_list)?; - - Ok(map_adapter_addresses( - adapters_list.as_ptr() as *const IpAdapterAddresses - )) - } -} - -unsafe fn is_ipv4_enabled(unicast_addr: &IpAdapterUnicastAddress) -> bool { - if unicast_addr.length != 0 { - let socket_addr = &unicast_addr.address; - let sa_family = (*socket_addr.lpSockaddr).sa_family; - - sa_family == AF_INET as u16 - } else { - false - } -} - -unsafe fn is_ipv6_enabled(unicast_addr: &IpAdapterUnicastAddress) -> bool { - if unicast_addr.length != 0 { - let socket_addr = &unicast_addr.address; - let sa_family = (*socket_addr.lpSockaddr).sa_family; - - sa_family == AF_INET6 as u16 - } else { - false - } -} diff --git a/util/src/ifaces/mod.rs b/util/src/ifaces/mod.rs deleted file mode 100644 index b786659e6..000000000 --- a/util/src/ifaces/mod.rs +++ /dev/null @@ -1,26 +0,0 @@ -pub mod ffi; -pub use ffi::ifaces; - -#[derive(PartialEq, Eq, Debug, Clone)] -pub enum NextHop { - Broadcast(::std::net::SocketAddr), - Destination(::std::net::SocketAddr), -} - -#[derive(PartialEq, Eq, Debug, Clone)] -pub enum Kind { - Packet, - Link, - Ipv4, - Ipv6, - Unknow(i32), -} - -#[derive(Debug, Clone)] -pub struct Interface { - pub name: String, - pub kind: Kind, - pub addr: Option<::std::net::SocketAddr>, - pub mask: Option<::std::net::SocketAddr>, - pub hop: Option, -} diff --git a/util/src/lib.rs b/util/src/lib.rs deleted file mode 100644 index b149bbb50..000000000 --- a/util/src/lib.rs +++ /dev/null @@ -1,88 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -use std::io; - -use async_trait::async_trait; -use thiserror::Error; - -#[cfg(feature = "vnet")] -#[macro_use] -extern crate lazy_static; - -#[cfg(target_family = "windows")] -#[macro_use] -extern crate bitflags; - -pub mod fixed_big_int; -pub mod replay_detector; - -/// KeyingMaterialExporter to extract keying material. -/// -/// This trait sits here to avoid getting a direct dependency between -/// the dtls and srtp crates. -#[async_trait] -pub trait KeyingMaterialExporter { - async fn export_keying_material( - &self, - label: &str, - context: &[u8], - length: usize, - ) -> std::result::Result, KeyingMaterialExporterError>; -} - -/// Possible errors while exporting keying material. -/// -/// These errors might have been more logically kept in the dtls -/// crate, but that would have required a direct dependency between -/// srtp and dtls. -#[derive(Debug, Error, PartialEq)] -#[non_exhaustive] -pub enum KeyingMaterialExporterError { - #[error("tls handshake is in progress")] - HandshakeInProgress, - #[error("context is not supported for export_keying_material")] - ContextUnsupported, - #[error("export_keying_material can not be used with a reserved label")] - ReservedExportKeyingMaterial, - #[error("no cipher suite for export_keying_material")] - CipherSuiteUnset, - #[error("export_keying_material io: {0}")] - Io(#[source] error::IoError), - #[error("export_keying_material hash: {0}")] - Hash(String), -} - -impl From for KeyingMaterialExporterError { - fn from(e: io::Error) -> Self { - KeyingMaterialExporterError::Io(error::IoError(e)) - } -} - -#[cfg(feature = "buffer")] -pub mod buffer; - -#[cfg(feature = "conn")] -pub mod conn; - -#[cfg(feature = "ifaces")] -pub mod ifaces; - -#[cfg(feature = "vnet")] -pub mod vnet; - -#[cfg(feature = "marshal")] -pub mod marshal; - -#[cfg(feature = "buffer")] -pub use crate::buffer::Buffer; -#[cfg(feature = "conn")] -pub use crate::conn::Conn; -#[cfg(feature = "marshal")] -pub use crate::marshal::{exact_size_buf::ExactSizeBuf, Marshal, MarshalSize, Unmarshal}; - -mod error; -pub use error::{Error, Result}; - -#[cfg(feature = "sync")] -pub mod sync; diff --git a/util/src/marshal/exact_size_buf.rs b/util/src/marshal/exact_size_buf.rs deleted file mode 100644 index daf826759..000000000 --- a/util/src/marshal/exact_size_buf.rs +++ /dev/null @@ -1,96 +0,0 @@ -// FIXME(regexident): -// Replace with `bytes::ExactSizeBuf` once merged: -// https://github.com/tokio-rs/bytes/pull/496 - -use bytes::buf::{Chain, Take}; -use bytes::{Bytes, BytesMut}; - -/// A trait for buffers that know their exact length. -pub trait ExactSizeBuf { - /// Returns the exact length of the buffer. - fn len(&self) -> usize; - - /// Returns `true` if the buffer is empty. - /// - /// This method has a default implementation using `ExactSizeBuf::len()`, - /// so you don't need to implement it yourself. - #[inline] - fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl ExactSizeBuf for Bytes { - #[inline] - fn len(&self) -> usize { - Bytes::len(self) - } - - #[inline] - fn is_empty(&self) -> bool { - Bytes::is_empty(self) - } -} - -impl ExactSizeBuf for BytesMut { - #[inline] - fn len(&self) -> usize { - BytesMut::len(self) - } - - #[inline] - fn is_empty(&self) -> bool { - BytesMut::is_empty(self) - } -} - -impl ExactSizeBuf for [u8] { - #[inline] - fn len(&self) -> usize { - <[u8]>::len(self) - } - - #[inline] - fn is_empty(&self) -> bool { - <[u8]>::is_empty(self) - } -} - -impl ExactSizeBuf for Chain -where - T: ExactSizeBuf, - U: ExactSizeBuf, -{ - fn len(&self) -> usize { - let first_ref = self.first_ref(); - let last_ref = self.last_ref(); - - first_ref.len() + last_ref.len() - } - - fn is_empty(&self) -> bool { - let first_ref = self.first_ref(); - let last_ref = self.last_ref(); - - first_ref.is_empty() && last_ref.is_empty() - } -} - -impl ExactSizeBuf for Take -where - T: ExactSizeBuf, -{ - fn len(&self) -> usize { - let inner_ref = self.get_ref(); - let limit = self.limit(); - - limit.min(inner_ref.len()) - } - - fn is_empty(&self) -> bool { - let inner_ref = self.get_ref(); - let limit = self.limit(); - - (limit == 0) || inner_ref.is_empty() - } -} diff --git a/util/src/marshal/mod.rs b/util/src/marshal/mod.rs deleted file mode 100644 index aa4bb56cb..000000000 --- a/util/src/marshal/mod.rs +++ /dev/null @@ -1,34 +0,0 @@ -pub mod exact_size_buf; - -use bytes::{Buf, Bytes, BytesMut}; - -use crate::error::{Error, Result}; - -pub trait MarshalSize { - fn marshal_size(&self) -> usize; -} - -pub trait Marshal: MarshalSize { - fn marshal_to(&self, buf: &mut [u8]) -> Result; - - fn marshal(&self) -> Result { - let l = self.marshal_size(); - let mut buf = BytesMut::with_capacity(l); - buf.resize(l, 0); - let n = self.marshal_to(&mut buf)?; - if n != l { - Err(Error::Other(format!( - "marshal_to output size {n}, but expect {l}" - ))) - } else { - Ok(buf.freeze()) - } - } -} - -pub trait Unmarshal: MarshalSize { - fn unmarshal(buf: &mut B) -> Result - where - Self: Sized, - B: Buf; -} diff --git a/util/src/replay_detector/mod.rs b/util/src/replay_detector/mod.rs deleted file mode 100644 index 127707be6..000000000 --- a/util/src/replay_detector/mod.rs +++ /dev/null @@ -1,177 +0,0 @@ -#[cfg(test)] -mod replay_detector_test; - -use super::fixed_big_int::*; - -// ReplayDetector is the interface of sequence replay detector. -pub trait ReplayDetector { - // Check returns true if given sequence number is not replayed. - // Call accept() to mark the packet is received properly. - fn check(&mut self, seq: u64) -> bool; - fn accept(&mut self); -} - -pub struct SlidingWindowDetector { - accepted: bool, - seq: u64, - latest_seq: u64, - max_seq: u64, - window_size: usize, - mask: FixedBigInt, -} - -impl SlidingWindowDetector { - // New creates ReplayDetector. - // Created ReplayDetector doesn't allow wrapping. - // It can handle monotonically increasing sequence number up to - // full 64bit number. It is suitable for DTLS replay protection. - pub fn new(window_size: usize, max_seq: u64) -> Self { - SlidingWindowDetector { - accepted: false, - seq: 0, - latest_seq: 0, - max_seq, - window_size, - mask: FixedBigInt::new(window_size), - } - } -} - -impl ReplayDetector for SlidingWindowDetector { - fn check(&mut self, seq: u64) -> bool { - self.accepted = false; - - if seq > self.max_seq { - // Exceeded upper limit. - return false; - } - - if seq <= self.latest_seq { - if self.latest_seq >= self.window_size as u64 + seq { - return false; - } - if self.mask.bit((self.latest_seq - seq) as usize) != 0 { - // The sequence number is duplicated. - return false; - } - } - - self.accepted = true; - self.seq = seq; - true - } - - fn accept(&mut self) { - if !self.accepted { - return; - } - - if self.seq > self.latest_seq { - // Update the head of the window. - self.mask.lsh((self.seq - self.latest_seq) as usize); - self.latest_seq = self.seq; - } - let diff = (self.latest_seq - self.seq) % self.max_seq; - self.mask.set_bit(diff as usize); - } -} - -pub struct WrappedSlidingWindowDetector { - accepted: bool, - seq: u64, - latest_seq: u64, - max_seq: u64, - window_size: usize, - mask: FixedBigInt, - init: bool, -} - -impl WrappedSlidingWindowDetector { - // WithWrap creates ReplayDetector allowing sequence wrapping. - // This is suitable for short bitwidth counter like SRTP and SRTCP. - pub fn new(window_size: usize, max_seq: u64) -> Self { - WrappedSlidingWindowDetector { - accepted: false, - seq: 0, - latest_seq: 0, - max_seq, - window_size, - mask: FixedBigInt::new(window_size), - init: false, - } - } -} - -impl ReplayDetector for WrappedSlidingWindowDetector { - fn check(&mut self, seq: u64) -> bool { - self.accepted = false; - - if seq > self.max_seq { - // Exceeded upper limit. - return false; - } - if !self.init { - if seq != 0 { - self.latest_seq = seq - 1; - } else { - self.latest_seq = self.max_seq; - } - self.init = true; - } - - let mut diff = self.latest_seq as i64 - seq as i64; - // Wrap the number. - if diff > self.max_seq as i64 / 2 { - diff -= (self.max_seq + 1) as i64; - } else if diff <= -(self.max_seq as i64 / 2) { - diff += (self.max_seq + 1) as i64; - } - - if diff >= self.window_size as i64 { - // Too old. - return false; - } - if diff >= 0 && self.mask.bit(diff as usize) != 0 { - // The sequence number is duplicated. - return false; - } - - self.accepted = true; - self.seq = seq; - true - } - - fn accept(&mut self) { - if !self.accepted { - return; - } - - let mut diff = self.latest_seq as i64 - self.seq as i64; - // Wrap the number. - if diff > self.max_seq as i64 / 2 { - diff -= (self.max_seq + 1) as i64; - } else if diff <= -(self.max_seq as i64 / 2) { - diff += (self.max_seq + 1) as i64; - } - - assert!(diff < self.window_size as i64); - - if diff < 0 { - // Update the head of the window. - self.mask.lsh((-diff) as usize); - self.latest_seq = self.seq; - } - self.mask - .set_bit((self.latest_seq as isize - self.seq as isize) as usize); - } -} - -#[derive(Default)] -pub struct NoOpReplayDetector; - -impl ReplayDetector for NoOpReplayDetector { - fn check(&mut self, _: u64) -> bool { - true - } - fn accept(&mut self) {} -} diff --git a/util/src/replay_detector/replay_detector_test.rs b/util/src/replay_detector/replay_detector_test.rs deleted file mode 100644 index 9537e8987..000000000 --- a/util/src/replay_detector/replay_detector_test.rs +++ /dev/null @@ -1,278 +0,0 @@ -use super::*; - -#[test] -fn test_replay_detector() { - const LARGE_SEQ: u64 = 0x100000000000; - - #[allow(clippy::type_complexity)] - let tests: Vec<(&str, usize, u64, Vec, Vec, Vec, Vec)> = vec![ - ( - "Continuous", - 16, - 0x0000FFFFFFFFFFFF, - vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, true, true, true, true, true, true, - ], - vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - ], - vec![], - ), - ( - "ValidLargeJump", - 16, - 0x0000FFFFFFFFFFFF, - vec![ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - LARGE_SEQ, - 11, - LARGE_SEQ + 1, - LARGE_SEQ + 2, - LARGE_SEQ + 3, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, - ], - vec![ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - LARGE_SEQ, - LARGE_SEQ + 1, - LARGE_SEQ + 2, - LARGE_SEQ + 3, - ], - vec![], - ), - ( - "InvalidLargeJump", - 16, - 0x0000FFFFFFFFFFFF, - vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, LARGE_SEQ, 11, 12, 13, 14, 15], - vec![ - true, true, true, true, true, true, true, true, true, true, false, true, true, - true, true, true, - ], - vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15], - vec![], - ), - ( - "DuplicateAfterValidJump", - 196, - 0x0000FFFFFFFFFFFF, - vec![0, 1, 2, 129, 0, 1, 2], - vec![true, true, true, true, true, true, true], - vec![0, 1, 2, 129], - vec![], - ), - ( - "DuplicateAfterInvalidJump", - 196, - 0x0000FFFFFFFFFFFF, - vec![0, 1, 2, 128, 0, 1, 2], - vec![true, true, true, false, true, true, true], - vec![0, 1, 2], - vec![], - ), - ( - "ContinuousOffset", - 16, - 0x0000FFFFFFFFFFFF, - vec![ - 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, - ], - vec![ - 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, - ], - vec![], - ), - ( - "Reordered", - 128, - 0x0000FFFFFFFFFFFF, - vec![ - 96, 64, 16, 80, 32, 48, 8, 24, 88, 40, 128, 56, 72, 112, 104, 120, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, true, - ], - vec![ - 96, 64, 16, 80, 32, 48, 8, 24, 88, 40, 128, 56, 72, 112, 104, 120, - ], - vec![], - ), - ( - "Old", - 100, - 0x0000FFFFFFFFFFFF, - vec![ - 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 8, 16, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, true, - ], - vec![24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128], - vec![], - ), - ( - "ContinuouesReplayed", - 8, - 0x0000FFFFFFFFFFFF, - vec![ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, true, true, true, true, true, - ], - vec![16, 17, 18, 19, 20, 21, 22, 23, 24, 25], - vec![], - ), - ( - "ReplayedLater", - 128, - 0x0000FFFFFFFFFFFF, - vec![ - 16, 32, 48, 64, 80, 96, 112, 128, 16, 32, 48, 64, 80, 96, 112, 128, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, true, - ], - vec![16, 32, 48, 64, 80, 96, 112, 128], - vec![], - ), - ( - "ReplayedQuick", - 128, - 0x0000FFFFFFFFFFFF, - vec![ - 16, 16, 32, 32, 48, 48, 64, 64, 80, 80, 96, 96, 112, 112, 128, 128, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, true, true, - true, true, - ], - vec![16, 32, 48, 64, 80, 96, 112, 128], - vec![], - ), - ( - "Strict", - 0, - 0x0000FFFFFFFFFFFF, - vec![1, 3, 2, 4, 5, 6, 7, 8, 9, 10], - vec![true, true, true, true, true, true, true, true, true, true], - vec![1, 3, 4, 5, 6, 7, 8, 9, 10], - vec![], - ), - ( - "Overflow", - 128, - 0x0000FFFFFFFFFFFF, - vec![ - 0x0000FFFFFFFFFFFE, - 0x0000FFFFFFFFFFFF, - 0x0001000000000000, - 0x0001000000000001, - ], - vec![true, true, true, true], - vec![0x0000FFFFFFFFFFFE, 0x0000FFFFFFFFFFFF], - vec![], - ), - ( - "WrapContinuous", - 64, - 0xFFFF, - vec![ - 0xFFFC, 0xFFFD, 0xFFFE, 0xFFFF, 0x0000, 0x0001, 0x0002, 0x0003, - ], - vec![true, true, true, true, true, true, true, true], - vec![0xFFFC, 0xFFFD, 0xFFFE, 0xFFFF], - vec![ - 0xFFFC, 0xFFFD, 0xFFFE, 0xFFFF, 0x0000, 0x0001, 0x0002, 0x0003, - ], - ), - ( - "WrapReordered", - 64, - 0xFFFF, - vec![ - 0xFFFD, 0xFFFC, 0x0002, 0xFFFE, 0x0000, 0x0001, 0xFFFF, 0x0003, - ], - vec![true, true, true, true, true, true, true, true], - vec![0xFFFD, 0xFFFC, 0xFFFE, 0xFFFF], - vec![ - 0xFFFD, 0xFFFC, 0x0002, 0xFFFE, 0x0000, 0x0001, 0xFFFF, 0x0003, - ], - ), - ( - "WrapReorderedReplayed", - 64, - 0xFFFF, - vec![ - 0xFFFD, 0xFFFC, 0xFFFC, 0x0002, 0xFFFE, 0xFFFC, 0x0000, 0x0001, 0x0001, 0xFFFF, - 0x0001, 0x0003, - ], - vec![ - true, true, true, true, true, true, true, true, true, true, true, true, - ], - vec![0xFFFD, 0xFFFC, 0xFFFE, 0xFFFF], - vec![ - 0xFFFD, 0xFFFC, 0x0002, 0xFFFE, 0x0000, 0x0001, 0xFFFF, 0x0003, - ], - ), - ]; - - for (name, windows_size, max_seq, input, valid, expected, mut expected_wrap) in tests { - if expected_wrap.is_empty() { - expected_wrap.extend_from_slice(&expected); - } - - for k in 0..2 { - let mut det: Box = if k == 0 { - Box::new(SlidingWindowDetector::new(windows_size, max_seq)) - } else { - Box::new(WrappedSlidingWindowDetector::new(windows_size, max_seq)) - }; - let exp = if k == 0 { &expected } else { &expected_wrap }; - - let mut out = vec![]; - for (i, seq) in input.iter().enumerate() { - let ok = det.check(*seq); - if ok && valid[i] { - out.push(*seq); - det.accept(); - } - } - - assert_eq!(&out, exp, "{name} failed"); - } - } -} diff --git a/util/src/sync/mod.rs b/util/src/sync/mod.rs deleted file mode 100644 index cab9050cf..000000000 --- a/util/src/sync/mod.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::{ops, sync}; - -/// A synchronous mutual exclusion primitive useful for protecting shared data. -#[derive(Default, Debug)] -pub struct Mutex(sync::Mutex); - -impl Mutex { - /// Creates a new mutex in an unlocked state ready for use. - pub fn new(value: T) -> Self { - Self(sync::Mutex::new(value)) - } - - /// Acquires a mutex, blocking the current thread until it is able to do so. - pub fn lock(&self) -> MutexGuard<'_, T> { - let guard = self.0.lock().unwrap(); - - MutexGuard(guard) - } - - /// Consumes this mutex, returning the underlying data. - pub fn into_inner(self) -> T { - self.0.into_inner().unwrap() - } -} - -/// An RAII implementation of a "scoped lock" of a mutex. When this structure is -/// dropped (falls out of scope), the lock will be unlocked. -pub struct MutexGuard<'a, T>(sync::MutexGuard<'a, T>); - -impl<'a, T> ops::Deref for MutexGuard<'a, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<'a, T> ops::DerefMut for MutexGuard<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -/// A synchronous reader-writer lock. -#[derive(Default, Debug)] -pub struct RwLock(sync::RwLock); - -impl RwLock { - /// Creates a new mutex in an unlocked state ready for use. - pub fn new(value: T) -> Self { - Self(sync::RwLock::new(value)) - } - - /// Locks this rwlock with shared read access, blocking the current thread - /// until it can be acquired. - pub fn read(&self) -> RwLockReadGuard<'_, T> { - let guard = self.0.read().unwrap(); - - RwLockReadGuard(guard) - } - - /// Locks this rwlock with exclusive write access, blocking the current - /// thread until it can be acquired. - pub fn write(&self) -> RwLockWriteGuard<'_, T> { - let guard = self.0.write().unwrap(); - - RwLockWriteGuard(guard) - } -} - -/// RAII structure used to release the shared read access of a lock when -/// dropped. -pub struct RwLockReadGuard<'a, T>(sync::RwLockReadGuard<'a, T>); - -impl<'a, T> ops::Deref for RwLockReadGuard<'a, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// RAII structure used to release the exclusive write access of a lock when -/// dropped. -pub struct RwLockWriteGuard<'a, T>(sync::RwLockWriteGuard<'a, T>); - -impl<'a, T> ops::Deref for RwLockWriteGuard<'a, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<'a, T> ops::DerefMut for RwLockWriteGuard<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} diff --git a/util/src/vnet/chunk.rs b/util/src/vnet/chunk.rs deleted file mode 100644 index 6300aff8d..000000000 --- a/util/src/vnet/chunk.rs +++ /dev/null @@ -1,352 +0,0 @@ -#[cfg(test)] -mod chunk_test; - -use std::fmt; -use std::net::{IpAddr, SocketAddr}; -use std::ops::{BitAnd, BitOr}; -use std::str::FromStr; -use std::sync::atomic::Ordering; -use std::time::SystemTime; - -use portable_atomic::AtomicU64; - -use super::net::*; -use crate::error::Result; - -lazy_static! { - static ref TAG_CTR: AtomicU64 = AtomicU64::new(0); -} - -/// Encodes a u64 value to a lowercase base 36 string. -pub fn base36(value: impl Into) -> String { - let mut digits: Vec = vec![]; - - let mut value = value.into(); - while value > 0 { - let digit = (value % 36) as usize; - value /= 36; - - digits.push(b"0123456789abcdefghijklmnopqrstuvwxyz"[digit]); - } - - digits.reverse(); - format!("{:0>8}", String::from_utf8(digits).unwrap()) -} - -// Generate a base36-encoded unique tag -// See: https://play.golang.org/p/0ZaAID1q-HN -fn assign_chunk_tag() -> String { - let n = TAG_CTR.fetch_add(1, Ordering::SeqCst); - base36(n) -} - -#[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) struct TcpFlag(pub(crate) u8); - -pub(crate) const TCP_FLAG_ZERO: TcpFlag = TcpFlag(0x00); -pub(crate) const TCP_FLAG_FIN: TcpFlag = TcpFlag(0x01); -pub(crate) const TCP_FLAG_SYN: TcpFlag = TcpFlag(0x02); -pub(crate) const TCP_FLAG_RST: TcpFlag = TcpFlag(0x04); -pub(crate) const TCP_FLAG_PSH: TcpFlag = TcpFlag(0x08); -pub(crate) const TCP_FLAG_ACK: TcpFlag = TcpFlag(0x10); - -impl BitOr for TcpFlag { - type Output = Self; - - // rhs is the "right-hand side" of the expression `a | b` - fn bitor(self, rhs: Self) -> Self::Output { - Self(self.0 | rhs.0) - } -} - -impl BitAnd for TcpFlag { - type Output = Self; - - // rhs is the "right-hand side" of the expression `a & b` - fn bitand(self, rhs: Self) -> Self::Output { - Self(self.0 & rhs.0) - } -} - -impl fmt::Display for TcpFlag { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut sa = vec![]; - if *self & TCP_FLAG_FIN != TCP_FLAG_ZERO { - sa.push("FIN"); - } - if *self & TCP_FLAG_SYN != TCP_FLAG_ZERO { - sa.push("SYN"); - } - if *self & TCP_FLAG_RST != TCP_FLAG_ZERO { - sa.push("RST"); - } - if *self & TCP_FLAG_PSH != TCP_FLAG_ZERO { - sa.push("PSH"); - } - if *self & TCP_FLAG_ACK != TCP_FLAG_ZERO { - sa.push("ACK"); - } - - write!(f, "{}", sa.join("-")) - } -} - -// Chunk represents a packet passed around in the vnet -pub trait Chunk: fmt::Display + fmt::Debug { - fn set_timestamp(&mut self) -> SystemTime; // used by router - fn get_timestamp(&self) -> SystemTime; // used by router - fn get_source_ip(&self) -> IpAddr; // used by routee - fn get_destination_ip(&self) -> IpAddr; // used by router - fn set_source_addr(&mut self, address: &str) -> Result<()>; // used by nat - fn set_destination_addr(&mut self, address: &str) -> Result<()>; // used by nat - - fn source_addr(&self) -> SocketAddr; - fn destination_addr(&self) -> SocketAddr; - fn user_data(&self) -> Vec; - fn tag(&self) -> String; - fn network(&self) -> String; // returns "udp" or "tcp" - fn clone_to(&self) -> Box; -} - -#[derive(PartialEq, Debug)] -pub(crate) struct ChunkIp { - pub(crate) timestamp: SystemTime, - pub(crate) source_ip: IpAddr, - pub(crate) destination_ip: IpAddr, - pub(crate) tag: String, -} - -impl ChunkIp { - fn set_timestamp(&mut self) -> SystemTime { - self.timestamp = SystemTime::now(); - self.timestamp - } - - fn get_timestamp(&self) -> SystemTime { - self.timestamp - } - - fn get_destination_ip(&self) -> IpAddr { - self.destination_ip - } - - fn get_source_ip(&self) -> IpAddr { - self.source_ip - } - - fn tag(&self) -> String { - self.tag.clone() - } -} - -#[derive(PartialEq, Debug)] -pub(crate) struct ChunkUdp { - pub(crate) chunk_ip: ChunkIp, - pub(crate) source_port: u16, - pub(crate) destination_port: u16, - pub(crate) user_data: Vec, -} - -impl fmt::Display for ChunkUdp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} chunk {} {} => {}", - self.network(), - self.tag(), - self.source_addr(), - self.destination_addr(), - ) - } -} - -impl Chunk for ChunkUdp { - fn set_timestamp(&mut self) -> SystemTime { - self.chunk_ip.set_timestamp() - } - - fn get_timestamp(&self) -> SystemTime { - self.chunk_ip.get_timestamp() - } - - fn get_destination_ip(&self) -> IpAddr { - self.chunk_ip.get_destination_ip() - } - - fn get_source_ip(&self) -> IpAddr { - self.chunk_ip.get_source_ip() - } - - fn tag(&self) -> String { - self.chunk_ip.tag() - } - - fn source_addr(&self) -> SocketAddr { - SocketAddr::new(self.chunk_ip.source_ip, self.source_port) - } - - fn destination_addr(&self) -> SocketAddr { - SocketAddr::new(self.chunk_ip.destination_ip, self.destination_port) - } - - fn user_data(&self) -> Vec { - self.user_data.clone() - } - - fn clone_to(&self) -> Box { - Box::new(ChunkUdp { - chunk_ip: ChunkIp { - timestamp: self.chunk_ip.timestamp, - source_ip: self.chunk_ip.source_ip, - destination_ip: self.chunk_ip.destination_ip, - tag: self.chunk_ip.tag.clone(), - }, - source_port: self.source_port, - destination_port: self.destination_port, - user_data: self.user_data.clone(), - }) - } - - fn network(&self) -> String { - UDP_STR.to_owned() - } - - fn set_source_addr(&mut self, address: &str) -> Result<()> { - let addr = SocketAddr::from_str(address)?; - self.chunk_ip.source_ip = addr.ip(); - self.source_port = addr.port(); - Ok(()) - } - - fn set_destination_addr(&mut self, address: &str) -> Result<()> { - let addr = SocketAddr::from_str(address)?; - self.chunk_ip.destination_ip = addr.ip(); - self.destination_port = addr.port(); - Ok(()) - } -} - -impl ChunkUdp { - pub(crate) fn new(src_addr: SocketAddr, dst_addr: SocketAddr) -> Self { - ChunkUdp { - chunk_ip: ChunkIp { - timestamp: SystemTime::now(), - source_ip: src_addr.ip(), - destination_ip: dst_addr.ip(), - tag: assign_chunk_tag(), - }, - source_port: src_addr.port(), - destination_port: dst_addr.port(), - user_data: vec![], - } - } -} - -#[derive(PartialEq, Debug)] -pub(crate) struct ChunkTcp { - chunk_ip: ChunkIp, - source_port: u16, - destination_port: u16, - flags: TcpFlag, // control bits - user_data: Vec, // only with PSH flag - // seq :u32, // always starts with 0 - // ack :u32, // always starts with 0 -} - -impl fmt::Display for ChunkTcp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {} chunk {} {} => {}", - self.network(), - self.flags, - self.chunk_ip.tag, - self.source_addr(), - self.destination_addr(), - ) - } -} - -impl Chunk for ChunkTcp { - fn set_timestamp(&mut self) -> SystemTime { - self.chunk_ip.set_timestamp() - } - - fn get_timestamp(&self) -> SystemTime { - self.chunk_ip.get_timestamp() - } - - fn get_destination_ip(&self) -> IpAddr { - self.chunk_ip.get_destination_ip() - } - - fn get_source_ip(&self) -> IpAddr { - self.chunk_ip.get_source_ip() - } - - fn tag(&self) -> String { - self.chunk_ip.tag() - } - - fn source_addr(&self) -> SocketAddr { - SocketAddr::new(self.chunk_ip.source_ip, self.source_port) - } - - fn destination_addr(&self) -> SocketAddr { - SocketAddr::new(self.chunk_ip.destination_ip, self.destination_port) - } - - fn user_data(&self) -> Vec { - self.user_data.clone() - } - - fn clone_to(&self) -> Box { - Box::new(ChunkTcp { - chunk_ip: ChunkIp { - timestamp: self.chunk_ip.timestamp, - source_ip: self.chunk_ip.source_ip, - destination_ip: self.chunk_ip.destination_ip, - tag: self.chunk_ip.tag.clone(), - }, - source_port: self.source_port, - destination_port: self.destination_port, - flags: self.flags, - user_data: self.user_data.clone(), - }) - } - - fn network(&self) -> String { - "tcp".to_owned() - } - - fn set_source_addr(&mut self, address: &str) -> Result<()> { - let addr = SocketAddr::from_str(address)?; - self.chunk_ip.source_ip = addr.ip(); - self.source_port = addr.port(); - Ok(()) - } - - fn set_destination_addr(&mut self, address: &str) -> Result<()> { - let addr = SocketAddr::from_str(address)?; - self.chunk_ip.destination_ip = addr.ip(); - self.destination_port = addr.port(); - Ok(()) - } -} - -impl ChunkTcp { - pub(crate) fn new(src_addr: SocketAddr, dst_addr: SocketAddr, flags: TcpFlag) -> Self { - ChunkTcp { - chunk_ip: ChunkIp { - timestamp: SystemTime::now(), - source_ip: src_addr.ip(), - destination_ip: dst_addr.ip(), - tag: assign_chunk_tag(), - }, - source_port: src_addr.port(), - destination_port: dst_addr.port(), - flags, - user_data: vec![], - } - } -} diff --git a/util/src/vnet/chunk/chunk_test.rs b/util/src/vnet/chunk/chunk_test.rs deleted file mode 100644 index 92b5ae381..000000000 --- a/util/src/vnet/chunk/chunk_test.rs +++ /dev/null @@ -1,59 +0,0 @@ -use super::*; -use crate::error::Result; - -#[test] -fn test_tcp_frag_string() { - let f = TCP_FLAG_FIN; - assert_eq!(f.to_string(), "FIN", "should match"); - let f = TCP_FLAG_SYN; - assert_eq!(f.to_string(), "SYN", "should match"); - let f = TCP_FLAG_RST; - assert_eq!(f.to_string(), "RST", "should match"); - let f = TCP_FLAG_PSH; - assert_eq!(f.to_string(), "PSH", "should match"); - let f = TCP_FLAG_ACK; - assert_eq!(f.to_string(), "ACK", "should match"); - let f = TCP_FLAG_SYN | TCP_FLAG_ACK; - assert_eq!(f.to_string(), "SYN-ACK", "should match"); -} - -const DEMO_IP: &str = "1.2.3.4"; - -#[test] -fn test_chunk_udp() -> Result<()> { - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst = SocketAddr::from_str(&(DEMO_IP.to_owned() + ":5678"))?; - - let mut c = ChunkUdp::new(src, dst); - let s = c.to_string(); - log::debug!("chunk: {}", s); - assert_eq!(c.network(), UDP_STR, "should match"); - assert!(s.contains(&src.to_string()), "should include address"); - assert!(s.contains(&dst.to_string()), "should include address"); - assert_eq!(c.get_source_ip(), src.ip(), "ip should match"); - assert_eq!(c.get_destination_ip(), dst.ip(), "ip should match"); - - // Test timestamp - let ts = c.set_timestamp(); - assert_eq!(ts, c.get_timestamp(), "timestamp should match"); - - c.user_data = "Hello".as_bytes().to_vec(); - - let cloned = c.clone_to(); - - // Test setSourceAddr - c.set_source_addr("2.3.4.5:4000")?; - assert_eq!(c.source_addr().to_string(), "2.3.4.5:4000"); - - // Test Tag() - assert!(!c.tag().is_empty(), "should not be empty"); - - // Verify cloned chunk was not affected by the changes to original chunk - c.user_data[0] = b'!'; // oroginal: "Hello" -> "Hell!" - assert_eq!(cloned.user_data(), "Hello".as_bytes(), "should match"); - assert_eq!(cloned.source_addr().to_string(), "192.168.0.2:1234"); - assert_eq!(cloned.get_source_ip(), src.ip(), "ip should match"); - assert_eq!(cloned.get_destination_ip(), dst.ip(), "ip should match"); - - Ok(()) -} diff --git a/util/src/vnet/chunk_queue.rs b/util/src/vnet/chunk_queue.rs deleted file mode 100644 index c90d342e3..000000000 --- a/util/src/vnet/chunk_queue.rs +++ /dev/null @@ -1,44 +0,0 @@ -#[cfg(test)] -mod chunk_queue_test; - -use std::collections::VecDeque; - -use tokio::sync::RwLock; - -use super::chunk::*; - -#[derive(Default)] -pub(crate) struct ChunkQueue { - chunks: RwLock>>, - max_size: usize, // 0 or negative value: unlimited -} - -impl ChunkQueue { - pub(crate) fn new(max_size: usize) -> Self { - ChunkQueue { - chunks: RwLock::new(VecDeque::new()), - max_size, - } - } - - pub(crate) async fn push(&self, c: Box) -> bool { - let mut chunks = self.chunks.write().await; - - if self.max_size > 0 && chunks.len() >= self.max_size { - false // dropped - } else { - chunks.push_back(c); - true - } - } - - pub(crate) async fn pop(&self) -> Option> { - let mut chunks = self.chunks.write().await; - chunks.pop_front() - } - - pub(crate) async fn peek(&self) -> Option> { - let chunks = self.chunks.read().await; - chunks.front().map(|chunk| chunk.clone_to()) - } -} diff --git a/util/src/vnet/chunk_queue/chunk_queue_test.rs b/util/src/vnet/chunk_queue/chunk_queue_test.rs deleted file mode 100644 index b865d3e01..000000000 --- a/util/src/vnet/chunk_queue/chunk_queue_test.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::net::SocketAddr; -use std::str::FromStr; - -use super::*; -use crate::error::Result; - -const DEMO_IP: &str = "1.2.3.4"; - -#[tokio::test] -async fn test_chunk_queue() -> Result<()> { - let c: Box = Box::new(ChunkUdp::new( - SocketAddr::from_str("192.188.0.2:1234")?, - SocketAddr::from_str(&(DEMO_IP.to_owned() + ":5678"))?, - )); - - let q = ChunkQueue::new(0); - - let d = q.peek().await; - assert!(d.is_none(), "should return none"); - - let ok = q.push(c.clone_to()).await; - assert!(ok, "should succeed"); - - let d = q.pop().await; - assert!(d.is_some(), "should succeed"); - if let Some(d) = d { - assert_eq!(c.to_string(), d.to_string(), "should be the same"); - } - - let d = q.pop().await; - assert!(d.is_none(), "should fail"); - - let q = ChunkQueue::new(1); - let ok = q.push(c.clone_to()).await; - assert!(ok, "should succeed"); - - let ok = q.push(c.clone_to()).await; - assert!(!ok, "should fail"); - - let d = q.peek().await; - assert!(d.is_some(), "should succeed"); - if let Some(d) = d { - assert_eq!(c.to_string(), d.to_string(), "should be the same"); - } - - Ok(()) -} diff --git a/util/src/vnet/conn.rs b/util/src/vnet/conn.rs deleted file mode 100644 index 676ca495d..000000000 --- a/util/src/vnet/conn.rs +++ /dev/null @@ -1,164 +0,0 @@ -#[cfg(test)] -mod conn_test; - -use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use portable_atomic::AtomicBool; -use tokio::sync::{mpsc, Mutex}; - -use crate::conn::Conn; -use crate::error::*; -use crate::sync::RwLock; -use crate::vnet::chunk::{Chunk, ChunkUdp}; - -const MAX_READ_QUEUE_SIZE: usize = 1024; - -/// vNet implements this -#[async_trait] -pub(crate) trait ConnObserver { - async fn write(&self, c: Box) -> Result<()>; - async fn on_closed(&self, addr: SocketAddr); - fn determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option; -} - -pub(crate) type ChunkChTx = mpsc::Sender>; - -/// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. -/// compatible with net.PacketConn and net.Conn -pub(crate) struct UdpConn { - loc_addr: SocketAddr, - rem_addr: RwLock>, - read_ch_tx: Arc>>, - read_ch_rx: Mutex>>, - closed: AtomicBool, - obs: Arc>, -} - -impl UdpConn { - pub(crate) fn new( - loc_addr: SocketAddr, - rem_addr: Option, - obs: Arc>, - ) -> Self { - let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE); - - UdpConn { - loc_addr, - rem_addr: RwLock::new(rem_addr), - read_ch_tx: Arc::new(Mutex::new(Some(read_ch_tx))), - read_ch_rx: Mutex::new(read_ch_rx), - closed: AtomicBool::new(false), - obs, - } - } - - pub(crate) fn get_inbound_ch(&self) -> Arc>> { - Arc::clone(&self.read_ch_tx) - } -} - -#[async_trait] -impl Conn for UdpConn { - async fn connect(&self, addr: SocketAddr) -> Result<()> { - self.rem_addr.write().replace(addr); - - Ok(()) - } - async fn recv(&self, buf: &mut [u8]) -> Result { - let (n, _) = self.recv_from(buf).await?; - Ok(n) - } - - /// recv_from reads a packet from the connection, - /// copying the payload into p. It returns the number of - /// bytes copied into p and the return address that - /// was on the packet. - /// It returns the number of bytes read (0 <= n <= len(p)) - /// and any error encountered. Callers should always process - /// the n > 0 bytes returned before considering the error err. - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - let mut read_ch = self.read_ch_rx.lock().await; - let rem_addr = *self.rem_addr.read(); - while let Some(chunk) = read_ch.recv().await { - let user_data = chunk.user_data(); - let n = std::cmp::min(buf.len(), user_data.len()); - buf[..n].copy_from_slice(&user_data[..n]); - let addr = chunk.source_addr(); - { - if let Some(rem_addr) = &rem_addr { - if &addr != rem_addr { - continue; // discard (shouldn't happen) - } - } - } - return Ok((n, addr)); - } - - Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Connection Aborted").into()) - } - - async fn send(&self, buf: &[u8]) -> Result { - let rem_addr = *self.rem_addr.read(); - if let Some(rem_addr) = rem_addr { - self.send_to(buf, rem_addr).await - } else { - Err(Error::ErrNoRemAddr) - } - } - - /// send_to writes a packet with payload p to addr. - /// send_to can be made to time out and return - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result { - let src_ip = { - let obs = self.obs.lock().await; - match obs.determine_source_ip(self.loc_addr.ip(), target.ip()) { - Some(ip) => ip, - None => return Err(Error::ErrLocAddr), - } - }; - - let src_addr = SocketAddr::new(src_ip, self.loc_addr.port()); - - let mut chunk = ChunkUdp::new(src_addr, target); - chunk.user_data = buf.to_vec(); - { - let c: Box = Box::new(chunk); - let obs = self.obs.lock().await; - obs.write(c).await? - } - - Ok(buf.len()) - } - - fn local_addr(&self) -> Result { - Ok(self.loc_addr) - } - - fn remote_addr(&self) -> Option { - *self.rem_addr.read() - } - - async fn close(&self) -> Result<()> { - if self.closed.load(Ordering::SeqCst) { - return Err(Error::ErrAlreadyClosed); - } - self.closed.store(true, Ordering::SeqCst); - { - let mut reach_ch = self.read_ch_tx.lock().await; - reach_ch.take(); - } - { - let obs = self.obs.lock().await; - obs.on_closed(self.loc_addr).await; - } - - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/util/src/vnet/conn/conn_test.rs b/util/src/vnet/conn/conn_test.rs deleted file mode 100644 index b8fe072e5..000000000 --- a/util/src/vnet/conn/conn_test.rs +++ /dev/null @@ -1,208 +0,0 @@ -use std::str::FromStr; - -use portable_atomic::AtomicUsize; - -use super::*; - -#[derive(Default)] -struct DummyObserver { - nclosed: Arc, - #[allow(clippy::type_complexity)] - read_ch_tx: Arc>>>>, -} - -#[async_trait] -impl ConnObserver for DummyObserver { - async fn write(&self, c: Box) -> Result<()> { - let mut chunk = ChunkUdp::new(c.destination_addr(), c.source_addr()); - chunk.user_data = c.user_data(); - - let read_ch_tx = self.read_ch_tx.lock().await; - if let Some(tx) = &*read_ch_tx { - tx.send(Box::new(chunk)) - .await - .map_err(|e| Error::Other(e.to_string()))?; - } - Ok(()) - } - - async fn on_closed(&self, _addr: SocketAddr) { - self.nclosed.fetch_add(1, Ordering::SeqCst); - } - - fn determine_source_ip(&self, loc_ip: IpAddr, _dst_ip: IpAddr) -> Option { - Some(loc_ip) - } -} - -//use std::io::Write; - -#[tokio::test] -async fn test_udp_conn_send_to_recv_from() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let nclosed = Arc::new(AtomicUsize::new(0)); - let data = b"Hello".to_vec(); - let src_addr = SocketAddr::from_str("127.0.0.1:1234")?; - let dst_addr = SocketAddr::from_str("127.0.0.1:5678")?; - - let dummy_obs = Arc::new(Mutex::new(DummyObserver::default())); - let dummy_obs2 = Arc::clone(&dummy_obs); - let obs = dummy_obs2 as Arc>; - - let conn = Arc::new(UdpConn::new(src_addr, None, obs)); - { - let mut dummy = dummy_obs.lock().await; - dummy.nclosed = Arc::clone(&nclosed); - dummy.read_ch_tx = conn.get_inbound_ch(); - } - - let conn_rx = Arc::clone(&conn); - let data_rx = data.clone(); - - let (rcvd_ch_tx, mut rcvd_ch_rx) = mpsc::channel(1); - let (done_ch_tx, mut done_ch_rx) = mpsc::channel::<()>(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - - loop { - let (n, addr) = match conn_rx.recv_from(&mut buf).await { - Ok((n, addr)) => (n, addr), - Err(err) => { - log::debug!("conn closed. exiting the read loop with err {}", err); - break; - } - }; - - log::debug!("read data"); - assert_eq!(data_rx.len(), n, "should match"); - assert_eq!(&data_rx, &buf[..n], "should match"); - log::debug!("dst_addr {} vs add {}", dst_addr, addr); - assert_eq!(dst_addr.to_string(), addr.to_string(), "should match"); - let _ = rcvd_ch_tx.send(()).await; - } - - drop(done_ch_tx); - }); - - let n = conn.send_to(&data, dst_addr).await.unwrap(); - assert_eq!(n, data.len(), "should match"); - - loop { - tokio::select! { - result = rcvd_ch_rx.recv() =>{ - if result.is_some(){ - log::debug!("closing soon..."); - conn.close().await?; - } - } - _ = done_ch_rx.recv() => { - log::debug!("recv done_ch_rx..."); - break; - } - } - } - - assert_eq!(1, nclosed.load(Ordering::SeqCst), "should be closed once"); - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_udp_conn_send_recv() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let nclosed = Arc::new(AtomicUsize::new(0)); - let data = b"Hello".to_vec(); - let src_addr = SocketAddr::from_str("127.0.0.1:1234")?; - let dst_addr = SocketAddr::from_str("127.0.0.1:5678")?; - - let dummy_obs = Arc::new(Mutex::new(DummyObserver::default())); - let dummy_obs2 = Arc::clone(&dummy_obs); - let obs = dummy_obs2 as Arc>; - - let conn = Arc::new(UdpConn::new(src_addr, Some(dst_addr), obs)); - { - let mut dummy = dummy_obs.lock().await; - dummy.nclosed = Arc::clone(&nclosed); - dummy.read_ch_tx = conn.get_inbound_ch(); - } - - let conn_rx = Arc::clone(&conn); - let data_rx = data.clone(); - - let (rcvd_ch_tx, mut rcvd_ch_rx) = mpsc::channel(1); - let (done_ch_tx, mut done_ch_rx) = mpsc::channel::<()>(1); - - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - - loop { - let n = match conn_rx.recv(&mut buf).await { - Ok(n) => n, - Err(err) => { - log::debug!("conn closed. exiting the read loop with err {}", err); - break; - } - }; - - log::debug!("read data"); - assert_eq!(data_rx.len(), n, "should match"); - assert_eq!(&data_rx, &buf[..n], "should match"); - let _ = rcvd_ch_tx.send(()).await; - } - - drop(done_ch_tx); - }); - - let n = conn.send(&data).await.unwrap(); - assert_eq!(n, data.len(), "should match"); - - loop { - tokio::select! { - result = rcvd_ch_rx.recv() =>{ - if result.is_some(){ - log::debug!("closing soon..."); - conn.close().await?; - } - } - _ = done_ch_rx.recv() => { - log::debug!("recv done_ch_rx..."); - break; - } - } - } - - assert_eq!(1, nclosed.load(Ordering::SeqCst), "should be closed once"); - - Ok(()) -} diff --git a/util/src/vnet/conn_map.rs b/util/src/vnet/conn_map.rs deleted file mode 100644 index 27dca8268..000000000 --- a/util/src/vnet/conn_map.rs +++ /dev/null @@ -1,120 +0,0 @@ -#[cfg(test)] -mod conn_map_test; - -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; - -use tokio::sync::Mutex; - -use crate::error::*; -use crate::vnet::conn::UdpConn; -use crate::Conn; - -type PortMap = Mutex>>>; - -#[derive(Default)] -pub(crate) struct UdpConnMap { - port_map: PortMap, -} - -impl UdpConnMap { - pub(crate) fn new() -> Self { - UdpConnMap { - port_map: Mutex::new(HashMap::new()), - } - } - - pub(crate) async fn insert(&self, conn: Arc) -> Result<()> { - let addr = conn.local_addr()?; - - let mut port_map = self.port_map.lock().await; - if let Some(conns) = port_map.get(&addr.port()) { - if addr.ip().is_unspecified() { - return Err(Error::ErrAddressAlreadyInUse); - } - - for c in conns { - let laddr = c.local_addr()?; - if laddr.ip().is_unspecified() || laddr.ip() == addr.ip() { - return Err(Error::ErrAddressAlreadyInUse); - } - } - } - - if let Some(conns) = port_map.get_mut(&addr.port()) { - conns.push(conn); - } else { - port_map.insert(addr.port(), vec![conn]); - } - Ok(()) - } - - pub(crate) async fn find(&self, addr: &SocketAddr) -> Option> { - let port_map = self.port_map.lock().await; - if let Some(conns) = port_map.get(&addr.port()) { - if addr.ip().is_unspecified() { - // pick the first one appears in the iteration - if let Some(c) = conns.first() { - return Some(Arc::clone(c)); - } else { - return None; - } - } - - for c in conns { - let laddr = { - match c.local_addr() { - Ok(laddr) => laddr, - Err(_) => return None, - } - }; - if laddr.ip().is_unspecified() || laddr.ip() == addr.ip() { - return Some(Arc::clone(c)); - } - } - } - - None - } - - pub(crate) async fn delete(&self, addr: &SocketAddr) -> Result<()> { - let mut port_map = self.port_map.lock().await; - let mut new_conns = vec![]; - if let Some(conns) = port_map.get(&addr.port()) { - if !addr.ip().is_unspecified() { - for c in conns { - let laddr = c.local_addr()?; - if laddr.ip().is_unspecified() { - // This can't happen! - return Err(Error::ErrCannotRemoveUnspecifiedIp); - } - - if laddr.ip() == addr.ip() { - continue; - } - new_conns.push(Arc::clone(c)); - } - } - } else { - return Err(Error::ErrNoSuchUdpConn); - } - - if new_conns.is_empty() { - port_map.remove(&addr.port()); - } else { - port_map.insert(addr.port(), new_conns); - } - - Ok(()) - } - - pub(crate) async fn len(&self) -> usize { - let port_map = self.port_map.lock().await; - let mut n = 0; - for conns in port_map.values() { - n += conns.len(); - } - n - } -} diff --git a/util/src/vnet/conn_map/conn_map_test.rs b/util/src/vnet/conn_map/conn_map_test.rs deleted file mode 100644 index a3a6e7049..000000000 --- a/util/src/vnet/conn_map/conn_map_test.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::net::IpAddr; -use std::str::FromStr; - -use async_trait::async_trait; - -use super::*; -use crate::vnet::chunk::*; -use crate::vnet::conn::*; - -#[derive(Default)] -struct DummyObserver; - -#[async_trait] -impl ConnObserver for DummyObserver { - async fn write(&self, _c: Box) -> Result<()> { - Ok(()) - } - - async fn on_closed(&self, _addr: SocketAddr) {} - - fn determine_source_ip(&self, loc_ip: IpAddr, _dst_ip: IpAddr) -> Option { - Some(loc_ip) - } -} - -#[tokio::test] -async fn test_udp_conn_map_insert_remove() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in = Arc::new(UdpConn::new( - SocketAddr::from_str("127.0.0.1:1234")?, - None, - obs, - )); - - conn_map.insert(Arc::clone(&conn_in)).await?; - - let conn_out = conn_map.find(&conn_in.local_addr()?).await; - assert!(conn_out.is_some(), "should succeed"); - if let Some(conn_out) = conn_out { - assert_eq!( - conn_in.local_addr()?, - conn_out.local_addr()?, - "should match" - ); - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 1, "should match"); - } - - conn_map.delete(&conn_in.local_addr()?).await?; - { - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 0, "should match"); - } - - let result = conn_map.delete(&conn_in.local_addr()?).await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_insert_0_remove() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in = Arc::new(UdpConn::new( - SocketAddr::from_str("0.0.0.0:1234")?, - None, - obs, - )); - - conn_map.insert(Arc::clone(&conn_in)).await?; - - let conn_out = conn_map.find(&conn_in.local_addr()?).await; - assert!(conn_out.is_some(), "should succeed"); - if let Some(conn_out) = conn_out { - assert_eq!( - conn_in.local_addr()?, - conn_out.local_addr()?, - "should match" - ); - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 1, "should match"); - } - - conn_map.delete(&conn_in.local_addr()?).await?; - { - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 0, "should match"); - } - - let result = conn_map.delete(&conn_in.local_addr()?).await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_find_0() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in = Arc::new(UdpConn::new( - SocketAddr::from_str("0.0.0.0:1234")?, - None, - obs, - )); - - conn_map.insert(Arc::clone(&conn_in)).await?; - - let addr = SocketAddr::from_str("192.168.0.1:1234")?; - let conn_out = conn_map.find(&addr).await; - assert!(conn_out.is_some(), "should succeed"); - if let Some(conn_out) = conn_out { - let addr_in = conn_in.local_addr()?; - let addr_out = conn_out.local_addr()?; - assert_eq!(addr_in, addr_out, "should match"); - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 1, "should match"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_insert_many_ips_with_same_port() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in1 = Arc::new(UdpConn::new( - SocketAddr::from_str("10.1.2.1:5678")?, - None, - Arc::clone(&obs), - )); - - let conn_in2 = Arc::new(UdpConn::new( - SocketAddr::from_str("10.1.2.2:5678")?, - None, - Arc::clone(&obs), - )); - - conn_map.insert(Arc::clone(&conn_in1)).await?; - conn_map.insert(Arc::clone(&conn_in2)).await?; - - let addr1 = SocketAddr::from_str("10.1.2.1:5678")?; - let conn_out1 = conn_map.find(&addr1).await; - assert!(conn_out1.is_some(), "should succeed"); - if let Some(conn_out1) = conn_out1 { - let addr_in = conn_in1.local_addr()?; - let addr_out = conn_out1.local_addr()?; - assert_eq!(addr_in, addr_out, "should match"); - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 1, "should match"); - } - - let addr2 = SocketAddr::from_str("10.1.2.2:5678")?; - let conn_out2 = conn_map.find(&addr2).await; - assert!(conn_out2.is_some(), "should succeed"); - if let Some(conn_out2) = conn_out2 { - let addr_in = conn_in2.local_addr()?; - let addr_out = conn_out2.local_addr()?; - assert_eq!(addr_in, addr_out, "should match"); - let port_map = conn_map.port_map.lock().await; - assert_eq!(port_map.len(), 1, "should match"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_already_inuse_when_insert_0() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in1 = Arc::new(UdpConn::new( - SocketAddr::from_str("10.1.2.1:5678")?, - None, - Arc::clone(&obs), - )); - let conn_in2 = Arc::new(UdpConn::new( - SocketAddr::from_str("0.0.0.0:5678")?, - None, - Arc::clone(&obs), - )); - - conn_map.insert(Arc::clone(&conn_in1)).await?; - let result = conn_map.insert(Arc::clone(&conn_in2)).await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_already_inuse_when_insert_a_specified_ip() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in1 = Arc::new(UdpConn::new( - SocketAddr::from_str("0.0.0.0:5678")?, - None, - Arc::clone(&obs), - )); - let conn_in2 = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.1:5678")?, - None, - Arc::clone(&obs), - )); - - conn_map.insert(Arc::clone(&conn_in1)).await?; - let result = conn_map.insert(Arc::clone(&conn_in2)).await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_already_inuse_when_insert_same_specified_ip() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in1 = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.1:5678")?, - None, - Arc::clone(&obs), - )); - let conn_in2 = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.1:5678")?, - None, - Arc::clone(&obs), - )); - - conn_map.insert(Arc::clone(&conn_in1)).await?; - let result = conn_map.insert(Arc::clone(&conn_in2)).await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_find_failure_1() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.1:5678")?, - None, - obs, - )); - - conn_map.insert(Arc::clone(&conn_in)).await?; - - let addr = SocketAddr::from_str("192.168.0.2:5678")?; - let result = conn_map.find(&addr).await; - assert!(result.is_none(), "should be none"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_find_failure_2() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.1:5678")?, - None, - obs, - )); - - conn_map.insert(Arc::clone(&conn_in)).await?; - - let addr = SocketAddr::from_str("192.168.0.1:1234")?; - let result = conn_map.find(&addr).await; - assert!(result.is_none(), "should be none"); - - Ok(()) -} - -#[tokio::test] -async fn test_udp_conn_map_insert_two_on_same_port_then_remove() -> Result<()> { - let conn_map = UdpConnMap::new(); - - let obs: Arc> = Arc::new(Mutex::new(DummyObserver)); - - let conn_in1 = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.1:5678")?, - None, - Arc::clone(&obs), - )); - let conn_in2 = Arc::new(UdpConn::new( - SocketAddr::from_str("192.168.0.2:5678")?, - None, - Arc::clone(&obs), - )); - - conn_map.insert(Arc::clone(&conn_in1)).await?; - conn_map.insert(Arc::clone(&conn_in2)).await?; - - conn_map.delete(&conn_in1.local_addr()?).await?; - conn_map.delete(&conn_in2.local_addr()?).await?; - - Ok(()) -} diff --git a/util/src/vnet/interface.rs b/util/src/vnet/interface.rs deleted file mode 100644 index 7aac9e7e3..000000000 --- a/util/src/vnet/interface.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::net::SocketAddr; - -use ipnet::*; - -use crate::error::*; - -#[derive(Debug, Clone, Default)] -pub struct Interface { - pub(crate) name: String, - pub(crate) addrs: Vec, -} - -impl Interface { - pub fn new(name: String, addrs: Vec) -> Self { - Interface { name, addrs } - } - - pub fn add_addr(&mut self, addr: IpNet) { - self.addrs.push(addr); - } - - pub fn name(&self) -> &str { - &self.name - } - pub fn addrs(&self) -> &[IpNet] { - &self.addrs - } - - pub fn convert(addr: SocketAddr, mask: Option) -> Result { - if let Some(mask) = mask { - Ok(IpNet::with_netmask(addr.ip(), mask.ip()).map_err(|_| Error::ErrInvalidMask)?) - } else { - Ok(IpNet::new(addr.ip(), if addr.is_ipv4() { 32 } else { 128 }) - .expect("ipv4 should always work with prefix 32 and ipv6 with prefix 128")) - } - } -} diff --git a/util/src/vnet/mod.rs b/util/src/vnet/mod.rs deleted file mode 100644 index 41b7c459f..000000000 --- a/util/src/vnet/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod chunk; -pub(crate) mod chunk_queue; -pub(crate) mod conn; -pub(crate) mod conn_map; -pub mod interface; -pub mod nat; -pub mod net; -pub(crate) mod resolver; -pub mod router; diff --git a/util/src/vnet/nat.rs b/util/src/vnet/nat.rs deleted file mode 100644 index c4905d3b3..000000000 --- a/util/src/vnet/nat.rs +++ /dev/null @@ -1,464 +0,0 @@ -#[cfg(test)] -mod nat_test; - -use std::collections::{HashMap, HashSet}; -use std::net::IpAddr; -use std::ops::Add; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use portable_atomic::AtomicU16; -use tokio::sync::Mutex; -use tokio::time::Duration; - -use crate::error::*; -use crate::vnet::chunk::Chunk; -use crate::vnet::net::UDP_STR; - -const DEFAULT_NAT_MAPPING_LIFE_TIME: Duration = Duration::from_secs(30); - -// EndpointDependencyType defines a type of behavioral dependency on the -// remote endpoint's IP address or port number. This is used for the two -// kinds of behaviors: -// - Port Mapping behavior -// - Filtering behavior -// See: https://tools.ietf.org/html/rfc4787 -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum EndpointDependencyType { - // EndpointIndependent means the behavior is independent of the endpoint's address or port - #[default] - EndpointIndependent, - // EndpointAddrDependent means the behavior is dependent on the endpoint's address - EndpointAddrDependent, - // EndpointAddrPortDependent means the behavior is dependent on the endpoint's address and port - EndpointAddrPortDependent, -} - -// NATMode defines basic behavior of the NAT -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum NatMode { - // NATModeNormal means the NAT behaves as a standard NAPT (RFC 2663). - #[default] - Normal, - // NATModeNAT1To1 exhibits 1:1 DNAT where the external IP address is statically mapped to - // a specific local IP address with port number is preserved always between them. - // When this mode is selected, mapping_behavior, filtering_behavior, port_preservation and - // mapping_life_time of NATType are ignored. - Nat1To1, -} - -// NATType has a set of parameters that define the behavior of NAT. -#[derive(Default, Debug, Copy, Clone)] -pub struct NatType { - pub mode: NatMode, - pub mapping_behavior: EndpointDependencyType, - pub filtering_behavior: EndpointDependencyType, - pub hair_pining: bool, // Not implemented yet - pub port_preservation: bool, // Not implemented yet - pub mapping_life_time: Duration, -} - -#[derive(Default, Debug, Clone)] -pub(crate) struct NatConfig { - pub(crate) name: String, - pub(crate) nat_type: NatType, - pub(crate) mapped_ips: Vec, // mapped IPv4 - pub(crate) local_ips: Vec, // local IPv4, required only when the mode is NATModeNAT1To1 -} - -#[derive(Debug, Clone)] -pub(crate) struct Mapping { - proto: String, // "udp" or "tcp" - local: String, // ":" - mapped: String, // ":" - bound: String, // key: "[[:]]" - filters: Arc>>, // key: "[[:]]" - expires: Arc>, // time to expire -} - -impl Default for Mapping { - fn default() -> Self { - Mapping { - proto: String::new(), // "udp" or "tcp" - local: String::new(), // ":" - mapped: String::new(), // ":" - bound: String::new(), // key: "[[:]]" - filters: Arc::new(Mutex::new(HashSet::new())), // key: "[[:]]" - expires: Arc::new(Mutex::new(SystemTime::now())), // time to expire - } - } -} - -#[derive(Default, Debug, Clone)] -pub(crate) struct NetworkAddressTranslator { - pub(crate) name: String, - pub(crate) nat_type: NatType, - pub(crate) mapped_ips: Vec, // mapped IPv4 - pub(crate) local_ips: Vec, // local IPv4, required only when the mode is NATModeNAT1To1 - pub(crate) outbound_map: Arc>>>, // key: "::[:remote-ip[:remote-port]] - pub(crate) inbound_map: Arc>>>, // key: "::" - pub(crate) udp_port_counter: Arc, -} - -impl NetworkAddressTranslator { - pub(crate) fn new(config: NatConfig) -> Result { - let mut nat_type = config.nat_type; - - if nat_type.mode == NatMode::Nat1To1 { - // 1:1 NAT behavior - nat_type.mapping_behavior = EndpointDependencyType::EndpointIndependent; - nat_type.filtering_behavior = EndpointDependencyType::EndpointIndependent; - nat_type.port_preservation = true; - nat_type.mapping_life_time = Duration::from_secs(0); - - if config.mapped_ips.is_empty() { - return Err(Error::ErrNatRequiresMapping); - } - if config.mapped_ips.len() != config.local_ips.len() { - return Err(Error::ErrMismatchLengthIp); - } - } else { - // Normal (NAPT) behavior - nat_type.mode = NatMode::Normal; - if nat_type.mapping_life_time == Duration::from_secs(0) { - nat_type.mapping_life_time = DEFAULT_NAT_MAPPING_LIFE_TIME; - } - } - - Ok(NetworkAddressTranslator { - name: config.name, - nat_type, - mapped_ips: config.mapped_ips, - local_ips: config.local_ips, - outbound_map: Arc::new(Mutex::new(HashMap::new())), - inbound_map: Arc::new(Mutex::new(HashMap::new())), - udp_port_counter: Arc::new(AtomicU16::new(0)), - }) - } - - pub(crate) fn get_paired_mapped_ip(&self, loc_ip: &IpAddr) -> Option<&IpAddr> { - for (i, ip) in self.local_ips.iter().enumerate() { - if ip == loc_ip { - return self.mapped_ips.get(i); - } - } - None - } - - pub(crate) fn get_paired_local_ip(&self, mapped_ip: &IpAddr) -> Option<&IpAddr> { - for (i, ip) in self.mapped_ips.iter().enumerate() { - if ip == mapped_ip { - return self.local_ips.get(i); - } - } - None - } - - pub(crate) async fn translate_outbound( - &self, - from: &(dyn Chunk + Send + Sync), - ) -> Result>> { - let mut to = from.clone_to(); - - if from.network() == UDP_STR { - if self.nat_type.mode == NatMode::Nat1To1 { - // 1:1 NAT behavior - let src_addr = from.source_addr(); - if let Some(src_ip) = self.get_paired_mapped_ip(&src_addr.ip()) { - to.set_source_addr(&format!("{}:{}", src_ip, src_addr.port()))?; - } else { - log::debug!( - "[{}] drop outbound chunk {} with not route", - self.name, - from - ); - return Ok(None); // silently discard - } - } else { - // Normal (NAPT) behavior - let bound = match self.nat_type.mapping_behavior { - EndpointDependencyType::EndpointIndependent => "".to_owned(), - EndpointDependencyType::EndpointAddrDependent => { - from.get_destination_ip().to_string() - } - EndpointDependencyType::EndpointAddrPortDependent => { - from.destination_addr().to_string() - } - }; - - let filter_key = match self.nat_type.filtering_behavior { - EndpointDependencyType::EndpointIndependent => "".to_owned(), - EndpointDependencyType::EndpointAddrDependent => { - from.get_destination_ip().to_string() - } - EndpointDependencyType::EndpointAddrPortDependent => { - from.destination_addr().to_string() - } - }; - - let o_key = format!("udp:{}:{}", from.source_addr(), bound); - let name = self.name.clone(); - - let m_mapped = if let Some(m) = self.find_outbound_mapping(&o_key).await { - let mut filters = m.filters.lock().await; - if !filters.contains(&filter_key) { - log::debug!( - "[{}] permit access from {} to {}", - name, - filter_key, - m.mapped - ); - filters.insert(filter_key); - } - m.mapped.clone() - } else { - // Create a new Mapping - let udp_port_counter = self.udp_port_counter.load(Ordering::SeqCst); - let mapped_port = 0xC000 + udp_port_counter; - if udp_port_counter == 0xFFFF - 0xC000 { - self.udp_port_counter.store(0, Ordering::SeqCst); - } else { - self.udp_port_counter.fetch_add(1, Ordering::SeqCst); - } - - let m = if let Some(mapped_ips_first) = self.mapped_ips.first() { - Mapping { - proto: "udp".to_owned(), - local: from.source_addr().to_string(), - bound, - mapped: format!("{mapped_ips_first}:{mapped_port}"), - filters: Arc::new(Mutex::new(HashSet::new())), - expires: Arc::new(Mutex::new( - SystemTime::now().add(self.nat_type.mapping_life_time), - )), - } - } else { - return Err(Error::ErrNatRequiresMapping); - }; - - { - let mut outbound_map = self.outbound_map.lock().await; - outbound_map.insert(o_key.clone(), Arc::new(m.clone())); - } - - let i_key = format!("udp:{}", m.mapped); - - log::debug!( - "[{}] created a new NAT binding oKey={} i_key={}", - self.name, - o_key, - i_key - ); - log::debug!( - "[{}] permit access from {} to {}", - self.name, - filter_key, - m.mapped - ); - - { - let mut filters = m.filters.lock().await; - filters.insert(filter_key); - } - - let m_mapped = m.mapped.clone(); - { - let mut inbound_map = self.inbound_map.lock().await; - inbound_map.insert(i_key, Arc::new(m)); - } - m_mapped - }; - - to.set_source_addr(&m_mapped)?; - } - - log::debug!( - "[{}] translate outbound chunk from {} to {}", - self.name, - from, - to - ); - - return Ok(Some(to)); - } - - Err(Error::ErrNonUdpTranslationNotSupported) - } - - pub(crate) async fn translate_inbound( - &self, - from: &(dyn Chunk + Send + Sync), - ) -> Result>> { - let mut to = from.clone_to(); - - if from.network() == UDP_STR { - if self.nat_type.mode == NatMode::Nat1To1 { - // 1:1 NAT behavior - let dst_addr = from.destination_addr(); - if let Some(dst_ip) = self.get_paired_local_ip(&dst_addr.ip()) { - let dst_port = from.destination_addr().port(); - to.set_destination_addr(&format!("{dst_ip}:{dst_port}"))?; - } else { - return Err(Error::Other(format!( - "drop {from} as {:?}", - Error::ErrNoAssociatedLocalAddress - ))); - } - } else { - // Normal (NAPT) behavior - let filter_key = match self.nat_type.filtering_behavior { - EndpointDependencyType::EndpointIndependent => "".to_owned(), - EndpointDependencyType::EndpointAddrDependent => { - from.get_source_ip().to_string() - } - EndpointDependencyType::EndpointAddrPortDependent => { - from.source_addr().to_string() - } - }; - - let i_key = format!("udp:{}", from.destination_addr()); - if let Some(m) = self.find_inbound_mapping(&i_key).await { - { - let filters = m.filters.lock().await; - if !filters.contains(&filter_key) { - return Err(Error::Other(format!( - "drop {} as the remote {} {:?}", - from, - filter_key, - Error::ErrHasNoPermission - ))); - } - } - - // See RFC 4847 Section 4.3. Mapping Refresh - // a) Inbound refresh may be useful for applications with no outgoing - // UDP traffic. However, allowing inbound refresh may allow an - // external attacker or misbehaving application to keep a Mapping - // alive indefinitely. This may be a security risk. Also, if the - // process is repeated with different ports, over time, it could - // use up all the ports on the NAT. - - to.set_destination_addr(&m.local)?; - } else { - return Err(Error::Other(format!( - "drop {} as {:?}", - from, - Error::ErrNoNatBindingFound - ))); - } - } - - log::debug!( - "[{}] translate inbound chunk from {} to {}", - self.name, - from, - to - ); - - return Ok(Some(to)); - } - - Err(Error::ErrNonUdpTranslationNotSupported) - } - - // caller must hold the mutex - pub(crate) async fn find_outbound_mapping(&self, o_key: &str) -> Option> { - let mapping_life_time = self.nat_type.mapping_life_time; - let mut expired = false; - let (in_key, out_key) = { - let outbound_map = self.outbound_map.lock().await; - if let Some(m) = outbound_map.get(o_key) { - let now = SystemTime::now(); - - { - let mut expires = m.expires.lock().await; - // check if this Mapping is expired - if now.duration_since(*expires).is_ok() { - expired = true; - } else { - *expires = now.add(mapping_life_time); - } - } - ( - NetworkAddressTranslator::get_inbound_map_key(m), - NetworkAddressTranslator::get_outbound_map_key(m), - ) - } else { - (String::new(), String::new()) - } - }; - - if expired { - { - let mut inbound_map = self.inbound_map.lock().await; - inbound_map.remove(&in_key); - } - { - let mut outbound_map = self.outbound_map.lock().await; - outbound_map.remove(&out_key); - } - } - - let outbound_map = self.outbound_map.lock().await; - outbound_map.get(o_key).cloned() - } - - // caller must hold the mutex - pub(crate) async fn find_inbound_mapping(&self, i_key: &str) -> Option> { - let mut expired = false; - let (in_key, out_key) = { - let inbound_map = self.inbound_map.lock().await; - if let Some(m) = inbound_map.get(i_key) { - let now = SystemTime::now(); - - { - let expires = m.expires.lock().await; - // check if this Mapping is expired - if now.duration_since(*expires).is_ok() { - expired = true; - } - } - ( - NetworkAddressTranslator::get_inbound_map_key(m), - NetworkAddressTranslator::get_outbound_map_key(m), - ) - } else { - (String::new(), String::new()) - } - }; - - if expired { - { - let mut inbound_map = self.inbound_map.lock().await; - inbound_map.remove(&in_key); - } - { - let mut outbound_map = self.outbound_map.lock().await; - outbound_map.remove(&out_key); - } - } - - let inbound_map = self.inbound_map.lock().await; - inbound_map.get(i_key).cloned() - } - - // caller must hold the mutex - fn get_outbound_map_key(m: &Mapping) -> String { - format!("{}:{}:{}", m.proto, m.local, m.bound) - } - - fn get_inbound_map_key(m: &Mapping) -> String { - format!("{}:{}", m.proto, m.mapped) - } - - async fn inbound_map_len(&self) -> usize { - let inbound_map = self.inbound_map.lock().await; - inbound_map.len() - } - - async fn outbound_map_len(&self) -> usize { - let outbound_map = self.outbound_map.lock().await; - outbound_map.len() - } -} diff --git a/util/src/vnet/nat/nat_test.rs b/util/src/vnet/nat/nat_test.rs deleted file mode 100644 index 461db9491..000000000 --- a/util/src/vnet/nat/nat_test.rs +++ /dev/null @@ -1,638 +0,0 @@ -use std::net::SocketAddr; -use std::str::FromStr; - -use super::*; -use crate::vnet::chunk::ChunkUdp; - -// oic: outbound internal chunk -// oec: outbound external chunk -// iic: inbound internal chunk -// iec: inbound external chunk - -const DEMO_IP: &str = "1.2.3.4"; - -#[test] -fn test_nat_type_default() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - assert_eq!( - nat.nat_type.mapping_behavior, - EndpointDependencyType::EndpointIndependent, - "should match" - ); - assert_eq!( - nat.nat_type.filtering_behavior, - EndpointDependencyType::EndpointIndependent, - "should match" - ); - assert!(!nat.nat_type.hair_pining, "should be false"); - assert!(!nat.nat_type.port_preservation, "should be false"); - assert_eq!( - nat.nat_type.mapping_life_time, DEFAULT_NAT_MAPPING_LIFE_TIME, - "should be false" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_behavior_full_cone_nat() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointIndependent, - filtering_behavior: EndpointDependencyType::EndpointIndependent, - hair_pining: false, - mapping_life_time: Duration::from_secs(30), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst = SocketAddr::from_str("5.6.7.8:5678")?; - - let oic = ChunkUdp::new(src, dst); - - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - - log::debug!("o-original : {}", oic); - log::debug!("o-translated: {}", oec); - - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - log::debug!("i-original : {}", iec); - - let iic = nat.translate_inbound(&iec).await?.unwrap(); - - log::debug!("i-translated: {}", iic); - - assert_eq!(oic.source_addr(), iic.destination_addr(), "should match"); - - // packet with dest addr that does not exist in the mapping table - // will be dropped - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port() + 1), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should fail (dropped)"); - - // packet from any addr will be accepted (full-cone) - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), 7777), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_ok(), "should succeed"); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_behavior_addr_restricted_cone_nat() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointIndependent, - filtering_behavior: EndpointDependencyType::EndpointAddrDependent, - hair_pining: false, - mapping_life_time: Duration::from_secs(30), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst = SocketAddr::from_str("5.6.7.8:5678")?; - - let oic = ChunkUdp::new(src, dst); - log::debug!("o-original : {}", oic); - - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - log::debug!("o-translated: {}", oec); - - // sending different (IP: 5.6.7.9) won't create a new mapping - let oic2 = ChunkUdp::new( - SocketAddr::from_str("192.168.0.2:1234")?, - SocketAddr::from_str("5.6.7.9:9000")?, - ); - let oec2 = nat.translate_outbound(&oic2).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - log::debug!("o-translated: {}", oec2); - - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - log::debug!("i-original : {}", iec); - - let iic = nat.translate_inbound(&iec).await?.unwrap(); - - log::debug!("i-translated: {}", iic); - - assert_eq!(oic.source_addr(), iic.destination_addr(), "should match"); - - // packet with dest addr that does not exist in the mapping table - // will be dropped - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port() + 1), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should fail (dropped)"); - - // packet from any port will be accepted (restricted-cone) - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), 7777), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_ok(), "should succeed"); - - // packet from different addr will be dropped (restricted-cone) - let iec = ChunkUdp::new( - SocketAddr::from_str(&format!("{}:{}", "6.6.6.6", dst.port()))?, - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should fail (dropped)"); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_behavior_port_restricted_cone_nat() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointIndependent, - filtering_behavior: EndpointDependencyType::EndpointAddrPortDependent, - hair_pining: false, - mapping_life_time: Duration::from_secs(30), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst = SocketAddr::from_str("5.6.7.8:5678")?; - - let oic = ChunkUdp::new(src, dst); - log::debug!("o-original : {}", oic); - - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - log::debug!("o-translated: {}", oec); - - // sending different (IP: 5.6.7.9) won't create a new mapping - let oic2 = ChunkUdp::new( - SocketAddr::from_str("192.168.0.2:1234")?, - SocketAddr::from_str("5.6.7.9:9000")?, - ); - let oec2 = nat.translate_outbound(&oic2).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - log::debug!("o-translated: {}", oec2); - - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - log::debug!("i-original : {}", iec); - - let iic = nat.translate_inbound(&iec).await?.unwrap(); - - log::debug!("i-translated: {}", iic); - - assert_eq!(oic.source_addr(), iic.destination_addr(), "should match"); - - // packet with dest addr that does not exist in the mapping table - // will be dropped - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port() + 1), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should fail (dropped)"); - - // packet from different port will be dropped (port-restricted-cone) - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), 7777), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should fail (dropped)"); - - // packet from different addr will be dropped (restricted-cone) - let iec = ChunkUdp::new( - SocketAddr::from_str(&format!("{}:{}", "6.6.6.6", dst.port()))?, - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should fail (dropped)"); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_behavior_symmetric_nat_addr_dependent_mapping() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointAddrDependent, - filtering_behavior: EndpointDependencyType::EndpointAddrDependent, - hair_pining: false, - mapping_life_time: Duration::from_secs(30), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst1 = SocketAddr::from_str("5.6.7.8:5678")?; - let dst2 = SocketAddr::from_str("5.6.7.100:5678")?; - let dst3 = SocketAddr::from_str("5.6.7.8:6000")?; - - let oic1 = ChunkUdp::new(src, dst1); - let oic2 = ChunkUdp::new(src, dst2); - let oic3 = ChunkUdp::new(src, dst3); - - log::debug!("o-original : {}", oic1); - log::debug!("o-original : {}", oic2); - log::debug!("o-original : {}", oic3); - - let oec1 = nat.translate_outbound(&oic1).await?.unwrap(); - let oec2 = nat.translate_outbound(&oic2).await?.unwrap(); - let oec3 = nat.translate_outbound(&oic3).await?.unwrap(); - - assert_eq!(nat.outbound_map_len().await, 2, "should match"); - assert_eq!(nat.inbound_map_len().await, 2, "should match"); - - log::debug!("o-translated: {}", oec1); - log::debug!("o-translated: {}", oec2); - log::debug!("o-translated: {}", oec3); - - assert_ne!( - oec1.source_addr().port(), - oec2.source_addr().port(), - "should not match" - ); - assert_eq!( - oec1.source_addr().port(), - oec3.source_addr().port(), - "should match" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_behavior_symmetric_nat_port_dependent_mapping() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointAddrPortDependent, - filtering_behavior: EndpointDependencyType::EndpointAddrPortDependent, - hair_pining: false, - mapping_life_time: Duration::from_secs(30), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst1 = SocketAddr::from_str("5.6.7.8:5678")?; - let dst2 = SocketAddr::from_str("5.6.7.100:5678")?; - let dst3 = SocketAddr::from_str("5.6.7.8:6000")?; - - let oic1 = ChunkUdp::new(src, dst1); - let oic2 = ChunkUdp::new(src, dst2); - let oic3 = ChunkUdp::new(src, dst3); - - log::debug!("o-original : {}", oic1); - log::debug!("o-original : {}", oic2); - log::debug!("o-original : {}", oic3); - - let oec1 = nat.translate_outbound(&oic1).await?.unwrap(); - let oec2 = nat.translate_outbound(&oic2).await?.unwrap(); - let oec3 = nat.translate_outbound(&oic3).await?.unwrap(); - - assert_eq!(nat.outbound_map_len().await, 3, "should match"); - assert_eq!(nat.inbound_map_len().await, 3, "should match"); - - log::debug!("o-translated: {}", oec1); - log::debug!("o-translated: {}", oec2); - log::debug!("o-translated: {}", oec3); - - assert_ne!( - oec1.source_addr().port(), - oec2.source_addr().port(), - "should not match" - ); - assert_ne!( - oec1.source_addr().port(), - oec3.source_addr().port(), - "should match" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_timeout_refresh_on_outbound() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointIndependent, - filtering_behavior: EndpointDependencyType::EndpointIndependent, - hair_pining: false, - mapping_life_time: Duration::from_millis(200), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst = SocketAddr::from_str("5.6.7.8:5678")?; - - let oic = ChunkUdp::new(src, dst); - - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - - log::debug!("o-original : {}", oic); - log::debug!("o-translated: {}", oec); - - // record mapped addr - let mapped = oec.source_addr().to_string(); - - tokio::time::sleep(Duration::from_millis(5)).await; - - // refresh - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - - log::debug!("o-original : {}", oic); - log::debug!("o-translated: {}", oec); - - assert_eq!( - mapped, - oec.source_addr().to_string(), - "mapped addr should match" - ); - - // sleep long enough for the mapping to expire - tokio::time::sleep(Duration::from_millis(225)).await; - - // refresh after expiration - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - - log::debug!("o-original : {}", oic); - log::debug!("o-translated: {}", oec); - - assert_ne!( - oec.source_addr().to_string(), - mapped, - "mapped addr should not match" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_nat_mapping_timeout_outbound_detects_timeout() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mapping_behavior: EndpointDependencyType::EndpointIndependent, - filtering_behavior: EndpointDependencyType::EndpointIndependent, - hair_pining: false, - mapping_life_time: Duration::from_millis(100), - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("192.168.0.2:1234")?; - let dst = SocketAddr::from_str("5.6.7.8:5678")?; - - let oic = ChunkUdp::new(src, dst); - - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 1, "should match"); - assert_eq!(nat.inbound_map_len().await, 1, "should match"); - - log::debug!("o-original : {}", oic); - log::debug!("o-translated: {}", oec); - - // sleep long enough for the mapping to expire - tokio::time::sleep(Duration::from_millis(125)).await; - - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - log::debug!("i-original : {}", iec); - - let result = nat.translate_inbound(&iec).await; - assert!(result.is_err(), "should drop"); - assert_eq!(nat.outbound_map_len().await, 0, "should match"); - assert_eq!(nat.inbound_map_len().await, 0, "should match"); - - Ok(()) -} - -#[tokio::test] -async fn test_nat1to1_behavior_one_mapping() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mode: NatMode::Nat1To1, - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - local_ips: vec![IpAddr::from_str("10.0.0.1")?], - ..Default::default() - })?; - - let src = SocketAddr::from_str("10.0.0.1:1234")?; - let dst = SocketAddr::from_str("5.6.7.8:5678")?; - - let oic = ChunkUdp::new(src, dst); - - let oec = nat.translate_outbound(&oic).await?.unwrap(); - assert_eq!(nat.outbound_map_len().await, 0, "should match"); - assert_eq!(nat.inbound_map_len().await, 0, "should match"); - - log::debug!("o-original : {}", oic); - log::debug!("o-translated: {}", oec); - - assert_eq!( - "1.2.3.4:1234", - oec.source_addr().to_string(), - "should match" - ); - - let iec = ChunkUdp::new( - SocketAddr::new(dst.ip(), dst.port()), - SocketAddr::new(oec.source_addr().ip(), oec.source_addr().port()), - ); - - log::debug!("i-original : {}", iec); - - let iic = nat.translate_inbound(&iec).await?.unwrap(); - - log::debug!("i-translated: {}", iic); - - assert_eq!(oic.source_addr(), iic.destination_addr(), "should match"); - - Ok(()) -} - -#[tokio::test] -async fn test_nat1to1_behavior_more_mapping() -> Result<()> { - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mode: NatMode::Nat1To1, - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?, IpAddr::from_str("1.2.3.5")?], - local_ips: vec![IpAddr::from_str("10.0.0.1")?, IpAddr::from_str("10.0.0.2")?], - ..Default::default() - })?; - - // outbound translation - - let before = ChunkUdp::new( - SocketAddr::from_str("10.0.0.1:1234")?, - SocketAddr::from_str("5.6.7.8:5678")?, - ); - - let after = nat.translate_outbound(&before).await?.unwrap(); - assert_eq!( - after.source_addr().to_string(), - "1.2.3.4:1234", - "should match" - ); - - let before = ChunkUdp::new( - SocketAddr::from_str("10.0.0.2:1234")?, - SocketAddr::from_str("5.6.7.8:5678")?, - ); - - let after = nat.translate_outbound(&before).await?.unwrap(); - assert_eq!( - after.source_addr().to_string(), - "1.2.3.5:1234", - "should match" - ); - - // inbound translation - - let before = ChunkUdp::new( - SocketAddr::from_str("5.6.7.8:5678")?, - SocketAddr::from_str(&format!("{}:{}", DEMO_IP, 2525))?, - ); - - let after = nat.translate_inbound(&before).await?.unwrap(); - assert_eq!( - after.destination_addr().to_string(), - "10.0.0.1:2525", - "should match" - ); - - let before = ChunkUdp::new( - SocketAddr::from_str("5.6.7.8:5678")?, - SocketAddr::from_str("1.2.3.5:9847")?, - ); - - let after = nat.translate_inbound(&before).await?.unwrap(); - assert_eq!( - after.destination_addr().to_string(), - "10.0.0.2:9847", - "should match" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_nat1to1_behavior_failure() -> Result<()> { - // 1:1 NAT requires more than one mapping - let result = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mode: NatMode::Nat1To1, - ..Default::default() - }, - ..Default::default() - }); - assert!(result.is_err(), "should fail"); - - // 1:1 NAT requires the same number of mappedIPs and localIPs - let result = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mode: NatMode::Nat1To1, - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?, IpAddr::from_str("1.2.3.5")?], - local_ips: vec![IpAddr::from_str("10.0.0.1")?], - ..Default::default() - }); - assert!(result.is_err(), "should fail"); - - // drop outbound or inbound chunk with no route in 1:1 NAT - let nat = NetworkAddressTranslator::new(NatConfig { - nat_type: NatType { - mode: NatMode::Nat1To1, - ..Default::default() - }, - mapped_ips: vec![IpAddr::from_str(DEMO_IP)?], - local_ips: vec![IpAddr::from_str("10.0.0.1")?], - ..Default::default() - })?; - - let before = ChunkUdp::new( - SocketAddr::from_str("10.0.0.2:1234")?, // no external mapping for this - SocketAddr::from_str("5.6.7.8:5678")?, - ); - - let after = nat.translate_outbound(&before).await?; - assert!(after.is_none(), "should be nil"); - - let before = ChunkUdp::new( - SocketAddr::from_str("5.6.7.8:5678")?, - SocketAddr::from_str("10.0.0.2:1234")?, // no local mapping for this - ); - - let result = nat.translate_inbound(&before).await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} diff --git a/util/src/vnet/net.rs b/util/src/vnet/net.rs deleted file mode 100644 index 5586e59b7..000000000 --- a/util/src/vnet/net.rs +++ /dev/null @@ -1,566 +0,0 @@ -#[cfg(test)] -mod net_test; - -use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::str::FromStr; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use ipnet::IpNet; -use portable_atomic::AtomicU64; -use tokio::net::UdpSocket; -use tokio::sync::Mutex; - -use super::conn_map::*; -use super::interface::*; -use crate::error::*; -use crate::vnet::chunk::Chunk; -use crate::vnet::conn::{ConnObserver, UdpConn}; -use crate::vnet::router::*; -use crate::{conn, ifaces, Conn}; - -pub(crate) const LO0_STR: &str = "lo0"; -pub(crate) const UDP_STR: &str = "udp"; - -lazy_static! { - pub static ref MAC_ADDR_COUNTER: AtomicU64 = AtomicU64::new(0xBEEFED910200); -} - -pub(crate) type HardwareAddr = Vec; - -pub(crate) fn new_mac_address() -> HardwareAddr { - let b = MAC_ADDR_COUNTER - .fetch_add(1, Ordering::SeqCst) - .to_be_bytes(); - b[2..].to_vec() -} - -#[derive(Default)] -pub(crate) struct VNetInternal { - pub(crate) interfaces: Vec, // read-only - pub(crate) router: Option>>, // read-only - pub(crate) udp_conns: UdpConnMap, // read-only -} - -impl VNetInternal { - fn get_interface(&self, ifc_name: &str) -> Option<&Interface> { - self.interfaces.iter().find(|ifc| ifc.name == ifc_name) - } -} - -#[async_trait] -impl ConnObserver for VNetInternal { - async fn write(&self, c: Box) -> Result<()> { - if c.network() == UDP_STR && c.get_destination_ip().is_loopback() { - if let Some(conn) = self.udp_conns.find(&c.destination_addr()).await { - let read_ch_tx = conn.get_inbound_ch(); - let ch_tx = read_ch_tx.lock().await; - if let Some(tx) = &*ch_tx { - let _ = tx.send(c).await; - } - } - return Ok(()); - } - - if let Some(r) = &self.router { - let p = r.lock().await; - p.push(c).await; - Ok(()) - } else { - Err(Error::ErrNoRouterLinked) - } - } - - async fn on_closed(&self, addr: SocketAddr) { - let _ = self.udp_conns.delete(&addr).await; - } - - // This method determines the srcIP based on the dstIP when locIP - // is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr, - // this method simply returns locIP. - // caller must hold the mutex - fn determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option { - if !loc_ip.is_unspecified() { - return Some(loc_ip); - } - - if dst_ip.is_loopback() { - let src_ip = if let Ok(src_ip) = IpAddr::from_str("127.0.0.1") { - Some(src_ip) - } else { - None - }; - return src_ip; - } - - if let Some(ifc) = self.get_interface("eth0") { - for ipnet in ifc.addrs() { - if (ipnet.addr().is_ipv4() && loc_ip.is_ipv4()) - || (ipnet.addr().is_ipv6() && loc_ip.is_ipv6()) - { - return Some(ipnet.addr()); - } - } - } - - None - } -} - -#[derive(Default)] -pub struct VNet { - pub(crate) interfaces: Vec, // read-only - pub(crate) static_ips: Vec, // read-only - pub(crate) vi: Arc>, -} - -#[async_trait] -impl Nic for VNet { - async fn get_interface(&self, ifc_name: &str) -> Option { - for ifc in &self.interfaces { - if ifc.name == ifc_name { - return Some(ifc.clone()); - } - } - None - } - - async fn add_addrs_to_interface(&mut self, ifc_name: &str, addrs: &[IpNet]) -> Result<()> { - { - let mut vi = self.vi.lock().await; - for ifc in &mut vi.interfaces { - if ifc.name == ifc_name { - for addr in addrs { - ifc.add_addr(*addr); - } - break; - } - } - } - - for ifc in &mut self.interfaces { - if ifc.name == ifc_name { - for addr in addrs { - ifc.add_addr(*addr); - } - return Ok(()); - } - } - - Err(Error::ErrNotFound) - } - - async fn set_router(&self, r: Arc>) -> Result<()> { - let mut vi = self.vi.lock().await; - vi.router = Some(r); - - Ok(()) - } - - async fn on_inbound_chunk(&self, c: Box) { - if c.network() == UDP_STR { - let vi = self.vi.lock().await; - if let Some(conn) = vi.udp_conns.find(&c.destination_addr()).await { - let read_ch_tx = conn.get_inbound_ch(); - let ch_tx = read_ch_tx.lock().await; - if let Some(tx) = &*ch_tx { - let _ = tx.send(c).await; - } - } - } - } - - async fn get_static_ips(&self) -> Vec { - self.static_ips.clone() - } -} - -impl VNet { - pub(crate) fn get_interfaces(&self) -> &[Interface] { - &self.interfaces - } - - // caller must hold the mutex - pub(crate) fn get_all_ipaddrs(&self, ipv6: bool) -> Vec { - let mut ips = vec![]; - - for ifc in &self.interfaces { - for ipnet in ifc.addrs() { - if (ipv6 && ipnet.addr().is_ipv6()) || (!ipv6 && ipnet.addr().is_ipv4()) { - ips.push(ipnet.addr()); - } - } - } - - ips - } - - // caller must hold the mutex - pub(crate) fn has_ipaddr(&self, ip: IpAddr) -> bool { - for ifc in &self.interfaces { - for ipnet in ifc.addrs() { - let loc_ip = ipnet.addr(); - - match ip.to_string().as_str() { - "0.0.0.0" => { - if loc_ip.is_ipv4() { - return true; - } - } - "::" => { - if loc_ip.is_ipv6() { - return true; - } - } - _ => { - if loc_ip == ip { - return true; - } - } - } - } - } - - false - } - - // caller must hold the mutex - pub(crate) async fn allocate_local_addr(&self, ip: IpAddr, port: u16) -> Result<()> { - // gather local IP addresses to bind - let mut ips = vec![]; - if ip.is_unspecified() { - ips = self.get_all_ipaddrs(ip.is_ipv6()); - } else if self.has_ipaddr(ip) { - ips.push(ip); - } - - if ips.is_empty() { - return Err(Error::ErrBindFailed); - } - - // check if all these transport addresses are not in use - for ip2 in ips { - let addr = SocketAddr::new(ip2, port); - let vi = self.vi.lock().await; - if vi.udp_conns.find(&addr).await.is_some() { - return Err(Error::ErrAddressAlreadyInUse); - } - } - - Ok(()) - } - - // caller must hold the mutex - pub(crate) async fn assign_port(&self, ip: IpAddr, start: u16, end: u16) -> Result { - // choose randomly from the range between start and end (inclusive) - if end < start { - return Err(Error::ErrEndPortLessThanStart); - } - - let space = end + 1 - start; - let offset = rand::random::() % space; - for i in 0..space { - let port = ((offset + i) % space) + start; - let result = self.allocate_local_addr(ip, port).await; - if result.is_ok() { - return Ok(port); - } - } - - Err(Error::ErrPortSpaceExhausted) - } - - pub(crate) async fn resolve_addr(&self, use_ipv4: bool, address: &str) -> Result { - let v: Vec<&str> = address.splitn(2, ':').collect(); - if v.len() != 2 { - return Err(Error::ErrAddrNotUdpAddr); - } - let (host, port) = (v[0], v[1]); - - // Check if host is a domain name - let ip: IpAddr = match host.parse() { - Ok(ip) => ip, - Err(_) => { - let host = host.to_lowercase(); - if host == "localhost" { - if use_ipv4 { - Ipv4Addr::new(127, 0, 0, 1).into() - } else { - Ipv6Addr::from_str("::1")?.into() - } - } else { - // host is a domain name. resolve IP address by the name - let vi = self.vi.lock().await; - if let Some(router) = &vi.router { - let r = router.lock().await; - let resolver = r.resolver.lock().await; - if let Some(ip) = resolver.lookup(host).await { - ip - } else { - return Err(Error::ErrNotFound); - } - } else { - return Err(Error::ErrNoRouterLinked); - } - } - } - }; - - let port: u16 = port.parse()?; - - let remote_addr = SocketAddr::new(ip, port); - if (use_ipv4 && remote_addr.is_ipv4()) || (!use_ipv4 && remote_addr.is_ipv6()) { - Ok(remote_addr) - } else { - Err(Error::Other(format!( - "No available {} IP address found!", - if use_ipv4 { "ipv4" } else { "ipv6" }, - ))) - } - } - - // caller must hold the mutex - pub(crate) async fn bind( - &self, - mut local_addr: SocketAddr, - ) -> Result> { - // validate address. do we have that address? - if !self.has_ipaddr(local_addr.ip()) { - return Err(Error::ErrCantAssignRequestedAddr); - } - - if local_addr.port() == 0 { - // choose randomly from the range between 5000 and 5999 - local_addr.set_port(self.assign_port(local_addr.ip(), 5000, 5999).await?); - } else { - let vi = self.vi.lock().await; - if vi.udp_conns.find(&local_addr).await.is_some() { - return Err(Error::ErrAddressAlreadyInUse); - } - } - - let v = Arc::clone(&self.vi) as Arc>; - let conn = Arc::new(UdpConn::new(local_addr, None, v)); - - { - let vi = self.vi.lock().await; - vi.udp_conns.insert(Arc::clone(&conn)).await?; - } - - Ok(conn) - } - - pub(crate) async fn dail( - &self, - use_ipv4: bool, - remote_addr: &str, - ) -> Result> { - let rem_addr = self.resolve_addr(use_ipv4, remote_addr).await?; - - // Determine source address - let src_ip = { - let vi = self.vi.lock().await; - let any_ip = if use_ipv4 { - Ipv4Addr::new(0, 0, 0, 0).into() - } else { - Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into() - }; - if let Some(src_ip) = vi.determine_source_ip(any_ip, rem_addr.ip()) { - src_ip - } else { - any_ip - } - }; - - let loc_addr = SocketAddr::new(src_ip, 0); - - let conn = self.bind(loc_addr).await?; - conn.connect(rem_addr).await?; - - Ok(conn) - } -} - -// NetConfig is a bag of configuration parameters passed to NewNet(). -#[derive(Debug, Default)] -pub struct NetConfig { - // static_ips is an array of static IP addresses to be assigned for this Net. - // If no static IP address is given, the router will automatically assign - // an IP address. - pub static_ips: Vec, - - // static_ip is deprecated. Use static_ips. - pub static_ip: String, -} - -// Net represents a local network stack equivalent to a set of layers from NIC -// up to the transport (UDP / TCP) layer. -pub enum Net { - VNet(Arc>), - Ifs(Vec), -} - -impl Net { - // NewNet creates an instance of Net. - // If config is nil, the virtual network is disabled. (uses corresponding - // net.Xxxx() operations. - // By design, it always have lo0 and eth0 interfaces. - // The lo0 has the address 127.0.0.1 assigned by default. - // IP address for eth0 will be assigned when this Net is added to a router. - pub fn new(config: Option) -> Self { - if let Some(config) = config { - let mut lo0 = Interface::new(LO0_STR.to_owned(), vec![]); - if let Ok(ipnet) = Interface::convert( - SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 0), - Some(SocketAddr::new(Ipv4Addr::new(255, 0, 0, 0).into(), 0)), - ) { - lo0.add_addr(ipnet); - } - - let eth0 = Interface::new("eth0".to_owned(), vec![]); - - let mut static_ips = vec![]; - for ip_str in &config.static_ips { - if let Ok(ip) = IpAddr::from_str(ip_str) { - static_ips.push(ip); - } - } - if !config.static_ip.is_empty() { - if let Ok(ip) = IpAddr::from_str(&config.static_ip) { - static_ips.push(ip); - } - } - - let vnet = VNet { - interfaces: vec![lo0.clone(), eth0.clone()], - static_ips, - vi: Arc::new(Mutex::new(VNetInternal { - interfaces: vec![lo0, eth0], - router: None, - udp_conns: UdpConnMap::new(), - })), - }; - - Net::VNet(Arc::new(Mutex::new(vnet))) - } else { - let interfaces = match ifaces::ifaces() { - Ok(ifs) => ifs, - Err(_) => vec![], - }; - - let mut m: HashMap> = HashMap::new(); - for iface in interfaces { - if let Some(addrs) = m.get_mut(&iface.name) { - if let Some(addr) = iface.addr { - if let Ok(inet) = Interface::convert(addr, iface.mask) { - addrs.push(inet); - } - } - } else if let Some(addr) = iface.addr { - if let Ok(inet) = Interface::convert(addr, iface.mask) { - m.insert(iface.name, vec![inet]); - } - } - } - - let mut ifs = vec![]; - for (name, addrs) in m.into_iter() { - ifs.push(Interface::new(name, addrs)); - } - - Net::Ifs(ifs) - } - } - - // Interfaces returns a list of the system's network interfaces. - pub async fn get_interfaces(&self) -> Vec { - match self { - Net::VNet(vnet) => { - let net = vnet.lock().await; - net.get_interfaces().to_vec() - } - Net::Ifs(ifs) => ifs.clone(), - } - } - - // InterfaceByName returns the interface specified by name. - pub async fn get_interface(&self, ifc_name: &str) -> Option { - match self { - Net::VNet(vnet) => { - let net = vnet.lock().await; - net.get_interface(ifc_name).await - } - Net::Ifs(ifs) => { - for ifc in ifs { - if ifc.name == ifc_name { - return Some(ifc.clone()); - } - } - None - } - } - } - - // IsVirtual tests if the virtual network is enabled. - pub fn is_virtual(&self) -> bool { - match self { - Net::VNet(_) => true, - Net::Ifs(_) => false, - } - } - - pub async fn resolve_addr(&self, use_ipv4: bool, address: &str) -> Result { - match self { - Net::VNet(vnet) => { - let net = vnet.lock().await; - net.resolve_addr(use_ipv4, address).await - } - Net::Ifs(_) => Ok(conn::lookup_host(use_ipv4, address).await?), - } - } - - pub async fn bind(&self, addr: SocketAddr) -> Result> { - match self { - Net::VNet(vnet) => { - let net = vnet.lock().await; - net.bind(addr).await - } - Net::Ifs(_) => Ok(Arc::new(UdpSocket::bind(addr).await?)), - } - } - - pub async fn dail( - &self, - use_ipv4: bool, - remote_addr: &str, - ) -> Result> { - match self { - Net::VNet(vnet) => { - let net = vnet.lock().await; - net.dail(use_ipv4, remote_addr).await - } - Net::Ifs(_) => { - let any_ip = if use_ipv4 { - Ipv4Addr::new(0, 0, 0, 0).into() - } else { - Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into() - }; - let local_addr = SocketAddr::new(any_ip, 0); - - let conn = UdpSocket::bind(local_addr).await?; - conn.connect(remote_addr).await?; - - Ok(Arc::new(conn)) - } - } - } - - pub fn get_nic(&self) -> Result>> { - match self { - Net::VNet(vnet) => Ok(Arc::clone(vnet) as Arc>), - Net::Ifs(_) => Err(Error::ErrVnetDisabled), - } - } -} diff --git a/util/src/vnet/net/net_test.rs b/util/src/vnet/net/net_test.rs deleted file mode 100644 index d24002290..000000000 --- a/util/src/vnet/net/net_test.rs +++ /dev/null @@ -1,903 +0,0 @@ -use tokio::sync::{broadcast, mpsc}; - -use super::*; -use crate::vnet::chunk::ChunkUdp; - -const DEMO_IP: &str = "1.2.3.4"; - -#[derive(Default)] -struct DummyObserver; - -#[async_trait] -impl ConnObserver for DummyObserver { - async fn write(&self, _c: Box) -> Result<()> { - Ok(()) - } - - async fn on_closed(&self, _addr: SocketAddr) {} - - fn determine_source_ip(&self, loc_ip: IpAddr, _dst_ip: IpAddr) -> Option { - Some(loc_ip) - } -} - -#[tokio::test] -async fn test_net_native_interfaces() -> Result<()> { - let nw = Net::new(None); - assert!(!nw.is_virtual(), "should be false"); - - let interfaces = nw.get_interfaces().await; - log::debug!("interfaces: {:?}", interfaces); - for ifc in interfaces { - let addrs = ifc.addrs(); - for addr in addrs { - log::debug!("{}", addr) - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_net_native_resolve_addr() -> Result<()> { - let nw = Net::new(None); - assert!(!nw.is_virtual(), "should be false"); - - let udp_addr = nw.resolve_addr(true, "localhost:1234").await?; - assert_eq!(udp_addr.ip().to_string(), "127.0.0.1", "should match"); - assert_eq!(udp_addr.port(), 1234, "should match"); - - let result = nw.resolve_addr(false, "127.0.0.1:1234").await; - assert!(result.is_err(), "should not match"); - - Ok(()) -} - -#[tokio::test] -async fn test_net_native_bind() -> Result<()> { - let nw = Net::new(None); - assert!(!nw.is_virtual(), "should be false"); - - let conn = nw.bind(SocketAddr::from_str("127.0.0.1:0")?).await?; - let laddr = conn.local_addr()?; - assert_eq!( - laddr.ip().to_string(), - "127.0.0.1", - "local_addr ip should match 127.0.0.1" - ); - log::debug!("laddr: {}", laddr); - - Ok(()) -} - -#[tokio::test] -async fn test_net_native_dail() -> Result<()> { - let nw = Net::new(None); - assert!(!nw.is_virtual(), "should be false"); - - let conn = nw.dail(true, "127.0.0.1:1234").await?; - let laddr = conn.local_addr()?; - assert_eq!( - laddr.ip().to_string(), - "127.0.0.1", - "local_addr should match 127.0.0.1" - ); - assert_ne!(laddr.port(), 1234, "local_addr port should match 1234"); - log::debug!("laddr: {}", laddr); - - Ok(()) -} - -#[tokio::test] -async fn test_net_native_loopback() -> Result<()> { - let nw = Net::new(None); - assert!(!nw.is_virtual(), "should be false"); - - let conn = nw.bind(SocketAddr::from_str("127.0.0.1:0")?).await?; - let laddr = conn.local_addr()?; - - let msg = "PING!"; - let n = conn.send_to(msg.as_bytes(), laddr).await?; - assert_eq!(n, msg.len(), "should match msg size {}", msg.len()); - - let mut buf = vec![0u8; 1000]; - let (n, raddr) = conn.recv_from(&mut buf).await?; - assert_eq!(n, msg.len(), "should match msg size {}", msg.len()); - assert_eq!(&buf[..n], msg.as_bytes(), "should match msg content {msg}"); - assert_eq!(laddr, raddr, "should match addr {laddr}"); - - Ok(()) -} - -#[tokio::test] -async fn test_net_native_unexpected_operations() -> Result<()> { - let mut lo_name = String::new(); - let ifcs = ifaces::ifaces()?; - for ifc in &ifcs { - if let Some(addr) = ifc.addr { - if addr.ip().is_loopback() { - lo_name.clone_from(&ifc.name); - break; - } - } - } - - let nw = Net::new(None); - assert!(!nw.is_virtual(), "should be false"); - - if !lo_name.is_empty() { - if let Some(ifc) = nw.get_interface(&lo_name).await { - assert_eq!(ifc.name, lo_name, "should match ifc name"); - } else { - panic!("should succeed"); - } - } - - let result = nw.get_interface("foo0").await; - assert!(result.is_none(), "should be none"); - - //let ips = nw.get_static_ips(); - //assert!(ips.is_empty(), "should empty"); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_interfaces() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let interfaces = nw.get_interfaces().await; - assert_eq!(2, interfaces.len(), "should be one interface"); - - for ifc in interfaces { - match ifc.name.as_str() { - LO0_STR => { - let addrs = ifc.addrs(); - assert_eq!(addrs.len(), 1, "should be one address"); - } - "eth0" => { - let addrs = ifc.addrs(); - assert!(addrs.is_empty(), "should empty"); - } - _ => { - panic!("unknown interface: {}", ifc.name); - } - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_interface_by_name() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let interfaces = nw.get_interfaces().await; - assert_eq!(2, interfaces.len(), "should be one interface"); - - let nic = nw.get_nic()?; - let nic = nic.lock().await; - if let Some(ifc) = nic.get_interface(LO0_STR).await { - assert_eq!(ifc.name.as_str(), LO0_STR, "should match"); - let addrs = ifc.addrs(); - assert_eq!(addrs.len(), 1, "should be one address"); - } else { - panic!("should got ifc"); - } - - if let Some(ifc) = nic.get_interface("eth0").await { - assert_eq!(ifc.name.as_str(), "eth0", "should match"); - let addrs = ifc.addrs(); - assert!(addrs.is_empty(), "should empty"); - } else { - panic!("should got ifc"); - } - - let result = nic.get_interface("foo0").await; - assert!(result.is_none(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_has_ipaddr() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let interfaces = nw.get_interfaces().await; - assert_eq!(interfaces.len(), 2, "should be one interface"); - - { - let nic = nw.get_nic()?; - let mut nic = nic.lock().await; - let ipnet = IpNet::from_str("10.1.2.3/24")?; - nic.add_addrs_to_interface("eth0", &[ipnet]).await?; - - if let Some(ifc) = nic.get_interface("eth0").await { - let addrs = ifc.addrs(); - assert!(!addrs.is_empty(), "should not empty"); - } - } - - if let Net::VNet(vnet) = &nw { - let net = vnet.lock().await; - let ip = Ipv4Addr::from_str("127.0.0.1")?.into(); - assert!(net.has_ipaddr(ip), "the IP addr {ip} should exist"); - - let ip = Ipv4Addr::from_str("10.1.2.3")?.into(); - assert!(net.has_ipaddr(ip), "the IP addr {ip} should exist"); - - let ip = Ipv4Addr::from_str("192.168.1.1")?.into(); - assert!(!net.has_ipaddr(ip), "the IP addr {ip} should exist"); - } - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_get_all_ipaddrs() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let interfaces = nw.get_interfaces().await; - assert_eq!(interfaces.len(), 2, "should be one interface"); - - { - let nic = nw.get_nic()?; - let mut nic = nic.lock().await; - let ipnet = IpNet::from_str("10.1.2.3/24")?; - nic.add_addrs_to_interface("eth0", &[ipnet]).await?; - - if let Some(ifc) = nic.get_interface("eth0").await { - let addrs = ifc.addrs(); - assert!(!addrs.is_empty(), "should not empty"); - } - } - - if let Net::VNet(vnet) = &nw { - let net = vnet.lock().await; - let ips = net.get_all_ipaddrs(false); - assert_eq!(ips.len(), 2, "ips should match size {} == 2", ips.len()) - } - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_assign_port() -> Result<()> { - let mut nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let addr = DEMO_IP; - let start = 1000u16; - let end = 1002u16; - let space = end + 1 - start; - - let interfaces = nw.get_interfaces().await; - assert_eq!(interfaces.len(), 2, "should be one interface"); - - { - let nic = nw.get_nic()?; - let mut nic = nic.lock().await; - let ipnet = IpNet::from_str(&format!("{addr}/24"))?; - nic.add_addrs_to_interface("eth0", &[ipnet]).await?; - } - - if let Net::VNet(vnet) = &mut nw { - let vnet = vnet.lock().await; - // attempt to assign port with start > end should fail - let ip = IpAddr::from_str(addr)?; - let result = vnet.assign_port(ip, 3000, 2999).await; - assert!(result.is_err(), "assign_port should fail"); - - for i in 0..space { - let port = vnet.assign_port(ip, start, end).await?; - log::debug!("{} got port: {}", i, port); - - let obs: Arc> = - Arc::new(Mutex::new(DummyObserver)); - - let conn = Arc::new(UdpConn::new(SocketAddr::new(ip, port), None, obs)); - - let vi = vnet.vi.lock().await; - let _ = vi.udp_conns.insert(conn).await; - } - - { - let vi = vnet.vi.lock().await; - assert_eq!( - vi.udp_conns.len().await, - space as usize, - "udp_conns should match" - ); - } - - // attempt to assign again should fail - let result = vnet.assign_port(ip, start, end).await; - assert!(result.is_err(), "assign_port should fail"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_determine_source_ip() -> Result<()> { - let mut nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let interfaces = nw.get_interfaces().await; - assert_eq!(interfaces.len(), 2, "should be one interface"); - - { - let nic = nw.get_nic()?; - let mut nic = nic.lock().await; - let ipnet = IpNet::from_str(&format!("{DEMO_IP}/24"))?; - nic.add_addrs_to_interface("eth0", &[ipnet]).await?; - } - - // Any IP turned into non-loopback IP - let any_ip = IpAddr::from_str("0.0.0.0")?; - let dst_ip = IpAddr::from_str("27.1.7.135")?; - if let Net::VNet(vnet) = &mut nw { - let vnet = vnet.lock().await; - let vi = vnet.vi.lock().await; - let src_ip = vi.determine_source_ip(any_ip, dst_ip); - log::debug!("any_ip: {} => {:?}", any_ip, src_ip); - assert!(src_ip.is_some(), "shouldn't be none"); - if let Some(src_ip) = src_ip { - assert_eq!(src_ip.to_string().as_str(), DEMO_IP, "use non-loopback IP"); - } - } - - // Any IP turned into loopback IP - let any_ip = IpAddr::from_str("0.0.0.0")?; - let dst_ip = IpAddr::from_str("127.0.0.2")?; - if let Net::VNet(vnet) = &mut nw { - let vnet = vnet.lock().await; - let vi = vnet.vi.lock().await; - let src_ip = vi.determine_source_ip(any_ip, dst_ip); - log::debug!("any_ip: {} => {:?}", any_ip, src_ip); - assert!(src_ip.is_some(), "shouldn't be none"); - if let Some(src_ip) = src_ip { - assert_eq!(src_ip.to_string().as_str(), "127.0.0.1", "use loopback IP"); - } - } - - // Non any IP won't change - let any_ip = IpAddr::from_str(DEMO_IP)?; - let dst_ip = IpAddr::from_str("127.0.0.2")?; - if let Net::VNet(vnet) = &mut nw { - let vnet = vnet.lock().await; - let vi = vnet.vi.lock().await; - let src_ip = vi.determine_source_ip(any_ip, dst_ip); - log::debug!("any_ip: {} => {:?}", any_ip, src_ip); - assert!(src_ip.is_some(), "shouldn't be none"); - if let Some(src_ip) = src_ip { - assert_eq!(src_ip, any_ip, "IP change"); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_resolve_addr() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let udp_addr = nw.resolve_addr(true, "localhost:1234").await?; - assert_eq!( - udp_addr.ip().to_string().as_str(), - "127.0.0.1", - "udp addr {} should match 127.0.0.1", - udp_addr.ip(), - ); - assert_eq!( - udp_addr.port(), - 1234, - "udp addr {} should match 1234", - udp_addr.port() - ); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_loopback1() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let conn = nw.bind(SocketAddr::from_str("127.0.0.1:0")?).await?; - let laddr = conn.local_addr()?; - - let msg = "PING!"; - let n = conn.send_to(msg.as_bytes(), laddr).await?; - assert_eq!(n, msg.len(), "should match msg size {}", msg.len()); - - let mut buf = vec![0u8; 1000]; - let (n, raddr) = conn.recv_from(&mut buf).await?; - assert_eq!(n, msg.len(), "should match msg size {}", msg.len()); - assert_eq!(&buf[..n], msg.as_bytes(), "should match msg content {msg}"); - assert_eq!(laddr, raddr, "should match addr {laddr}"); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_bind_specific_port() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let conn = nw.bind(SocketAddr::from_str("127.0.0.1:50916")?).await?; - let laddr = conn.local_addr()?; - assert_eq!( - laddr.ip().to_string().as_str(), - "127.0.0.1", - "{} should match 127.0.0.1", - laddr.ip() - ); - assert_eq!(laddr.port(), 50916, "{} should match 50916", laddr.port()); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_dail_lo0() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - assert!(nw.is_virtual(), "should be true"); - - let conn = nw.dail(true, "127.0.0.1:1234").await?; - let laddr = conn.local_addr()?; - assert_eq!( - laddr.ip().to_string().as_str(), - "127.0.0.1", - "{} should match 127.0.0.1", - laddr.ip() - ); - assert_ne!(laddr.port(), 1234, "{} should != 1234", laddr.port()); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_dail_eth0() -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let nw = Net::new(Some(NetConfig::default())); - - { - let nic = nw.get_nic()?; - - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - }; - - let conn = nw.dail(true, "27.3.4.5:1234").await?; - let laddr = conn.local_addr()?; - assert_eq!( - laddr.ip().to_string().as_str(), - "1.2.3.1", - "{} should match 1.2.3.1", - laddr.ip() - ); - assert!(laddr.port() != 0, "{} should != 0", laddr.port()); - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_resolver() -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let nw = Net::new(Some(NetConfig::default())); - - let remote_addr = nw.resolve_addr(true, "127.0.0.1:1234").await?; - assert_eq!(remote_addr.to_string(), "127.0.0.1:1234", "should match"); - - let result = nw.resolve_addr(false, "127.0.0.1:1234").await; - assert!(result.is_err(), "should not match"); - - { - let nic = nw.get_nic()?; - - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - w.add_host("test.webrtc.rs".to_owned(), "30.31.32.33".to_owned()) - .await?; - - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - tokio::spawn(async move { - let (conn, raddr) = { - let raddr = nw.resolve_addr(true, "test.webrtc.rs:1234").await?; - (nw.dail(true, "test.webrtc.rs:1234").await?, raddr) - }; - - let laddr = conn.local_addr()?; - assert_eq!( - laddr.ip().to_string().as_str(), - "1.2.3.1", - "{} should match 1.2.3.1", - laddr.ip() - ); - - assert_eq!( - raddr.to_string(), - "30.31.32.33:1234", - "{raddr} should match 30.31.32.33:1234" - ); - - drop(done_tx); - - Result::<()>::Ok(()) - }); - - let _ = done_rx.recv().await; - - Ok(()) -} - -#[tokio::test] -async fn test_net_virtual_loopback2() -> Result<()> { - let nw = Net::new(Some(NetConfig::default())); - - let conn = nw.bind(SocketAddr::from_str("127.0.0.1:50916")?).await?; - let laddr = conn.local_addr()?; - assert_eq!( - laddr.to_string().as_str(), - "127.0.0.1:50916", - "{laddr} should match 127.0.0.1:50916" - ); - - let mut c = ChunkUdp::new( - SocketAddr::from_str("127.0.0.1:4000")?, - SocketAddr::from_str("127.0.0.1:50916")?, - ); - c.user_data = b"Hello!".to_vec(); - - let (recv_ch_tx, mut recv_ch_rx) = mpsc::channel(1); - let (done_ch_tx, mut done_ch_rx) = mpsc::channel::(1); - let (close_ch_tx, mut close_ch_rx) = mpsc::channel::(1); - let conn_rx = Arc::clone(&conn); - - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - tokio::select! { - result = conn_rx.recv_from(&mut buf) => { - let (n, addr) = match result { - Ok((n, addr)) => (n, addr), - Err(err) => { - log::debug!("ReadFrom returned: {}", err); - break; - } - }; - - assert_eq!(n, 6, "{n} should match 6"); - assert_eq!(addr.to_string(), "127.0.0.1:4000", "addr should match"); - assert_eq!(&buf[..n], b"Hello!", "buf should match"); - - let _ = recv_ch_tx.send(true).await; - } - _ = close_ch_rx.recv() => { - break; - } - } - } - - drop(done_ch_tx); - }); - - if let Net::VNet(vnet) = &nw { - let vnet = vnet.lock().await; - vnet.on_inbound_chunk(Box::new(c)).await; - } else { - panic!("must be virtual net"); - } - - let _ = recv_ch_rx.recv().await; - drop(close_ch_tx); - - let _ = done_ch_rx.recv().await; - - Ok(()) -} - -async fn get_ipaddr(nic: &Arc>) -> Result { - let n = nic.lock().await; - let eth0 = n.get_interface("eth0").await.ok_or(Error::ErrNoInterface)?; - let addrs = eth0.addrs(); - if addrs.is_empty() { - Err(Error::ErrNoAddressAssigned) - } else { - Ok(addrs[0].addr()) - } -} - -//use std::io::Write; - -#[tokio::test] -async fn test_net_virtual_end2end() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let net1 = Net::new(Some(NetConfig::default())); - let ip1 = { - let nic = net1.get_nic()?; - - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - get_ipaddr(&nic).await? - }; - - let net2 = Net::new(Some(NetConfig::default())); - let ip2 = { - let nic = net2.get_nic()?; - - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - get_ipaddr(&nic).await? - }; - - let conn1 = net1.bind(SocketAddr::new(ip1, 1234)).await?; - let conn2 = net2.bind(SocketAddr::new(ip2, 5678)).await?; - - { - let mut w = wan.lock().await; - w.start().await?; - } - - let (close_ch_tx, mut close_ch_rx1) = broadcast::channel::(1); - let (done_ch_tx, mut done_ch_rx) = mpsc::channel::(1); - let (conn1_recv_ch_tx, mut conn1_recv_ch_rx) = mpsc::channel(1); - let conn1_rx = Arc::clone(&conn1); - let conn2_tr = Arc::clone(&conn2); - let mut close_ch_rx2 = close_ch_tx.subscribe(); - - // conn1 - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - log::debug!("conn1: wait for a message.."); - tokio::select! { - result = conn1_rx.recv_from(&mut buf) =>{ - let n = match result{ - Ok((n, _)) => n, - Err(err) => { - log::debug!("ReadFrom returned: {}", err); - break; - } - }; - - log::debug!("conn1 received {:?}", &buf[..n]); - let _ = conn1_recv_ch_tx.send(true).await; - } - _ = close_ch_rx1.recv() => { - log::debug!("conn1 received close_ch_rx1"); - break; - } - } - } - drop(done_ch_tx); - log::debug!("conn1 drop done_ch_tx, exit spawn"); - }); - - // conn2 - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - log::debug!("conn2: wait for a message.."); - tokio::select! { - result = conn2_tr.recv_from(&mut buf) =>{ - let (n, addr) = match result{ - Ok((n, addr)) => (n, addr), - Err(err) => { - log::debug!("ReadFrom returned: {}", err); - break; - } - }; - - log::debug!("conn2 received {:?}", &buf[..n]); - - // echo back to conn1 - let n = conn2_tr.send_to(b"Good-bye!", addr).await?; - assert_eq!( 9, n, "should match"); - } - _ = close_ch_rx2.recv() => { - log::debug!("conn1 received close_ch_rx2"); - break; - } - } - } - - log::debug!("conn2 exit spawn"); - - Result::<()>::Ok(()) - }); - - log::debug!("conn1: sending"); - let n = conn1.send_to(b"Hello!", conn2.local_addr()?).await?; - assert_eq!(n, 6, "should match"); - - let _ = conn1_recv_ch_rx.recv().await; - log::debug!("main recv conn1_recv_ch_rx"); - drop(close_ch_tx); - log::debug!("main drop close_ch_tx"); - let _ = done_ch_rx.recv().await; - log::debug!("main recv done_ch_rx"); - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_net_virtual_two_ips_on_a_nic() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let net = Net::new(Some(NetConfig { - static_ips: vec![DEMO_IP.to_owned(), "1.2.3.5".to_owned()], - ..Default::default() - })); - { - let nic = net.get_nic()?; - - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - // start the router - { - let mut w = wan.lock().await; - w.start().await?; - } - - let (conn1, conn2) = ( - net.bind(SocketAddr::new(Ipv4Addr::from_str(DEMO_IP)?.into(), 1234)) - .await?, - net.bind(SocketAddr::new(Ipv4Addr::from_str("1.2.3.5")?.into(), 1234)) - .await?, - ); - - let (close_ch_tx, mut close_ch_rx1) = broadcast::channel::(1); - let (done_ch_tx, mut done_ch_rx) = mpsc::channel::(1); - let (conn1_recv_ch_tx, mut conn1_recv_ch_rx) = mpsc::channel(1); - let conn1_rx = Arc::clone(&conn1); - let conn2_tr = Arc::clone(&conn2); - let mut close_ch_rx2 = close_ch_tx.subscribe(); - - // conn1 - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - log::debug!("conn1: wait for a message.."); - tokio::select! { - result = conn1_rx.recv_from(&mut buf) =>{ - let n = match result{ - Ok((n, _)) => n, - Err(err) => { - log::debug!("ReadFrom returned: {}", err); - break; - } - }; - - log::debug!("conn1 received {:?}", &buf[..n]); - let _ = conn1_recv_ch_tx.send(true).await; - } - _ = close_ch_rx1.recv() => { - log::debug!("conn1 received close_ch_rx1"); - break; - } - } - } - drop(done_ch_tx); - log::debug!("conn1 drop done_ch_tx, exit spawn"); - }); - - // conn2 - tokio::spawn(async move { - let mut buf = vec![0u8; 1500]; - loop { - log::debug!("conn2: wait for a message.."); - tokio::select! { - result = conn2_tr.recv_from(&mut buf) =>{ - let (n, addr) = match result{ - Ok((n, addr)) => (n, addr), - Err(err) => { - log::debug!("ReadFrom returned: {}", err); - break; - } - }; - - log::debug!("conn2 received {:?}", &buf[..n]); - - // echo back to conn1 - let n = conn2_tr.send_to(b"Good-bye!", addr).await?; - assert_eq!(n, 9, "should match"); - } - _ = close_ch_rx2.recv() => { - log::debug!("conn1 received close_ch_rx2"); - break; - } - } - } - - log::debug!("conn2 exit spawn"); - - Result::<()>::Ok(()) - }); - - log::debug!("conn1: sending"); - let n = conn1.send_to(b"Hello!", conn2.local_addr()?).await?; - assert_eq!(n, 6, "should match"); - - let _ = conn1_recv_ch_rx.recv().await; - log::debug!("main recv conn1_recv_ch_rx"); - drop(close_ch_tx); - log::debug!("main drop close_ch_tx"); - let _ = done_ch_rx.recv().await; - log::debug!("main recv done_ch_rx"); - Ok(()) -} diff --git a/util/src/vnet/resolver.rs b/util/src/vnet/resolver.rs deleted file mode 100644 index 2972c112a..000000000 --- a/util/src/vnet/resolver.rs +++ /dev/null @@ -1,68 +0,0 @@ -#[cfg(test)] -mod resolver_test; - -use std::collections::HashMap; -use std::future::Future; -use std::net::IpAddr; -use std::pin::Pin; -use std::str::FromStr; -use std::sync::Arc; - -use tokio::sync::Mutex; - -use crate::error::*; - -#[derive(Default)] -pub(crate) struct Resolver { - parent: Option>>, - hosts: HashMap, -} - -impl Resolver { - pub(crate) fn new() -> Self { - let mut r = Resolver { - parent: None, - hosts: HashMap::new(), - }; - - if let Err(err) = r.add_host("localhost".to_owned(), "127.0.0.1".to_owned()) { - log::warn!("failed to add localhost to Resolver: {}", err); - } - r - } - - pub(crate) fn set_parent(&mut self, p: Arc>) { - self.parent = Some(p); - } - - pub(crate) fn add_host(&mut self, name: String, ip_addr: String) -> Result<()> { - if name.is_empty() { - return Err(Error::ErrHostnameEmpty); - } - let ip = IpAddr::from_str(&ip_addr)?; - self.hosts.insert(name, ip); - - Ok(()) - } - - pub(crate) fn lookup( - &self, - host_name: String, - ) -> Pin> + Send + 'static>> { - if let Some(ip) = self.hosts.get(&host_name) { - let ip2 = *ip; - return Box::pin(async move { Some(ip2) }); - } - - // mutex must be unlocked before calling into parent Resolver - if let Some(parent) = &self.parent { - let parent2 = Arc::clone(parent); - Box::pin(async move { - let p = parent2.lock().await; - p.lookup(host_name).await - }) - } else { - Box::pin(async move { None }) - } - } -} diff --git a/util/src/vnet/resolver/resolver_test.rs b/util/src/vnet/resolver/resolver_test.rs deleted file mode 100644 index 701fc3036..000000000 --- a/util/src/vnet/resolver/resolver_test.rs +++ /dev/null @@ -1,71 +0,0 @@ -use super::*; - -const DEMO_IP: &str = "1.2.3.4"; - -#[tokio::test] -async fn test_resolver_standalone() -> Result<()> { - let mut r = Resolver::new(); - - // should have localhost by default - let name = "localhost"; - let ip_addr = "127.0.0.1"; - let ip = IpAddr::from_str(ip_addr)?; - - if let Some(resolved) = r.lookup(name.to_owned()).await { - assert_eq!(resolved, ip, "should match"); - } else { - panic!("should Some, but got None"); - } - - let name = "abc.com"; - let ip_addr = DEMO_IP; - let ip = IpAddr::from_str(ip_addr)?; - log::debug!("adding {} {}", name, ip_addr); - - r.add_host(name.to_owned(), ip_addr.to_owned())?; - - if let Some(resolved) = r.lookup(name.to_owned()).await { - assert_eq!(resolved, ip, "should match"); - } else { - panic!("should Some, but got None"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_resolver_cascaded() -> Result<()> { - let mut r0 = Resolver::new(); - - let name0 = "abc.com"; - let ip_addr0 = DEMO_IP; - let ip0 = IpAddr::from_str(ip_addr0)?; - r0.add_host(name0.to_owned(), ip_addr0.to_owned())?; - - let mut r1 = Resolver::new(); - - let name1 = "myserver.local"; - let ip_addr1 = "10.1.2.5"; - let ip1 = IpAddr::from_str(ip_addr1)?; - r1.add_host(name1.to_owned(), ip_addr1.to_owned())?; - - r1.set_parent(Arc::new(Mutex::new(r0))); - - if let Some(resolved) = r1.lookup(name0.to_owned()).await { - assert_eq!(resolved, ip0, "should match"); - } else { - panic!("should Some, but got None"); - } - - if let Some(resolved) = r1.lookup(name1.to_owned()).await { - assert_eq!(resolved, ip1, "should match"); - } else { - panic!("should Some, but got None"); - } - - // should fail if the name does not exist - let result = r1.lookup("bad.com".to_owned()).await; - assert!(result.is_none(), "should fail"); - - Ok(()) -} diff --git a/util/src/vnet/router.rs b/util/src/vnet/router.rs deleted file mode 100644 index c5ad4c0f2..000000000 --- a/util/src/vnet/router.rs +++ /dev/null @@ -1,586 +0,0 @@ -#[cfg(test)] -mod router_test; - -use std::collections::HashMap; -use std::future::Future; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::ops::{Add, Sub}; -use std::pin::Pin; -use std::str::FromStr; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::SystemTime; - -use async_trait::async_trait; -use ipnet::*; -use portable_atomic::AtomicU64; -use tokio::sync::{mpsc, Mutex}; -use tokio::time::Duration; - -use crate::error::*; -use crate::vnet::chunk::*; -use crate::vnet::chunk_queue::*; -use crate::vnet::interface::*; -use crate::vnet::nat::*; -use crate::vnet::net::*; -use crate::vnet::resolver::*; - -const DEFAULT_ROUTER_QUEUE_SIZE: usize = 0; // unlimited - -lazy_static! { - pub static ref ROUTER_ID_CTR: AtomicU64 = AtomicU64::new(0); -} - -// Generate a unique router name -fn assign_router_name() -> String { - let n = ROUTER_ID_CTR.fetch_add(1, Ordering::SeqCst); - format!("router{n}") -} - -// RouterConfig ... -#[derive(Default)] -pub struct RouterConfig { - // name of router. If not specified, a unique name will be assigned. - pub name: String, - // cidr notation, like "192.0.2.0/24" - pub cidr: String, - // static_ips is an array of static IP addresses to be assigned for this router. - // If no static IP address is given, the router will automatically assign - // an IP address. - // This will be ignored if this router is the root. - pub static_ips: Vec, - // static_ip is deprecated. Use static_ips. - pub static_ip: String, - // Internal queue size - pub queue_size: usize, - // Effective only when this router has a parent router - pub nat_type: Option, - // Minimum Delay - pub min_delay: Duration, - // Max Jitter - pub max_jitter: Duration, -} - -// NIC is a network interface controller that interfaces Router -#[async_trait] -pub trait Nic { - async fn get_interface(&self, ifc_name: &str) -> Option; - async fn add_addrs_to_interface(&mut self, ifc_name: &str, addrs: &[IpNet]) -> Result<()>; - async fn on_inbound_chunk(&self, c: Box); - async fn get_static_ips(&self) -> Vec; - async fn set_router(&self, r: Arc>) -> Result<()>; -} - -// ChunkFilter is a handler users can add to filter chunks. -// If the filter returns false, the packet will be dropped. -pub type ChunkFilterFn = Box bool) + Send + Sync>; - -#[derive(Default)] -pub struct RouterInternal { - pub(crate) nat_type: Option, // read-only - pub(crate) ipv4net: IpNet, // read-only - pub(crate) parent: Option>>, // read-only - pub(crate) nat: NetworkAddressTranslator, // read-only - pub(crate) nics: HashMap>>, // read-only - pub(crate) chunk_filters: Vec, // requires mutex [x] - pub(crate) last_id: u8, // requires mutex [x], used to assign the last digit of IPv4 address -} - -// Router ... -#[derive(Default)] -pub struct Router { - name: String, // read-only - ipv4net: IpNet, // read-only - min_delay: Duration, // requires mutex [x] - max_jitter: Duration, // requires mutex [x] - queue: Arc, // read-only - interfaces: Vec, // read-only - static_ips: Vec, // read-only - static_local_ips: HashMap, // read-only, - children: Vec>>, // read-only - done: Option>, // requires mutex [x] - pub(crate) resolver: Arc>, // read-only - push_ch: Option>, // writer requires mutex - router_internal: Arc>, -} - -#[async_trait] -impl Nic for Router { - async fn get_interface(&self, ifc_name: &str) -> Option { - for ifc in &self.interfaces { - if ifc.name == ifc_name { - return Some(ifc.clone()); - } - } - None - } - - async fn add_addrs_to_interface(&mut self, ifc_name: &str, addrs: &[IpNet]) -> Result<()> { - for ifc in &mut self.interfaces { - if ifc.name == ifc_name { - for addr in addrs { - ifc.add_addr(*addr); - } - return Ok(()); - } - } - - Err(Error::ErrNotFound) - } - - async fn on_inbound_chunk(&self, c: Box) { - let from_parent: Box = { - let router_internal = self.router_internal.lock().await; - match router_internal.nat.translate_inbound(&*c).await { - Ok(from) => { - if let Some(from) = from { - from - } else { - return; - } - } - Err(err) => { - log::warn!("[{}] {}", self.name, err); - return; - } - } - }; - - self.push(from_parent).await; - } - - async fn get_static_ips(&self) -> Vec { - self.static_ips.clone() - } - - // caller must hold the mutex - async fn set_router(&self, parent: Arc>) -> Result<()> { - { - let mut router_internal = self.router_internal.lock().await; - router_internal.parent = Some(Arc::clone(&parent)); - } - - let parent_resolver = { - let p = parent.lock().await; - Arc::clone(&p.resolver) - }; - { - let mut resolver = self.resolver.lock().await; - resolver.set_parent(parent_resolver); - } - - let mut mapped_ips = vec![]; - let mut local_ips = vec![]; - - // when this method is called, one or more IP address has already been assigned by - // the parent router. - if let Some(ifc) = self.get_interface("eth0").await { - for ifc_addr in ifc.addrs() { - let ip = ifc_addr.addr(); - mapped_ips.push(ip); - - if let Some(loc_ip) = self.static_local_ips.get(&ip.to_string()) { - local_ips.push(*loc_ip); - } - } - } else { - return Err(Error::ErrNoIpaddrEth0); - } - - // Set up NAT here - { - let mut router_internal = self.router_internal.lock().await; - if router_internal.nat_type.is_none() { - router_internal.nat_type = Some(NatType { - mapping_behavior: EndpointDependencyType::EndpointIndependent, - filtering_behavior: EndpointDependencyType::EndpointAddrPortDependent, - hair_pining: false, - port_preservation: false, - mapping_life_time: Duration::from_secs(30), - ..Default::default() - }); - } - - router_internal.nat = NetworkAddressTranslator::new(NatConfig { - name: self.name.clone(), - nat_type: router_internal.nat_type.unwrap(), - mapped_ips, - local_ips, - })?; - } - - Ok(()) - } -} - -impl Router { - pub fn new(config: RouterConfig) -> Result { - let ipv4net: IpNet = config.cidr.parse()?; - - let queue_size = if config.queue_size > 0 { - config.queue_size - } else { - DEFAULT_ROUTER_QUEUE_SIZE - }; - - // set up network interface, lo0 - let mut lo0 = Interface::new(LO0_STR.to_owned(), vec![]); - if let Ok(ipnet) = Interface::convert( - SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 0), - Some(SocketAddr::new(Ipv4Addr::new(255, 0, 0, 0).into(), 0)), - ) { - lo0.add_addr(ipnet); - } - - // set up network interface, eth0 - let eth0 = Interface::new("eth0".to_owned(), vec![]); - - // local host name resolver - let resolver = Arc::new(Mutex::new(Resolver::new())); - - let name = if config.name.is_empty() { - assign_router_name() - } else { - config.name.clone() - }; - - let mut static_ips = vec![]; - let mut static_local_ips = HashMap::new(); - for ip_str in &config.static_ips { - let ip_pair: Vec<&str> = ip_str.split('/').collect(); - if let Ok(ip) = IpAddr::from_str(ip_pair[0]) { - if ip_pair.len() > 1 { - let loc_ip = IpAddr::from_str(ip_pair[1])?; - if !ipv4net.contains(&loc_ip) { - return Err(Error::ErrLocalIpBeyondStaticIpsSubset); - } - static_local_ips.insert(ip.to_string(), loc_ip); - } - static_ips.push(ip); - } - } - if !config.static_ip.is_empty() { - log::warn!("static_ip is deprecated. Use static_ips instead"); - if let Ok(ip) = IpAddr::from_str(&config.static_ip) { - static_ips.push(ip); - } - } - - let n_static_local = static_local_ips.len(); - if n_static_local > 0 && n_static_local != static_ips.len() { - return Err(Error::ErrLocalIpNoStaticsIpsAssociated); - } - - let router_internal = RouterInternal { - nat_type: config.nat_type, - ipv4net, - nics: HashMap::new(), - ..Default::default() - }; - - Ok(Router { - name, - ipv4net, - interfaces: vec![lo0, eth0], - static_ips, - static_local_ips, - resolver, - router_internal: Arc::new(Mutex::new(router_internal)), - queue: Arc::new(ChunkQueue::new(queue_size)), - min_delay: config.min_delay, - max_jitter: config.max_jitter, - ..Default::default() - }) - } - - // caller must hold the mutex - pub(crate) fn get_interfaces(&self) -> &[Interface] { - &self.interfaces - } - - // Start ... - pub fn start(&mut self) -> Pin>>> { - if self.done.is_some() { - return Box::pin(async move { Err(Error::ErrRouterAlreadyStarted) }); - } - - let (done_tx, mut done_rx) = mpsc::channel(1); - let (push_ch_tx, mut push_ch_rx) = mpsc::channel(1); - self.done = Some(done_tx); - self.push_ch = Some(push_ch_tx); - - let router_internal = Arc::clone(&self.router_internal); - let queue = Arc::clone(&self.queue); - let max_jitter = self.max_jitter; - let min_delay = self.min_delay; - let name = self.name.clone(); - let ipv4net = self.ipv4net; - - tokio::spawn(async move { - while let Ok(d) = Router::process_chunks( - &name, - ipv4net, - max_jitter, - min_delay, - &queue, - &router_internal, - ) - .await - { - if d == Duration::from_secs(0) { - tokio::select! { - _ = push_ch_rx.recv() =>{}, - _ = done_rx.recv() => break, - } - } else { - let t = tokio::time::sleep(d); - tokio::pin!(t); - - tokio::select! { - _ = t.as_mut() => {}, - _ = done_rx.recv() => break, - } - } - } - }); - - let children = self.children.clone(); - Box::pin(async move { Router::start_children(children).await }) - } - - // Stop ... - pub fn stop(&mut self) -> Pin>>> { - if self.done.is_none() { - return Box::pin(async move { Err(Error::ErrRouterAlreadyStopped) }); - } - self.push_ch.take(); - self.done.take(); - - let children = self.children.clone(); - Box::pin(async move { Router::stop_children(children).await }) - } - - async fn start_children(children: Vec>>) -> Result<()> { - for child in children { - let mut c = child.lock().await; - c.start().await?; - } - - Ok(()) - } - - async fn stop_children(children: Vec>>) -> Result<()> { - for child in children { - let mut c = child.lock().await; - c.stop().await?; - } - - Ok(()) - } - - // AddRouter adds a chile Router. - // after parent.add_router(child), also call child.set_router(parent) to set child's parent router - pub async fn add_router(&mut self, child: Arc>) -> Result<()> { - // Router is a NIC. Add it as a NIC so that packets are routed to this child - // router. - let nic = Arc::clone(&child) as Arc>; - self.children.push(child); - self.add_net(nic).await - } - - // AddNet ... - // after router.add_net(nic), also call nic.set_router(router) to set nic's router - pub async fn add_net(&mut self, nic: Arc>) -> Result<()> { - let mut router_internal = self.router_internal.lock().await; - router_internal.add_nic(nic).await - } - - // AddHost adds a mapping of hostname and an IP address to the local resolver. - pub async fn add_host(&mut self, host_name: String, ip_addr: String) -> Result<()> { - let mut resolver = self.resolver.lock().await; - resolver.add_host(host_name, ip_addr) - } - - // AddChunkFilter adds a filter for chunks traversing this router. - // You may add more than one filter. The filters are called in the order of this method call. - // If a chunk is dropped by a filter, subsequent filter will not receive the chunk. - pub async fn add_chunk_filter(&self, filter: ChunkFilterFn) { - let mut router_internal = self.router_internal.lock().await; - router_internal.chunk_filters.push(filter); - } - - pub(crate) async fn push(&self, mut c: Box) { - log::debug!("[{}] route {}", self.name, c); - if self.done.is_some() { - c.set_timestamp(); - - if self.queue.push(c).await { - if let Some(push_ch) = &self.push_ch { - let _ = push_ch.try_send(()); - } - } else { - log::warn!("[{}] queue was full. dropped a chunk", self.name); - } - } else { - log::warn!("router is done"); - } - } - - async fn process_chunks( - name: &str, - ipv4net: IpNet, - max_jitter: Duration, - min_delay: Duration, - queue: &Arc, - router_internal: &Arc>, - ) -> Result { - // Introduce jitter by delaying the processing of chunks. - let mj = max_jitter.as_nanos() as u64; - if mj > 0 { - let jitter = Duration::from_nanos(rand::random::() % mj); - tokio::time::sleep(jitter).await; - } - - // cut_off - // v min delay - // |<--->| - // +------------:-- - // |OOOOOOXXXXX : --> time - // +------------:-- - // |<--->| now - // due - - let entered_at = SystemTime::now(); - let cut_off = entered_at.sub(min_delay); - - // the next sleep duration - let mut d; - - loop { - d = Duration::from_secs(0); - - if let Some(c) = queue.peek().await { - // check timestamp to find if the chunk is due - if c.get_timestamp().duration_since(cut_off).is_ok() { - // There is one or more chunk in the queue but none of them are due. - // Calculate the next sleep duration here. - let next_expire = c.get_timestamp().add(min_delay); - if let Ok(diff) = next_expire.duration_since(entered_at) { - d = diff; - break; - } - } - } else { - break; // no more chunk in the queue - } - - if let Some(c) = queue.pop().await { - let ri = router_internal.lock().await; - let mut blocked = false; - for filter in &ri.chunk_filters { - if !filter(&*c) { - blocked = true; - break; - } - } - if blocked { - continue; // discard - } - - let dst_ip = c.get_destination_ip(); - - // check if the destination is in our subnet - if ipv4net.contains(&dst_ip) { - // search for the destination NIC - if let Some(nic) = ri.nics.get(&dst_ip.to_string()) { - // found the NIC, forward the chunk to the NIC. - // call to NIC must unlock mutex - let ni = nic.lock().await; - ni.on_inbound_chunk(c).await; - } else { - // NIC not found. drop it. - log::debug!("[{}] {} unreachable", name, c); - } - } else { - // the destination is outside of this subnet - // is this WAN? - if let Some(parent) = &ri.parent { - // Pass it to the parent via NAT - if let Some(to_parent) = ri.nat.translate_outbound(&*c).await? { - // call to parent router mutex unlock mutex - let p = parent.lock().await; - p.push(to_parent).await; - } - } else { - // this WAN. No route for this chunk - log::debug!("[{}] no route found for {}", name, c); - } - } - } else { - break; // no more chunk in the queue - } - } - - Ok(d) - } -} - -impl RouterInternal { - // caller must hold the mutex - pub(crate) async fn add_nic(&mut self, nic: Arc>) -> Result<()> { - let mut ips = { - let ni = nic.lock().await; - ni.get_static_ips().await - }; - - if ips.is_empty() { - // assign an IP address - let ip = self.assign_ip_address()?; - log::debug!("assign_ip_address: {}", ip); - ips.push(ip); - } - - let mut ipnets = vec![]; - for ip in &ips { - if !self.ipv4net.contains(ip) { - return Err(Error::ErrStaticIpIsBeyondSubnet); - } - self.nics.insert(ip.to_string(), Arc::clone(&nic)); - ipnets.push(IpNet::from_str(&format!( - "{}/{}", - ip, - self.ipv4net.prefix_len() - ))?); - } - - { - let mut ni = nic.lock().await; - let _ = ni.add_addrs_to_interface("eth0", &ipnets).await; - } - - Ok(()) - } - - // caller should hold the mutex - fn assign_ip_address(&mut self) -> Result { - // See: https://stackoverflow.com/questions/14915188/ip-address-ending-with-zero - - if self.last_id == 0xfe { - return Err(Error::ErrAddressSpaceExhausted); - } - - self.last_id += 1; - match self.ipv4net.addr() { - IpAddr::V4(ipv4) => { - let mut ip = ipv4.octets(); - ip[3] = self.last_id; - Ok(IpAddr::V4(Ipv4Addr::from(ip))) - } - IpAddr::V6(ipv6) => { - let mut ip = ipv6.octets(); - ip[15] += self.last_id; - Ok(IpAddr::V6(Ipv6Addr::from(ip))) - } - } - } -} diff --git a/util/src/vnet/router/router_test.rs b/util/src/vnet/router/router_test.rs deleted file mode 100644 index e5694af83..000000000 --- a/util/src/vnet/router/router_test.rs +++ /dev/null @@ -1,810 +0,0 @@ -use portable_atomic::{AtomicI32, AtomicUsize}; - -use super::*; - -const MARGIN: Duration = Duration::from_millis(18); -const DEMO_IP: &str = "1.2.3.4"; - -struct DummyNic { - net: Net, - on_inbound_chunk_handler: u16, - cbs0: AtomicI32, - done_ch_tx: Arc>>>, - delay_res: Arc>>, - npkts: i32, -} - -impl Default for DummyNic { - fn default() -> Self { - DummyNic { - net: Net::Ifs(vec![]), - on_inbound_chunk_handler: 0, - cbs0: AtomicI32::new(0), - done_ch_tx: Arc::new(Mutex::new(None)), - delay_res: Arc::new(Mutex::new(vec![])), - npkts: 0, - } - } -} - -#[async_trait] -impl Nic for DummyNic { - async fn get_interface(&self, ifc_name: &str) -> Option { - self.net.get_interface(ifc_name).await - } - - async fn add_addrs_to_interface(&mut self, ifc_name: &str, addrs: &[IpNet]) -> Result<()> { - let nic = self.net.get_nic()?; - let mut net = nic.lock().await; - net.add_addrs_to_interface(ifc_name, addrs).await - } - - async fn set_router(&self, r: Arc>) -> Result<()> { - let nic = self.net.get_nic()?; - let net = nic.lock().await; - net.set_router(r).await - } - - async fn on_inbound_chunk(&self, c: Box) { - log::debug!("received: {}", c); - match self.on_inbound_chunk_handler { - 0 => { - self.cbs0.fetch_add(1, Ordering::SeqCst); - } - 1 => { - let mut done_ch_tx = self.done_ch_tx.lock().await; - done_ch_tx.take(); - } - 2 => { - let delay = SystemTime::now() - .duration_since(c.get_timestamp()) - .unwrap_or(Duration::from_secs(0)); - { - let mut delay_res = self.delay_res.lock().await; - delay_res.push(delay); - } - - let n = self.cbs0.fetch_add(1, Ordering::SeqCst); - if n >= self.npkts - 1 { - let mut done_ch_tx = self.done_ch_tx.lock().await; - done_ch_tx.take(); - } - } - 3 => { - // echo the chunk - let mut echo = c.clone_to(); - let result = echo.set_source_addr(&c.destination_addr().to_string()); - assert!(result.is_ok(), "should succeed"); - let result = echo.set_destination_addr(&c.source_addr().to_string()); - assert!(result.is_ok(), "should succeed"); - - log::debug!("wan.push being called.."); - if let Net::VNet(vnet) = &self.net { - let net = vnet.lock().await; - let vi = net.vi.lock().await; - if let Some(r) = &vi.router { - let wan = r.lock().await; - wan.push(echo).await; - } - } - log::debug!("wan.push called!"); - } - _ => {} - }; - } - - async fn get_static_ips(&self) -> Vec { - let nic = match self.net.get_nic() { - Ok(nic) => nic, - Err(_) => return vec![], - }; - let net = nic.lock().await; - net.get_static_ips().await - } -} - -async fn get_ipaddr(nic: &Arc>) -> Result { - let n = nic.lock().await; - let eth0 = n.get_interface("eth0").await.ok_or(Error::ErrNoInterface)?; - let addrs = eth0.addrs(); - if addrs.is_empty() { - Err(Error::ErrNoAddressAssigned) - } else { - Ok(addrs[0].addr()) - } -} - -#[test] -fn test_router_standalone_cidr_parsing() -> Result<()> { - let r = Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?; - - assert_eq!(r.ipv4net.addr().to_string(), "1.2.3.0", "ip should match"); - assert_eq!( - r.ipv4net.netmask().to_string(), - "255.255.255.0", - "mask should match" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_router_standalone_assign_ip_address() -> Result<()> { - let r = Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?; - - let mut ri = r.router_internal.lock().await; - for i in 1..255 { - let ip = match ri.assign_ip_address()? { - IpAddr::V4(ip) => ip.octets().to_vec(), - IpAddr::V6(ip) => ip.octets().to_vec(), - }; - assert_eq!(ip[0], 1_u8, "should match"); - assert_eq!(ip[1], 2_u8, "should match"); - assert_eq!(ip[2], 3_u8, "should match"); - assert_eq!(ip[3], i as u8, "should match"); - } - - let result = ri.assign_ip_address(); - assert!(result.is_err(), "assign_ip_address should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_router_standalone_add_net() -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let net = Net::new(Some(NetConfig::default())); - - let nic = net.get_nic()?; - - { - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - } - - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - - let eth0 = n.get_interface("eth0").await; - assert!(eth0.is_some(), "should succeed"); - if let Some(eth0) = eth0 { - let addrs = eth0.addrs(); - assert_eq!(addrs.len(), 1, "should match"); - assert_eq!(addrs[0].to_string(), "1.2.3.1/24", "should match"); - assert_eq!(addrs[0].addr().to_string(), "1.2.3.1", "should match"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_router_standalone_routing() -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let (done_ch_tx, mut done_ch_rx) = mpsc::channel(1); - let mut done_ch_tx = Some(done_ch_tx); - - let mut nics = vec![]; - let mut ips = vec![]; - for i in 0..2 { - let dn = DummyNic { - net: Net::new(Some(NetConfig::default())), - on_inbound_chunk_handler: i, - ..Default::default() - }; - if i == 1 { - let mut done_ch = dn.done_ch_tx.lock().await; - *done_ch = done_ch_tx.take(); - } - let nic = Arc::new(Mutex::new(dn)); - - { - let n = Arc::clone(&nic) as Arc>; - let mut w = wan.lock().await; - w.add_net(n).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - { - // Now, eth0 must have one address assigned - let n = nic.lock().await; - if let Some(eth0) = n.get_interface("eth0").await { - let addrs = eth0.addrs(); - assert_eq!(addrs.len(), 1, "should match"); - ips.push(SocketAddr::new(addrs[0].addr(), 1111 * (i + 1))); - } - } - - nics.push(nic); - } - - { - let c = Box::new(ChunkUdp::new(ips[0], ips[1])); - - let mut r = wan.lock().await; - r.start().await?; - r.push(c).await; - } - - let _ = done_ch_rx.recv().await; - - { - let mut r = wan.lock().await; - r.stop().await?; - } - - { - let n = nics[0].lock().await; - assert_eq!(n.cbs0.load(Ordering::SeqCst), 0, "should be zero"); - } - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_router_standalone_add_chunk_filter() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - ..Default::default() - })?)); - - let mut nics = vec![]; - let mut ips = vec![]; - for i in 0..2 { - let dn = DummyNic { - net: Net::new(Some(NetConfig::default())), - on_inbound_chunk_handler: 0, - ..Default::default() - }; - let nic = Arc::new(Mutex::new(dn)); - - { - let n = Arc::clone(&nic) as Arc>; - let mut w = wan.lock().await; - w.add_net(n).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - { - // Now, eth0 must have one address assigned - let n = nic.lock().await; - if let Some(eth0) = n.get_interface("eth0").await { - let addrs = eth0.addrs(); - assert_eq!(addrs.len(), 1, "should match"); - ips.push(SocketAddr::new(addrs[0].addr(), 1111 * (i + 1))); - } - } - - nics.push(nic); - } - - // this creates a filter that block the first chunk - let make_filter_fn = |name: String| { - let n = AtomicUsize::new(0); - Box::new(move |c: &(dyn Chunk + Send + Sync)| -> bool { - let m = n.fetch_add(1, Ordering::SeqCst); - let pass = m > 0; - if pass { - log::debug!("{}: {} passed {}", m, name, c); - } else { - log::debug!("{}: {} blocked {}", m, name, c); - } - pass - }) - }; - - { - let mut r = wan.lock().await; - r.add_chunk_filter(make_filter_fn("filter1".to_owned())) - .await; - r.add_chunk_filter(make_filter_fn("filter2".to_owned())) - .await; - r.start().await?; - - // send 3 packets - for i in 0..3u8 { - let mut c = ChunkUdp::new(ips[0], ips[1]); - c.user_data = vec![i]; // 1-byte seq num - r.push(Box::new(c)).await; - } - } - - tokio::time::sleep(Duration::from_millis(50)).await; - - { - let mut r = wan.lock().await; - r.stop().await?; - } - - { - let n = nics[0].lock().await; - assert_eq!(n.cbs0.load(Ordering::SeqCst), 0, "should be zero"); - } - - { - let n = nics[1].lock().await; - assert_eq!(n.cbs0.load(Ordering::SeqCst), 1, "should be one"); - } - - Ok(()) -} - -async fn delay_sub_test(title: String, min_delay: Duration, max_jitter: Duration) -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_string(), - min_delay, - max_jitter, - ..Default::default() - })?)); - - let npkts = 1; - let (done_ch_tx, mut done_ch_rx) = mpsc::channel(1); - let mut done_ch_tx = Some(done_ch_tx); - - let mut nics = vec![]; - let mut ips = vec![]; - for i in 0..2 { - let mut dn = DummyNic { - net: Net::new(Some(NetConfig::default())), - on_inbound_chunk_handler: 0, - ..Default::default() - }; - if i == 1 { - dn.on_inbound_chunk_handler = 2; - dn.npkts = npkts; - - let mut done_ch = dn.done_ch_tx.lock().await; - *done_ch = done_ch_tx.take(); - } - let nic = Arc::new(Mutex::new(dn)); - - { - let n = Arc::clone(&nic) as Arc>; - let mut w = wan.lock().await; - w.add_net(n).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - { - // Now, eth0 must have one address assigned - let n = nic.lock().await; - if let Some(eth0) = n.get_interface("eth0").await { - let addrs = eth0.addrs(); - assert_eq!(addrs.len(), 1, "should match"); - ips.push(SocketAddr::new(addrs[0].addr(), 1111 * (i + 1))); - } - } - - nics.push(nic); - } - - { - let mut r = wan.lock().await; - r.start().await?; - - for _ in 0..npkts { - let c = Box::new(ChunkUdp::new(ips[0], ips[1])); - r.push(c).await; - tokio::time::sleep(Duration::from_millis(50)).await; - } - } - - let _ = done_ch_rx.recv().await; - - { - let mut r = wan.lock().await; - r.stop().await?; - } - - // Validate the amount of delays - { - let n = nics[1].lock().await; - let delay_res = n.delay_res.lock().await; - for d in &*delay_res { - log::info!("min delay : {:?}", min_delay); - log::info!("max jitter: {:?}", max_jitter); - log::info!("actual delay: {:?}", d); - assert!(*d >= min_delay, "{title} should delay {d:?} >= 20ms"); - assert!( - *d <= (min_delay + max_jitter + MARGIN), - "{title} should delay {d:?} <= minDelay + maxJitter", - ); - // Note: actual delay should be within 30ms but giving a 8ms - // MARGIN for possible extra delay - // (e.g. wakeup delay, debug logs, etc) - } - } - - Ok(()) -} - -//use std::io::Write; -#[cfg(target_os = "linux")] -#[tokio::test] -async fn test_router_delay() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - delay_sub_test( - "Delay only".to_owned(), - Duration::from_millis(20), - Duration::from_millis(0), - ) - .await?; - delay_sub_test( - "Jitter only".to_owned(), - Duration::from_millis(0), - Duration::from_millis(10), - ) - .await?; - delay_sub_test( - "Delay and Jitter".to_owned(), - Duration::from_millis(20), - Duration::from_millis(10), - ) - .await?; - - Ok(()) -} - -//use std::io::Write; - -#[tokio::test] -async fn test_router_one_child() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, log::LevelFilter::Trace) - .init();*/ - - let (done_ch_tx, mut done_ch_rx) = mpsc::channel(1); - let mut done_ch_tx = Some(done_ch_tx); - - let mut rs = vec![]; - let mut nics = vec![]; - let mut ips = vec![]; - for i in 0..2 { - let r = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: if i == 0 { - "1.2.3.0/24".to_owned() - } else { - "192.168.0.0/24".to_owned() - }, - ..Default::default() - })?)); - - let mut dn = DummyNic { - net: Net::new(Some(NetConfig::default())), - on_inbound_chunk_handler: i, - ..Default::default() - }; - if i == 1 { - let mut done_ch = dn.done_ch_tx.lock().await; - *done_ch = done_ch_tx.take(); - } else { - dn.on_inbound_chunk_handler = 3; - } - let nic = Arc::new(Mutex::new(dn)); - - { - let n = Arc::clone(&nic) as Arc>; - let mut w = r.lock().await; - w.add_net(n).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(&r)).await?; - } - - { - let n = Arc::clone(&nic) as Arc>; - let ip = get_ipaddr(&n).await?; - ips.push(ip); - } - - nics.push(nic); - rs.push(r); - } - - { - let child = Arc::clone(&rs[1]); - let mut wan = rs[0].lock().await; - wan.add_router(child).await?; - } - { - let parent = Arc::clone(&rs[0]); - let lan = rs[1].lock().await; - lan.set_router(parent).await?; - } - - { - let mut wan = rs[0].lock().await; - wan.start().await?; - } - - { - let c = Box::new(ChunkUdp::new( - SocketAddr::new(ips[1], 1234), //lanIP - SocketAddr::new(ips[0], 5678), //wanIP - )); - log::debug!("sending {}", c); - let lan = rs[1].lock().await; - lan.push(c).await; - } - - log::debug!("waiting done_ch_rx"); - let _ = done_ch_rx.recv().await; - - { - let mut wan = rs[0].lock().await; - wan.stop().await?; - } - - Ok(()) -} - -#[test] -fn test_router_static_ips_more_than_one() -> Result<()> { - let lan = Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "1.2.3.1".to_owned(), - "1.2.3.2".to_owned(), - "1.2.3.3".to_owned(), - ], - ..Default::default() - })?; - - assert_eq!(lan.static_ips.len(), 3, "should be 3"); - assert_eq!(lan.static_ips[0].to_string(), "1.2.3.1", "should match"); - assert_eq!(lan.static_ips[1].to_string(), "1.2.3.2", "should match"); - assert_eq!(lan.static_ips[2].to_string(), "1.2.3.3", "should match"); - - Ok(()) -} - -#[test] -fn test_router_static_ips_static_ip_local_ip_mapping() -> Result<()> { - let lan = Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "1.2.3.1/192.168.0.1".to_owned(), - "1.2.3.2/192.168.0.2".to_owned(), - "1.2.3.3/192.168.0.3".to_owned(), - ], - ..Default::default() - })?; - - assert_eq!(lan.static_ips.len(), 3, "should be 3"); - assert_eq!(lan.static_ips[0].to_string(), "1.2.3.1", "should match"); - assert_eq!(lan.static_ips[1].to_string(), "1.2.3.2", "should match"); - assert_eq!(lan.static_ips[2].to_string(), "1.2.3.3", "should match"); - - assert_eq!(3, lan.static_local_ips.len(), "should be 3"); - let local_ips = ["192.168.0.1", "192.168.0.2", "192.168.0.3"]; - let ips = ["1.2.3.1", "1.2.3.2", "1.2.3.3"]; - for i in 0..3 { - let ext_ipstr = ips[i]; - if let Some(loc_ip) = lan.static_local_ips.get(ext_ipstr) { - assert_eq!(local_ips[i], loc_ip.to_string(), "should match"); - } else { - panic!("should have the external IP"); - } - } - - // bad local IP - let result = Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "1.2.3.1/192.168.0.1".to_owned(), - "1.2.3.2/bad".to_owned(), // <-- invalid local IP - ], - ..Default::default() - }); - assert!(result.is_err(), "should fail"); - - // local IP out of CIDR - let result = Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "1.2.3.1/192.168.0.1".to_owned(), - "1.2.3.2/172.16.1.2".to_owned(), // <-- out of CIDR - ], - ..Default::default() - }); - assert!(result.is_err(), "should fail"); - - // num of local IPs mismatch - let result = Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "1.2.3.1/192.168.0.1".to_owned(), - "1.2.3.2".to_owned(), // <-- lack of local IP - ], - ..Default::default() - }); - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_router_static_ips_1to1_nat() -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "0.0.0.0/0".to_owned(), - ..Default::default() - })?)); - - let lan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "1.2.3.1/192.168.0.1".to_owned(), - "1.2.3.2/192.168.0.2".to_owned(), - "1.2.3.3/192.168.0.3".to_owned(), - ], - nat_type: Some(NatType { - mode: NatMode::Nat1To1, - ..Default::default() - }), - ..Default::default() - })?)); - - { - let mut w = wan.lock().await; - w.add_router(Arc::clone(&lan)).await?; - } - { - let n = lan.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - { - let l = lan.lock().await; - let ri = l.router_internal.lock().await; - - assert_eq!(ri.nat.mapped_ips.len(), 3, "should be 3"); - assert_eq!(ri.nat.mapped_ips[0].to_string(), "1.2.3.1", "should match"); - assert_eq!(ri.nat.mapped_ips[1].to_string(), "1.2.3.2", "should match"); - assert_eq!(ri.nat.mapped_ips[2].to_string(), "1.2.3.3", "should match"); - - assert_eq!(3, ri.nat.local_ips.len(), "should be 3"); - assert_eq!( - ri.nat.local_ips[0].to_string(), - "192.168.0.1", - "should match" - ); - assert_eq!( - ri.nat.local_ips[1].to_string(), - "192.168.0.2", - "should match" - ); - assert_eq!( - ri.nat.local_ips[2].to_string(), - "192.168.0.3", - "should match" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn test_router_failures_stop() -> Result<()> { - let mut r = Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?; - - let result = r.stop().await; - assert!(result.is_err(), "should fail"); - - Ok(()) -} - -#[tokio::test] -async fn test_router_failures_add_net() -> Result<()> { - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?)); - - let net = Net::new(Some(NetConfig { - static_ips: vec![ - "5.6.7.8".to_owned(), // out of parent router'c CIDR - ], - ..Default::default() - })); - - { - let nic = net.get_nic()?; - let mut w = wan.lock().await; - let result = w.add_net(nic).await; - assert!(result.is_err(), "should fail"); - } - - Ok(()) -} - -#[tokio::test] -async fn test_router_failures_add_router() -> Result<()> { - let r1 = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?)); - - let r2 = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "192.168.0.0/24".to_owned(), - static_ips: vec![ - "5.6.7.8".to_owned(), // out of parent router'c CIDR - ], - ..Default::default() - })?)); - - { - let mut r = r1.lock().await; - let result = r.add_router(Arc::clone(&r2)).await; - assert!(result.is_err(), "should fail"); - } - - Ok(()) -} diff --git a/webrtc/CHANGELOG.md b/webrtc/CHANGELOG.md deleted file mode 100644 index 2947f5afd..000000000 --- a/webrtc/CHANGELOG.md +++ /dev/null @@ -1,236 +0,0 @@ -# webrtc-rs changelog - -## Unreleased - -## v0.7.0 - -* Added support for insecure/deprecated signature verification algorithms, opt in via `SettingsEngine::allow_insecure_verification_algorithm` [#342](https://github.com/webrtc-rs/webrtc/pull/342). -* Make RTCRtpCodecCapability::payloader_for_codec public API [#349](https://github.com/webrtc-rs/webrtc/pull/349). -* Fixed a panic in `calculate_rtt_ms` [#350](https://github.com/webrtc-rs/webrtc/pull/350). -* Fixed `TrackRemote` missing at least the first, sometimes more, RTP packet during probing. [#387](https://github.com/webrtc-rs/webrtc/pull/387) - -### Breaking changes - -* Change `RTCPeerConnection::on_track` callback signature to `|track: Arc, receiver: Arc, transceiver: Arc|` [#355](https://github.com/webrtc-rs/webrtc/pull/355). - -* Change `RTCRtpSender::new` signature to `|receive_mtu: usize, track: Option>, transport: Arc, media_engine: Arc, interceptor: Arc, start_paused: bool,|` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `API::new_rtp_sender` signature to `|&self, track: Option>, transport: Arc, interceptor: Arc,|` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCRtpTransceiver::sender` signature to `|&self| -> Arc` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCRtpTransceiver::set_sender_track` signature to `|self: &Arc, sender: Arc, track: Option>,|` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCRtpTransceiver::set_sender` signature to `|self: &Arc, s: Arc|` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCRtpTransceiver::receiver` signature to `|&self| -> Arc` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCRtpTransceiver::set_receiver` signature to `|&self, r: Arc|` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCPeerConnection::add_transceiver_from_kind` signature to `|&self, kind: RTPCodecType, init: Option,|`, `RTCRtpTransceiver::RTCRtpSender` ัreated without a track [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCPeerConnection::add_transceiver_from_track` signature to `|&self, track: Arc, init: Option,|` [#377](https://github.com/webrtc-rs/webrtc/pull/377). - -* Change `RTCPeerConnection::mid` return signature to `Option` [#375](https://github.com/webrtc-rs/webrtc/pull/375). - -* Make functions non-async [#402](https://github.com/webrtc-rs/webrtc/pull/402): - - `MediaEngine`: - - `get_codecs_by_kind`; - - `get_rtp_parameters_by_kind`. - - `RTCRtpTransceiver`: - - `sender`; - - `set_sender`; - - `receiver`. - - `RTPReceiverInternal`: - - `set_transceiver_codecs`; - - `get_codecs`. - - `RTCRtpSender`: - - `set_rtp_transceiver`; - - `has_sent`. - - `TrackRemote`: - - `id`; - - `set_id`; - - `stream_id`; - - `set_stream_id`; - - `msid`; - - `codec`; - - `set_codec`; - - `params`; - - `set_params`; - - `onmute`; - - `onunmute`. - -* Change `RTPReader::read` signature to `|&self, buf: &mut [u8], attributes: &Attributes| -> Result<(rtp::packet::Packet, Attributes)>` [#450](https://github.com/webrtc-rs/webrtc/pull/450). - -* Change `RTCPReader::read` signature to `|&self, buf: &mut [u8], attributes: &Attributes| -> Result<(Vec>, Attributes)>` [#450](https://github.com/webrtc-rs/webrtc/pull/450). - -## v0.6.0 - -* Added more stats to `RemoteInboundRTPStats` and `RemoteOutboundRTPStats` [#282](https://github.com/webrtc-rs/webrtc/pull/282) by [@k0nserv](https://github.com/k0nserv). -* Don't register `video/rtx` codecs in `MediaEngine::register_default_codecs`. These weren't actually support and prevented RTX in the existing RTP stream from being used. Long term we should support RTX via this method, this is tracked in [#295](https://github.com/webrtc-rs/webrtc/issues/295). [#294 Remove video/rtx codecs](https://github.com/webrtc-rs/webrtc/pull/294) contributed by [k0nserv](https://github.com/k0nserv) -* Add IP filter to WebRTC `SettingEngine` [#306](https://github.com/webrtc-rs/webrtc/pull/306) -* Stop sequence numbers from increasing in `TrackLocalStaticSample` while the bound `RTCRtpSender` have -directions that should not send. [#316](https://github.com/webrtc-rs/webrtc/pull/316) -* Add support for a mime type "audio/telephone-event" (rfc4733) [#322](https://github.com/webrtc-rs/webrtc/pull/322) -* Fixed a panic that would sometimes happen when collecting stats. [#327](https://github.com/webrtc-rs/webrtc/pull/327) by [@k0nserv](https://github.com/k0nserv). -* Added new extension marshaller/unmarshaller for VideoOrientation, and made marshallers serializable via serde [#331](https://github.com/webrtc-rs/webrtc/pull/331) [#332](https://github.com/webrtc-rs/webrtc/pull/332) -* Updated minimum rust version to `1.60.0` -* Added a new `write_rtp_with_extensions` method to `TrackLocalStaticSample` and `TrackLocalStaticRTP`. [#336](https://github.com/webrtc-rs/webrtc/pull/336) by [@k0nserv](https://github.com/k0nserv). -* Added a new `sample_writer` helper to `TrackLocalStaticSample`. [#336](https://github.com/webrtc-rs/webrtc/pull/336) by [@k0nserv](https://github.com/k0nserv). -* Increased minimum versions for sub-dependencies: - * `webrtc-data` version to `0.6.0`. - * `webrtc-ice` version to `0.9.0`. - * `webrtc-media` version to `0.5.0`. - * `webrtc-sctp` version to `0.7.0`. - * `webrtc-util` version to `0.7.0`. - -### Breaking changes - -* Allowed one single direction for extmap matching. [#321](https://github.com/webrtc-rs/webrtc/pull/321). API change for `MediaEngine::register_header_extension`. -* Removed support for Plan-B. All major implementations of WebRTC now support unified and continuing support for plan-b is an undue maintenance burden when unified can be used. See [โ€œUnified Planโ€ Transition Guide (JavaScript)](https://docs.google.com/document/d/1-ZfikoUtoJa9k-GZG1daN0BU3IjIanQ_JSscHxQesvU/) for an overview of the changes required to migrate. [#320](https://github.com/webrtc-rs/webrtc/pull/320) by [@algesten](https://github.com/algesten). -* Removed 2nd argument from `RTCCertificate::from_pem` and guard it with `pem` feature [#333] -* Renamed `RTCCertificate::pem` to `serialize_pem` and guard it with `pem` feature [#333] -* Removed `RTCCertificate::expires` [#333] -* `RTCCertificate::get_fingerprints` no longer returns `Result` [#333] -* Make functions non-async [#338](https://github.com/webrtc-rs/webrtc/pull/338): - - `RTCDataChannel`: - - `on_open`; - - `on_close`; - - `on_message`; - - `on_error`. - - `RTCDtlsTransport::on_state_change`; - - `RTCIceCandidate::to_json`; - - `RTCIceGatherer`: - - `on_local_candidate`; - - `on_state_change`; - - `on_gathering_complete`. - - `RTCIceTransport`: - - `get_selected_candidate_pair`; - - `on_selected_candidate_pair_change`; - - `on_connection_state_change`. - - `RTCPeerConnection`: - - `on_signaling_state_change`; - - `on_data_channel`; - - `on_negotiation_needed`; - - `on_ice_candidate`; - - `on_ice_gathering_state_change`; - - `on_track`; - - `on_ice_connection_state_change`; - - `on_peer_connection_state_change`. - - `RTCSctpTransport`: - - `on_error`; - - `on_data_channel`; - - `on_data_channel_opened`. - -[#333]: https://github.com/webrtc-rs/webrtc/pull/333 - -## v0.5.1 - -* Promote agent lock in ice_gather.rs create_agent() to top level of the function to avoid a race condition. [#290 Promote create_agent lock to top of function, to avoid race condition](https://github.com/webrtc-rs/webrtc/pull/290) contributed by [efer-ms](https://github.com/efer-ms) - -## v0.5.0 - -### Changes - -#### Breaking changes - -* The serialized format for `RTCIceCandidateInit` has changed to match what the specification i.e. keys are camelCase. [#153 Make RTCIceCandidateInit conform to WebRTC spec](https://github.com/webrtc-rs/webrtc/pull/153) contributed by [jmatss](https://github.com/jmatss). -* Improved robustness when proposing RTP extension IDs and handling of collisions in these. This change is only breaking if you have assumed anything about the nature of these extension IDs. [#154 Fix RTP extension id collision](https://github.com/webrtc-rs/webrtc/pull/154) contributed by [k0nserv](https://github.com/k0nserv) -* Transceivers will now not stop when either or both directions are disabled. That is, applying and SDP with `a=inactive` will not stop the transceiver, instead attached senders and receivers will pause. A transceiver can be resurrected by setting direction back to e.g. `a=sendrecv`. The desired direction can be controlled with the newly introduced public method `set_direction` on `RTCRtpTransceiver`. - * [#201 Handle inactive transceivers more correctly](https://github.com/webrtc-rs/webrtc/pull/201) contributed by [k0nserv](https://github.com/k0nserv) - * [#210 Rework transceiver direction support further](https://github.com/webrtc-rs/webrtc/pull/210) contributed by [k0nserv](https://github.com/k0nserv) - * [#214 set_direction add missing Send + Sync bound](https://github.com/webrtc-rs/webrtc/pull/214) contributed by [algesten](https://github.com/algesten) - * [#213 set_direction add missing Sync bound](https://github.com/webrtc-rs/webrtc/pull/213) contributed by [algesten](https://github.com/algesten) - * [#212 Public RTCRtpTransceiver::set_direction](https://github.com/webrtc-rs/webrtc/pull/212) contributed by [algesten](https://github.com/algesten) - * [#268 Fix current direction update when applying answer](https://github.com/webrtc-rs/webrtc/pull/268) contributed by [k0nserv](https://github.com/k0nserv) - * [#236 Pause RTP writing if direction indicates it](https://github.com/webrtc-rs/webrtc/pull/236) contributed by [algesten](https://github.com/algesten) -* Generated the `a=msid` line for `m=` line sections according to the specification. This might be break remote peers that relied on the previous, incorrect, behaviour. This also fixes a bug where an endless negotiation loop could happen. [#217 Correct msid handling for RtpSender](https://github.com/webrtc-rs/webrtc/pull/217) contributed by [k0nserv](https://github.com/k0nserv) -* Improve data channel id negotiation. We've slightly adjust the public interface for creating pre-negotiated data channels. Instead of a separate `negotiated: Option` and `id: Option` in `RTCDataChannelInit` there's now a more idiomatic `negotiated: Option`. If you have a pre-negotiated data channel simply set `negotiated: Some(id)` when creating the data channel. - * [#237 Fix datachannel id setting for 0.5.0 release](https://github.com/webrtc-rs/webrtc/pull/237) contributed by [stuqdog](https://github.com/stuqdog) - * [#229 Revert "base id updating on whether it's been negotiated, not on its โ€ฆ](https://github.com/webrtc-rs/webrtc/pull/229) contributed by [melekes](https://github.com/melekes) - - * [#226 base id updating on whether it's been finalized, not on its value](https://github.com/webrtc-rs/webrtc/pull/226) contributed by [stuqdog](https://github.com/stuqdog) - - -#### Other improvememnts - -We made various improvements and fixes since 0.4.0, including merging all subcrates into a single git repo. The old crate repos are archived and all development will now happen in https://github.com/webrtc-rs/webrtc/. - -* We now provide stats reporting via the standardized `RTCPeerConnection::get_stats` method. - * [#277 Implement Remote Inbound Stats](https://github.com/webrtc-rs/webrtc/pull/277) contributed by [k0nserv](https://github.com/k0nserv) - * [#220 Make stats types pub so they can be used directly](https://github.com/webrtc-rs/webrtc/pull/220) contributed by [k0nserv](https://github.com/k0nserv) - * [#225 Add RTP Stats to stats report](https://github.com/webrtc-rs/webrtc/pull/225) contributed by [k0nserv](https://github.com/k0nserv) - * [#189 Serialize stats](https://github.com/webrtc-rs/webrtc/pull/189) contributed by [sax](https://github.com/sax) - * [#180 Get stats from peer connection](https://github.com/webrtc-rs/webrtc/pull/180) contributed by [sax](https://github.com/sax) - -* [#278 Fix async-global-executor](https://github.com/webrtc-rs/webrtc/pull/278) contributed by [k0nserv](https://github.com/k0nserv) -* [#276 relax regex version requirement](https://github.com/webrtc-rs/webrtc/pull/276) contributed by [melekes](https://github.com/melekes) -* [#244 Update README.md instructions after monorepo merge](https://github.com/webrtc-rs/webrtc/pull/244) contributed by [k0nserv](https://github.com/k0nserv) -* [#241 move profile to workspace](https://github.com/webrtc-rs/webrtc/pull/241) contributed by [xnorpx](https://github.com/xnorpx) -* [#240 Increase timeout to "fix" test breaking](https://github.com/webrtc-rs/webrtc/pull/240) contributed by [algesten](https://github.com/algesten) -* [#239 One repo (again)](https://github.com/webrtc-rs/webrtc/pull/239) contributed by [algesten](https://github.com/algesten) -* [#234 Fix recent clippy lints](https://github.com/webrtc-rs/webrtc/pull/234) contributed by [k0nserv](https://github.com/k0nserv) -* [#224 update call to DataChannel::accept as per data pr #14](https://github.com/webrtc-rs/webrtc/pull/224) contributed by [melekes](https://github.com/melekes) -* [#223 dtls_transport: always set remote certificate](https://github.com/webrtc-rs/webrtc/pull/223) contributed by [melekes](https://github.com/melekes) -* [#216 Lower case mime types for comparison in fmpt lines](https://github.com/webrtc-rs/webrtc/pull/216) contributed by [k0nserv](https://github.com/k0nserv) -* [#211 Helper to trigger negotiation_needed](https://github.com/webrtc-rs/webrtc/pull/211) contributed by [algesten](https://github.com/algesten) -* [#209 MID generator feature](https://github.com/webrtc-rs/webrtc/pull/209) contributed by [algesten](https://github.com/algesten) -* [#208 update deps + loosen some requirements](https://github.com/webrtc-rs/webrtc/pull/208) contributed by [melekes](https://github.com/melekes) -* [#205 data_channel: handle stream EOF](https://github.com/webrtc-rs/webrtc/pull/205) contributed by [melekes](https://github.com/melekes) -* [#204 [peer_connection] allow persistent certificates](https://github.com/webrtc-rs/webrtc/pull/204) contributed by [melekes](https://github.com/melekes) -* [#202 bugfix-Udp connection not close (reopen #174) #195](https://github.com/webrtc-rs/webrtc/pull/202) contributed by [shiqifeng2000](https://github.com/shiqifeng2000) -* [#199 Upgrade ICE to 0.7.0](https://github.com/webrtc-rs/webrtc/pull/199) contributed by [k0nserv](https://github.com/k0nserv) -* [#194 Add AV1 MimeType and RtpCodecParameters](https://github.com/webrtc-rs/webrtc/pull/194) contributed by [billylindeman](https://github.com/billylindeman) -* [#188 Improve operations debuggability](https://github.com/webrtc-rs/webrtc/pull/188) contributed by [k0nserv](https://github.com/k0nserv) -* [#187 Fix SDP for rejected tracks to conform to RFC](https://github.com/webrtc-rs/webrtc/pull/187) contributed by [k0nserv](https://github.com/k0nserv) -* [#185 Adding some debug and display traits](https://github.com/webrtc-rs/webrtc/pull/185) contributed by [sevensidedmarble](https://github.com/sevensidedmarble) -* [#179 Fix example names in README](https://github.com/webrtc-rs/webrtc/pull/179) contributed by [ethagnawl](https://github.com/ethagnawl) -* [#176 Time overflow armv7 workaround](https://github.com/webrtc-rs/webrtc/pull/176) contributed by [frjol](https://github.com/frjol) -* [#171 close DTLS conn upon err](https://github.com/webrtc-rs/webrtc/pull/171) contributed by [melekes](https://github.com/melekes) -* [#170 always start sctp](https://github.com/webrtc-rs/webrtc/pull/170) contributed by [melekes](https://github.com/melekes) -* [#167 Add offer/answer/pranswer constructors for RTCSessionDescription](https://github.com/webrtc-rs/webrtc/pull/167) contributed by [sax](https://github.com/sax) - -#### Subcrate updates - -The various sub-crates have been updated as follows: - -* util: 0.5.3 => 0.6.0 -* sdp: 0.5.1 => 0.5.2 -* mdns: 0.4.2 => 0.5.0 -* stun: 0.4.2 => 0.4.3 -* turn: 0.5.3 => 0.6.0 -* ice: 0.6.4 => 0.8.0 -* dtls: 0.5.2 => 0.6.0 -* rtcp: 0.6.5 => 0.7.0 -* rtp: 0.6.5 => 0.6.7 -* srtp: 0.8.9 => 0.9.0 -* scpt: 0.4.3 => 0.6.1 -* data: 0.3.3 => 0.5.0 -* interceptor: 0.7.6 => 0.8.0 -* media: 0.4.5 => 0.4.7 - -Their respective change logs are found in the old, now archived, repositories and within their respective `CHANGELOG.md` files in the monorepo. - -### Contributors - -A big thanks to all the contributors that have made this release happen: - -* [morajabi](https://github.com/morajabi) -* [sax](https://github.com/sax) -* [ethagnawl](https://github.com/ethagnawl) -* [xnorpx](https://github.com/xnorpx) -* [frjol](https://github.com/frjol) -* [algesten](https://github.com/algesten) -* [shiqifeng2000](https://github.com/shiqifeng2000) -* [billylindeman](https://github.com/billylindeman) -* [sevensidedmarble](https://github.com/sevensidedmarble) -* [k0nserv](https://github.com/k0nserv) -* [stuqdog](https://github.com/stuqdog) -* [neonphog](https://github.com/neonphog) -* [melekes](https://github.com/melekes) -* [jmatss](https://github.com/jmatss) - - -## Prior to 0.5.0 - -Before 0.5.0 there was no changelog, previous changes are sometimes, but not always, available in the [GitHub Releases](https://github.com/webrtc-rs/webrtc/releases). diff --git a/webrtc/Cargo.toml b/webrtc/Cargo.toml deleted file mode 100644 index 25cb56210..000000000 --- a/webrtc/Cargo.toml +++ /dev/null @@ -1,71 +0,0 @@ -[package] -name = "webrtc" -version = "0.11.0" -authors = ["Rain Liu "] -edition = "2021" -description = "A pure Rust implementation of WebRTC API" -license = "MIT OR Apache-2.0" -documentation = "https://docs.rs/webrtc" -homepage = "https://webrtc.rs" -repository = "https://github.com/webrtc-rs/webrtc" -readme = "../README.md" - -[dependencies] -data = { version = "0.9.0", path = "../data", package = "webrtc-data" } -dtls = { version = "0.10.0", path = "../dtls", package = "webrtc-dtls" } -ice = { version = "0.11.0", path = "../ice", package = "webrtc-ice" } -interceptor = { version = "0.12.0", path = "../interceptor" } -mdns = { version = "0.7.0", path = "../mdns", package = "webrtc-mdns" } -media = { version = "0.8.0", path = "../media", package = "webrtc-media" } -rtcp = { version = "0.11.0", path = "../rtcp" } -rtp = { version = "0.11.0", path = "../rtp" } -sctp = { version = "0.10.0", path = "../sctp", package = "webrtc-sctp" } -sdp = { version = "0.6.2", path = "../sdp" } -srtp = { version = "0.13.0", path = "../srtp", package = "webrtc-srtp" } -stun = { version = "0.6.0", path = "../stun" } -turn = { version = "0.8.0", path = "../turn" } -util = { version = "0.9.0", path = "../util", package = "webrtc-util" } - -arc-swap = "1" -tokio = { version = "1.32.0", features = [ - "fs", - "io-util", - "io-std", - "macros", - "net", - "parking_lot", - "rt", - "rt-multi-thread", - "sync", - "time", -] } -log = "0.4" -async-trait = "0.1" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -rand = "0.8" -bytes = "1" -thiserror = "1" -waitgroup = "0.1" -regex = "1.9.5" -smol_str = { version = "0.2", features = ["serde"] } -url = "2" -rustls = { version = "0.23", default-features = false, features = ["std", "ring"] } -rcgen = { version = "0.13", features = ["pem", "x509-parser"]} -ring = "0.17" -sha2 = "0.10" -lazy_static = "1.4" -hex = "0.4" -pem = { version = "3", optional = true } -time = "0.3" -cfg-if = "1" -portable-atomic = "1.6" - -[dev-dependencies] -tokio-test = "0.4" -env_logger = "0.10" - -[features] -pem = ["dep:pem", "dtls/pem"] -openssl = ["srtp/openssl"] -vendored-openssl = ["srtp/vendored-openssl"] diff --git a/webrtc/src/api/api_test.rs b/webrtc/src/api/api_test.rs deleted file mode 100644 index c6b1aa54a..000000000 --- a/webrtc/src/api/api_test.rs +++ /dev/null @@ -1,25 +0,0 @@ -use super::*; - -#[test] -fn test_new_api() -> Result<()> { - let mut s = SettingEngine::default(); - s.detach_data_channels(); - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let api = APIBuilder::new() - .with_setting_engine(s) - .with_media_engine(m) - .build(); - - assert!( - api.setting_engine.detach.data_channels, - "Failed to set settings engine" - ); - assert!( - !api.media_engine.audio_codecs.is_empty(), - "Failed to set media engine" - ); - - Ok(()) -} diff --git a/webrtc/src/api/interceptor_registry/interceptor_registry_test.rs b/webrtc/src/api/interceptor_registry/interceptor_registry_test.rs deleted file mode 100644 index 277bce5b6..000000000 --- a/webrtc/src/api/interceptor_registry/interceptor_registry_test.rs +++ /dev/null @@ -1,278 +0,0 @@ -/*TODO: -use super::*; -use crate::api::APIBuilder; -use crate::peer_connection::configuration::RTCConfiguration; - -use bytes::Bytes; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use interceptor::mock::mock_builder::MockBuilder; -use interceptor::mock::mock_interceptor::MockInterceptor; -use interceptor::stream_info::StreamInfo; -use interceptor::{Attributes, Interceptor, RTPWriter, RTPWriterFn}; - -// E2E test of the features of Interceptors -// * Assert an extension can be set on an outbound packet -// * Assert an extension can be read on an outbound packet -// * Assert that attributes set by an interceptor are returned to the Reader -#[tokio::test] -async fn test_peer_connection_interceptor() -> Result<()> { - let create_pc = || async { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let mut ir = Registry::new(); - - let BindLocalStreamFn = |info: &StreamInfo, - writer: Arc| - -> Pin< - Box> + Send + Sync>, - > { - let writer2 = Arc::clone(&writer); - Box::pin(async move { - Arc::new(RTPWriterFn(Box::new( - move |in_pkt: &rtp::packet::Packet, - attributes: &Attributes| - -> Pin< - Box< - dyn Future> - + Send - + Sync, - >, - > { - let writer3 = Arc::clone(&writer2); - let a = attributes.clone(); - // set extension on outgoing packet - let mut out_pkt = in_pkt.clone(); - out_pkt.header.extension = true; - out_pkt.header.extension_profile = 0xBEDE; - - Box::pin(async move { - out_pkt - .header - .set_extension(2, Bytes::from_static(b"foo"))?; - //writer3.write(&out_pkt, &a).await - Ok(0) - }) - }, - ))) as Arc - }) - }; - - BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { - return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { - if a == nil { - a = interceptor.Attributes{} - } - - a.Set("attribute", "value") - return reader.Read(b, a) - }) - }, - let mock_builder = Box::new(MockBuilder { - build: - Box::new( - |_: &str| -> std::result::Result< - Arc, - interceptor::Error, - > { - Ok(Arc::new(MockInterceptor { - ..Default::default() - })) - }, - ), - }); - let mock_builder = MockBuilder::new( - |_: &str| -> std::result::Result< - Arc, - interceptor::Error, - > { - Ok(Arc::new(MockInterceptor { - ..Default::default() - })) - }, - ); - ir.add(Box::new(mock_builder)); - - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(ir) - .build(); - api.new_peer_connection(RTCConfiguration::default()).await - }; - - let offerer = create_pc().await?; - let answerer = create_pc().await?; - - track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") - assert.NoError(t, err) - - _, err = offerer.AddTrack(track) - assert.NoError(t, err) - - seenRTP, seenRTPCancel := context.WithCancel(context.Background()) - answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { - p, attributes, readErr := track.ReadRTP() - assert.NoError(t, readErr) - - assert.Equal(t, p.Extension, true) - assert.Equal(t, "foo", string(p.GetExtension(2))) - assert.Equal(t, "value", attributes.Get("attribute")) - - seenRTPCancel() - }) - - assert.NoError(t, signalPair(offerer, answerer)) - - func() { - ticker := time.NewTicker(time.Millisecond * 20) - for { - select { - case <-seenRTP.Done(): - return - case <-ticker.C: - assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) - } - } - }() - - closePairNow(t, offerer, answerer) - - Ok(()) -} - -func Test_Interceptor_BindUnbind(t *testing.T) { - lim := test.TimeOut(time.Second * 10) - defer lim.Stop() - - report := test.CheckRoutines(t) - defer report() - - m := &MediaEngine{} - assert.NoError(t, m.RegisterDefaultCodecs()) - - var ( - cntBindRTCPReader uint32 - cntBindRTCPWriter uint32 - cntBindLocalStream uint32 - cntUnbindLocalStream uint32 - cntBindRemoteStream uint32 - cntUnbindRemoteStream uint32 - cntClose uint32 - ) - mockInterceptor := &mock_interceptor.Interceptor{ - BindRTCPReaderFn: func(reader interceptor.RTCPReader) interceptor.RTCPReader { - atomic.AddUint32(&cntBindRTCPReader, 1) - return reader - }, - BindRTCPWriterFn: func(writer interceptor.RTCPWriter) interceptor.RTCPWriter { - atomic.AddUint32(&cntBindRTCPWriter, 1) - return writer - }, - BindLocalStreamFn: func(i *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { - atomic.AddUint32(&cntBindLocalStream, 1) - return writer - }, - UnbindLocalStreamFn: func(i *interceptor.StreamInfo) { - atomic.AddUint32(&cntUnbindLocalStream, 1) - }, - BindRemoteStreamFn: func(i *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { - atomic.AddUint32(&cntBindRemoteStream, 1) - return reader - }, - UnbindRemoteStreamFn: func(i *interceptor.StreamInfo) { - atomic.AddUint32(&cntUnbindRemoteStream, 1) - }, - CloseFn: func() error { - atomic.AddUint32(&cntClose, 1) - return nil - }, - } - ir := &interceptor.Registry{} - ir.Add(&mock_interceptor.Factory{ - NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil }, - }) - - sender, receiver, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{}) - assert.NoError(t, err) - - track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") - assert.NoError(t, err) - - _, err = sender.AddTrack(track) - assert.NoError(t, err) - - receiverReady, receiverReadyFn := context.WithCancel(context.Background()) - receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { - _, _, readErr := track.ReadRTP() - assert.NoError(t, readErr) - receiverReadyFn() - }) - - assert.NoError(t, signalPair(sender, receiver)) - - ticker := time.NewTicker(time.Millisecond * 20) - defer ticker.Stop() - func() { - for { - select { - case <-receiverReady.Done(): - return - case <-ticker.C: - // Send packet to make receiver track actual creates RTPReceiver. - assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) - } - } - }() - - closePairNow(t, sender, receiver) - - // Bind/UnbindLocal/RemoteStream should be called from one side. - if cnt := atomic.LoadUint32(&cntBindLocalStream); cnt != 1 { - t.Errorf("BindLocalStreamFn is expected to be called once, but called %d times", cnt) - } - if cnt := atomic.LoadUint32(&cntUnbindLocalStream); cnt != 1 { - t.Errorf("UnbindLocalStreamFn is expected to be called once, but called %d times", cnt) - } - if cnt := atomic.LoadUint32(&cntBindRemoteStream); cnt != 1 { - t.Errorf("BindRemoteStreamFn is expected to be called once, but called %d times", cnt) - } - if cnt := atomic.LoadUint32(&cntUnbindRemoteStream); cnt != 1 { - t.Errorf("UnbindRemoteStreamFn is expected to be called once, but called %d times", cnt) - } - - // BindRTCPWriter/Reader and Close should be called from both side. - if cnt := atomic.LoadUint32(&cntBindRTCPWriter); cnt != 2 { - t.Errorf("BindRTCPWriterFn is expected to be called twice, but called %d times", cnt) - } - if cnt := atomic.LoadUint32(&cntBindRTCPReader); cnt != 2 { - t.Errorf("BindRTCPReaderFn is expected to be called twice, but called %d times", cnt) - } - if cnt := atomic.LoadUint32(&cntClose); cnt != 2 { - t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt) - } -} - -func Test_InterceptorRegistry_Build(t *testing.T) { - registryBuildCount := 0 - - ir := &interceptor.Registry{} - ir.Add(&mock_interceptor.Factory{ - NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { - registryBuildCount++ - return &interceptor.NoOp{}, nil - }, - }) - - peerConnectionA, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) - assert.NoError(t, err) - - peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) - assert.NoError(t, err) - - assert.Equal(t, 2, registryBuildCount) - closePairNow(t, peerConnectionA, peerConnectionB) -} -*/ diff --git a/webrtc/src/api/interceptor_registry/mod.rs b/webrtc/src/api/interceptor_registry/mod.rs deleted file mode 100644 index ca4f54904..000000000 --- a/webrtc/src/api/interceptor_registry/mod.rs +++ /dev/null @@ -1,171 +0,0 @@ -#[cfg(test)] -mod interceptor_registry_test; - -use interceptor::nack::generator::Generator; -use interceptor::nack::responder::Responder; -use interceptor::registry::Registry; -use interceptor::report::receiver::ReceiverReport; -use interceptor::report::sender::SenderReport; -use interceptor::twcc::receiver::Receiver; -use interceptor::twcc::sender::Sender; - -use crate::api::media_engine::MediaEngine; -use crate::error::Result; -use crate::rtp_transceiver::rtp_codec::{RTCRtpHeaderExtensionCapability, RTPCodecType}; -use crate::rtp_transceiver::{RTCPFeedback, TYPE_RTCP_FB_TRANSPORT_CC}; - -/// register_default_interceptors will register some useful interceptors. -/// If you want to customize which interceptors are loaded, you should copy the -/// code from this method and remove unwanted interceptors. -pub fn register_default_interceptors( - mut registry: Registry, - media_engine: &mut MediaEngine, -) -> Result { - registry = configure_nack(registry, media_engine); - - registry = configure_rtcp_reports(registry); - - registry = configure_twcc_receiver_only(registry, media_engine)?; - - Ok(registry) -} - -/// configure_rtcp_reports will setup everything necessary for generating Sender and Receiver Reports -pub fn configure_rtcp_reports(mut registry: Registry) -> Registry { - let receiver = Box::new(ReceiverReport::builder()); - let sender = Box::new(SenderReport::builder()); - registry.add(receiver); - registry.add(sender); - registry -} - -/// configure_nack will setup everything necessary for handling generating/responding to nack messages. -pub fn configure_nack(mut registry: Registry, media_engine: &mut MediaEngine) -> Registry { - media_engine.register_feedback( - RTCPFeedback { - typ: "nack".to_owned(), - parameter: "".to_owned(), - }, - RTPCodecType::Video, - ); - media_engine.register_feedback( - RTCPFeedback { - typ: "nack".to_owned(), - parameter: "pli".to_owned(), - }, - RTPCodecType::Video, - ); - - let generator = Box::new(Generator::builder()); - let responder = Box::new(Responder::builder()); - registry.add(responder); - registry.add(generator); - registry -} - -/// configure_twcc will setup everything necessary for adding -/// a TWCC header extension to outgoing RTP packets and generating TWCC reports. -pub fn configure_twcc(mut registry: Registry, media_engine: &mut MediaEngine) -> Result { - media_engine.register_feedback( - RTCPFeedback { - typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), - ..Default::default() - }, - RTPCodecType::Video, - ); - media_engine.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), - }, - RTPCodecType::Video, - None, - )?; - - media_engine.register_feedback( - RTCPFeedback { - typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), - ..Default::default() - }, - RTPCodecType::Audio, - ); - media_engine.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - - let sender = Box::new(Sender::builder()); - let receiver = Box::new(Receiver::builder()); - registry.add(sender); - registry.add(receiver); - Ok(registry) -} - -/// configure_twcc_sender will setup everything necessary for adding -/// a TWCC header extension to outgoing RTP packets. This will allow the remote peer to generate TWCC reports. -pub fn configure_twcc_sender_only( - mut registry: Registry, - media_engine: &mut MediaEngine, -) -> Result { - media_engine.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), - }, - RTPCodecType::Video, - None, - )?; - - media_engine.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - - let sender = Box::new(Sender::builder()); - registry.add(sender); - Ok(registry) -} - -/// configure_twcc_receiver will setup everything necessary for generating TWCC reports. -pub fn configure_twcc_receiver_only( - mut registry: Registry, - media_engine: &mut MediaEngine, -) -> Result { - media_engine.register_feedback( - RTCPFeedback { - typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), - ..Default::default() - }, - RTPCodecType::Video, - ); - media_engine.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), - }, - RTPCodecType::Video, - None, - )?; - - media_engine.register_feedback( - RTCPFeedback { - typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), - ..Default::default() - }, - RTPCodecType::Audio, - ); - media_engine.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - - let receiver = Box::new(Receiver::builder()); - registry.add(receiver); - Ok(registry) -} diff --git a/webrtc/src/api/media_engine/media_engine_test.rs b/webrtc/src/api/media_engine/media_engine_test.rs deleted file mode 100644 index cbf3b503b..000000000 --- a/webrtc/src/api/media_engine/media_engine_test.rs +++ /dev/null @@ -1,780 +0,0 @@ -use std::io::Cursor; - -use regex::Regex; - -use super::*; -use crate::api::media_engine::MIME_TYPE_OPUS; -use crate::api::APIBuilder; -use crate::peer_connection::configuration::RTCConfiguration; - -#[tokio::test] -async fn test_opus_case() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let pc = api.new_peer_connection(RTCConfiguration::default()).await?; - pc.add_transceiver_from_kind(RTPCodecType::Audio, None) - .await?; - - let offer = pc.create_offer(None).await?; - - let re = Regex::new(r"(?m)^a=rtpmap:\d+ opus/48000/2").unwrap(); - assert!(re.is_match(offer.sdp.as_str())); - - pc.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_video_case() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let pc = api.new_peer_connection(RTCConfiguration::default()).await?; - pc.add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let offer = pc.create_offer(None).await?; - - let re = Regex::new(r"(?m)^a=rtpmap:\d+ H264/90000").unwrap(); - assert!(re.is_match(offer.sdp.as_str())); - let re = Regex::new(r"(?m)^a=rtpmap:\d+ VP8/90000").unwrap(); - assert!(re.is_match(offer.sdp.as_str())); - let re = Regex::new(r"(?m)^a=rtpmap:\d+ VP9/90000").unwrap(); - assert!(re.is_match(offer.sdp.as_str())); - - pc.close().await?; - - Ok(()) -} - -#[tokio::test] -async fn test_media_engine_remote_description() -> Result<()> { - let must_parse = |raw: &str| -> Result { - let mut reader = Cursor::new(raw.as_bytes()); - Ok(SessionDescription::unmarshal(&mut reader)?) - }; - - //"No Media" - { - const NO_MEDIA: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -"; - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.update_from_remote_description(&must_parse(NO_MEDIA)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(!m.negotiated_audio.load(Ordering::SeqCst)); - } - - //"Enable Opus" - { - const OPUS_SAME_PAYLOAD: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=audio 9 UDP/TLS/RTP/SAVPF 111 -a=rtpmap:111 opus/48000/2 -a=fmtp:111 minptime=10; useinbandfec=1 -"; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.update_from_remote_description(&must_parse(OPUS_SAME_PAYLOAD)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(m.negotiated_audio.load(Ordering::SeqCst)); - - let (opus_codec, _) = m.get_codec_by_payload(111).await?; - assert_eq!(opus_codec.capability.mime_type, MIME_TYPE_OPUS); - } - - //"Change Payload Type" - { - const OPUS_SAME_PAYLOAD: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=audio 9 UDP/TLS/RTP/SAVPF 112 -a=rtpmap:112 opus/48000/2 -a=fmtp:112 minptime=10; useinbandfec=1 -"; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.update_from_remote_description(&must_parse(OPUS_SAME_PAYLOAD)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(m.negotiated_audio.load(Ordering::SeqCst)); - - let result = m.get_codec_by_payload(111).await; - assert!(result.is_err()); - - let (opus_codec, _) = m.get_codec_by_payload(112).await?; - assert_eq!(opus_codec.capability.mime_type, MIME_TYPE_OPUS); - } - - //"Case Insensitive" - { - const OPUS_UPCASE: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=audio 9 UDP/TLS/RTP/SAVPF 111 -a=rtpmap:111 OPUS/48000/2 -a=fmtp:111 minptime=10; useinbandfec=1 -"; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.update_from_remote_description(&must_parse(OPUS_UPCASE)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(m.negotiated_audio.load(Ordering::SeqCst)); - - let (opus_codec, _) = m.get_codec_by_payload(111).await?; - assert_eq!(opus_codec.capability.mime_type, "audio/OPUS"); - } - - //"Handle different fmtp" - { - const OPUS_NO_FMTP: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=audio 9 UDP/TLS/RTP/SAVPF 111 -a=rtpmap:111 opus/48000/2 -"; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.update_from_remote_description(&must_parse(OPUS_NO_FMTP)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(m.negotiated_audio.load(Ordering::SeqCst)); - - let (opus_codec, _) = m.get_codec_by_payload(111).await?; - assert_eq!(opus_codec.capability.mime_type, MIME_TYPE_OPUS); - } - - //"Header Extensions" - { - const HEADER_EXTENSIONS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=audio 9 UDP/TLS/RTP/SAVPF 111 -a=extmap:7 urn:ietf:params:rtp-hdrext:sdes:mid -a=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id -a=rtpmap:111 opus/48000/2 -"; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - for extension in [ - "urn:ietf:params:rtp-hdrext:sdes:mid", - "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id", - ] { - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: extension.to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - } - - m.update_from_remote_description(&must_parse(HEADER_EXTENSIONS)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(m.negotiated_audio.load(Ordering::SeqCst)); - - let (abs_id, abs_audio_enabled, abs_video_enabled) = m - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::ABS_SEND_TIME_URI.to_owned(), - }) - .await; - assert_eq!(abs_id, 0); - assert!(!abs_audio_enabled); - assert!(!abs_video_enabled); - - let (mid_id, mid_audio_enabled, mid_video_enabled) = m - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::SDES_MID_URI.to_owned(), - }) - .await; - assert_eq!(mid_id, 7); - assert!(mid_audio_enabled); - assert!(!mid_video_enabled); - } - - //"Prefers exact codec matches" - { - const PROFILE_LEVELS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=video 60323 UDP/TLS/RTP/SAVPF 96 98 -a=rtpmap:96 H264/90000 -a=fmtp:96 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640c1f -a=rtpmap:98 H264/90000 -a=fmtp:98 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f -"; - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" - .to_string(), - rtcp_feedback: vec![], - }, - payload_type: 127, - ..Default::default() - }, - RTPCodecType::Video, - )?; - m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) - .await?; - - assert!(m.negotiated_video.load(Ordering::SeqCst)); - assert!(!m.negotiated_audio.load(Ordering::SeqCst)); - - let (supported_h264, _) = m.get_codec_by_payload(98).await?; - assert_eq!(supported_h264.capability.mime_type, MIME_TYPE_H264); - - assert!(m.get_codec_by_payload(96).await.is_err()); - } - - //"Does not match when fmtpline is set and does not match" - { - const PROFILE_LEVELS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=video 60323 UDP/TLS/RTP/SAVPF 96 98 -a=rtpmap:96 H264/90000 -a=fmtp:96 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640c1f -"; - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" - .to_string(), - rtcp_feedback: vec![], - }, - payload_type: 127, - ..Default::default() - }, - RTPCodecType::Video, - )?; - assert!(m - .update_from_remote_description(&must_parse(PROFILE_LEVELS)?) - .await - .is_err()); - - assert!(m.get_codec_by_payload(96).await.is_err()); - } - - //"Matches when fmtpline is not set in offer, but exists in mediaengine" - { - const PROFILE_LEVELS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=video 60323 UDP/TLS/RTP/SAVPF 96 -a=rtpmap:96 VP9/90000 -"; - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=0".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 98, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) - .await?; - - assert!(m.negotiated_video.load(Ordering::SeqCst)); - - m.get_codec_by_payload(96).await?; - } - - //"Matches when fmtpline exists in neither" - { - const PROFILE_LEVELS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=video 60323 UDP/TLS/RTP/SAVPF 96 -a=rtpmap:96 VP8/90000 -"; - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) - .await?; - - assert!(m.negotiated_video.load(Ordering::SeqCst)); - - m.get_codec_by_payload(96).await?; - } - - //"Matches when rtx apt for exact match codec" - { - const PROFILE_LEVELS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=video 60323 UDP/TLS/RTP/SAVPF 94 96 97 -a=rtpmap:94 VP8/90000 -a=rtpmap:96 VP9/90000 -a=fmtp:96 profile-id=2 -a=rtpmap:97 rtx/90000 -a=fmtp:97 apt=96 -"; - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 94, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=2".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: "video/rtx".to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "apt=96".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 97, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) - .await?; - - assert!(m.negotiated_video.load(Ordering::SeqCst)); - - m.get_codec_by_payload(97).await?; - } - - //"Matches when rtx apt for partial match codec" - { - const PROFILE_LEVELS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=video 60323 UDP/TLS/RTP/SAVPF 94 96 97 -a=rtpmap:94 VP8/90000 -a=rtpmap:96 VP9/90000 -a=fmtp:96 profile-id=2 -a=rtpmap:97 rtx/90000 -a=fmtp:97 apt=96 -"; - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 94, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=1".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: "video/rtx".to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "apt=96".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 97, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) - .await?; - - assert!(m.negotiated_video.load(Ordering::SeqCst)); - - if let Err(err) = m.get_codec_by_payload(97).await { - assert_eq!(err, Error::ErrCodecNotFound); - } else { - panic!(); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_media_engine_header_extension_direction() -> Result<()> { - let register_codec = |m: &mut MediaEngine| -> Result<()> { - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - ) - }; - - //"No Direction" - { - let mut m = MediaEngine::default(); - register_codec(&mut m)?; - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: "webrtc-header-test".to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - - let params = - m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Recvonly); - - assert_eq!(params.header_extensions.len(), 1); - } - - //"Same Direction" - { - let mut m = MediaEngine::default(); - register_codec(&mut m)?; - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: "webrtc-header-test".to_owned(), - }, - RTPCodecType::Audio, - Some(RTCRtpTransceiverDirection::Recvonly), - )?; - - let params = - m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Recvonly); - - assert_eq!(params.header_extensions.len(), 1); - } - - //"Different Direction" - { - let mut m = MediaEngine::default(); - register_codec(&mut m)?; - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: "webrtc-header-test".to_owned(), - }, - RTPCodecType::Audio, - Some(RTCRtpTransceiverDirection::Sendonly), - )?; - - let params = - m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Recvonly); - - assert_eq!(params.header_extensions.len(), 0); - } - - //"No direction and inactive" - { - let mut m = MediaEngine::default(); - register_codec(&mut m)?; - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: "webrtc-header-test".to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - - let params = - m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Inactive); - - assert_eq!(params.header_extensions.len(), 1); - } - - Ok(()) -} - -/// If a user attempts to register a codec twice we should just discard duplicate calls -#[tokio::test] -async fn test_media_engine_double_register() -> Result<()> { - let mut m = MediaEngine::default(); - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - - assert_eq!(m.audio_codecs.len(), 1); - Ok(()) -} - -async fn validate(m: &MediaEngine) -> Result<()> { - m.update_header_extension(2, "test-extension", RTPCodecType::Audio) - .await?; - - let (id, audio_negotiated, video_negotiated) = m - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: "test-extension".to_owned(), - }) - .await; - assert_eq!(id, 2); - assert!(audio_negotiated); - assert!(!video_negotiated); - - Ok(()) -} - -/// The cloned MediaEngine instance should be able to update negotiated header extensions. -#[tokio::test] -async fn test_update_header_extension_to_cloned_media_engine() -> Result<()> { - let mut m = MediaEngine::default(); - - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTPCodecType::Audio, - )?; - - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: "test-extension".to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - - validate(&m).await?; - validate(&m.clone_to()).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_extension_id_collision() -> Result<()> { - let must_parse = |raw: &str| -> Result { - let mut reader = Cursor::new(raw.as_bytes()); - Ok(SessionDescription::unmarshal(&mut reader)?) - }; - - const HEADER_EXTENSIONS: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=audio 9 UDP/TLS/RTP/SAVPF 111 -a=extmap:7 urn:ietf:params:rtp-hdrext:sdes:mid -a=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level -a=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id -a=rtpmap:111 opus/48000/2 -"; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - { - let extension = "urn:3gpp:video-orientation"; - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: extension.to_owned(), - }, - RTPCodecType::Video, - None, - )?; - } - for extension in [ - "urn:ietf:params:rtp-hdrext:ssrc-audio-level", - "urn:ietf:params:rtp-hdrext:sdes:mid", - "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id", - ] { - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: extension.to_owned(), - }, - RTPCodecType::Audio, - None, - )?; - } - - m.update_from_remote_description(&must_parse(HEADER_EXTENSIONS)?) - .await?; - - assert!(!m.negotiated_video.load(Ordering::SeqCst)); - assert!(m.negotiated_audio.load(Ordering::SeqCst)); - - let (abs_id, abs_audio_enabled, abs_video_enabled) = m - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::ABS_SEND_TIME_URI.to_owned(), - }) - .await; - assert_eq!(abs_id, 0); - assert!(!abs_audio_enabled); - assert!(!abs_video_enabled); - - let (mid_id, mid_audio_enabled, mid_video_enabled) = m - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::SDES_MID_URI.to_owned(), - }) - .await; - assert_eq!(mid_id, 7); - assert!(mid_audio_enabled); - assert!(!mid_video_enabled); - - let (mid_id, mid_audio_enabled, mid_video_enabled) = m - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: sdp::extmap::AUDIO_LEVEL_URI.to_owned(), - }) - .await; - assert_eq!(mid_id, 1); - assert!(mid_audio_enabled); - assert!(!mid_video_enabled); - - let params = - m.get_rtp_parameters_by_kind(RTPCodecType::Video, RTCRtpTransceiverDirection::Sendonly); - //dbg!(¶ms); - - let orientation = params - .header_extensions - .iter() - .find(|ext| ext.uri == "urn:3gpp:video-orientation") - .unwrap(); - assert_ne!(orientation.id, 1); - assert_ne!(orientation.id, 7); - assert_ne!(orientation.id, 5); - - Ok(()) -} diff --git a/webrtc/src/api/media_engine/mod.rs b/webrtc/src/api/media_engine/mod.rs deleted file mode 100644 index 82955534e..000000000 --- a/webrtc/src/api/media_engine/mod.rs +++ /dev/null @@ -1,819 +0,0 @@ -#[cfg(test)] -mod media_engine_test; - -use std::collections::HashMap; -use std::ops::Range; -use std::sync::atomic::Ordering; -use std::time::{SystemTime, UNIX_EPOCH}; - -use portable_atomic::AtomicBool; -use sdp::description::session::SessionDescription; -use util::sync::Mutex as SyncMutex; - -use crate::error::{Error, Result}; -use crate::peer_connection::sdp::{ - codecs_from_media_description, rtp_extensions_from_media_description, -}; -use crate::rtp_transceiver::rtp_codec::{ - codec_parameters_fuzzy_search, CodecMatch, RTCRtpCodecCapability, RTCRtpCodecParameters, - RTCRtpHeaderExtensionCapability, RTCRtpHeaderExtensionParameters, RTCRtpParameters, - RTPCodecType, -}; -use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use crate::rtp_transceiver::{fmtp, PayloadType, RTCPFeedback}; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::CodecStats; -use crate::stats::StatsReportType::Codec; - -/// MIME_TYPE_H264 H264 MIME type. -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_H264: &str = "video/H264"; -/// MIME_TYPE_HEVC HEVC MIME type. -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_HEVC: &str = "video/HEVC"; -/// MIME_TYPE_OPUS Opus MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_OPUS: &str = "audio/opus"; -/// MIME_TYPE_VP8 VP8 MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_VP8: &str = "video/VP8"; -/// MIME_TYPE_VP9 VP9 MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_VP9: &str = "video/VP9"; -/// MIME_TYPE_AV1 AV1 MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_AV1: &str = "video/AV1"; -/// MIME_TYPE_G722 G722 MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_G722: &str = "audio/G722"; -/// MIME_TYPE_PCMU PCMU MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_PCMU: &str = "audio/PCMU"; -/// MIME_TYPE_PCMA PCMA MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_PCMA: &str = "audio/PCMA"; -/// MIME_TYPE_TELEPHONE_EVENT telephone-event MIME type -/// Note: Matching should be case insensitive. -pub const MIME_TYPE_TELEPHONE_EVENT: &str = "audio/telephone-event"; - -const VALID_EXT_IDS: Range = 1..15; - -#[derive(Default, Clone)] -pub(crate) struct MediaEngineHeaderExtension { - pub(crate) uri: String, - pub(crate) is_audio: bool, - pub(crate) is_video: bool, - pub(crate) allowed_direction: Option, -} - -impl MediaEngineHeaderExtension { - pub fn is_matching_direction(&self, dir: RTCRtpTransceiverDirection) -> bool { - if let Some(allowed_direction) = self.allowed_direction { - use RTCRtpTransceiverDirection::*; - allowed_direction == Inactive && dir == Inactive - || allowed_direction.has_send() && dir.has_send() - || allowed_direction.has_recv() && dir.has_recv() - } else { - // None means all directions matches. - true - } - } -} - -/// A MediaEngine defines the codecs supported by a PeerConnection, and the -/// configuration of those codecs. A MediaEngine must not be shared between -/// PeerConnections. -#[derive(Default)] -pub struct MediaEngine { - // If we have attempted to negotiate a codec type yet. - pub(crate) negotiated_video: AtomicBool, - pub(crate) negotiated_audio: AtomicBool, - - pub(crate) video_codecs: Vec, - pub(crate) audio_codecs: Vec, - pub(crate) negotiated_video_codecs: SyncMutex>, - pub(crate) negotiated_audio_codecs: SyncMutex>, - - header_extensions: Vec, - proposed_header_extensions: SyncMutex>, - pub(crate) negotiated_header_extensions: SyncMutex>, -} - -impl MediaEngine { - /// register_default_codecs registers the default codecs supported by Pion WebRTC. - /// register_default_codecs is not safe for concurrent use. - pub fn register_default_codecs(&mut self) -> Result<()> { - // Default Audio Codecs - for codec in vec![ - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "minptime=10;useinbandfec=1".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_G722.to_owned(), - clock_rate: 8000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 9, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_PCMU.to_owned(), - clock_rate: 8000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 0, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_PCMA.to_owned(), - clock_rate: 8000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 8, - ..Default::default() - }, - ] { - self.register_codec(codec, RTPCodecType::Audio)?; - } - - let video_rtcp_feedback = vec![ - RTCPFeedback { - typ: "goog-remb".to_owned(), - parameter: "".to_owned(), - }, - RTCPFeedback { - typ: "ccm".to_owned(), - parameter: "fir".to_owned(), - }, - RTCPFeedback { - typ: "nack".to_owned(), - parameter: "".to_owned(), - }, - RTCPFeedback { - typ: "nack".to_owned(), - parameter: "pli".to_owned(), - }, - ]; - for codec in vec![ - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 96, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=0".to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 98, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=1".to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 100, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f" - .to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 102, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f" - .to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 127, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" - .to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 125, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f" - .to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 108, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f" - .to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 127, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640032" - .to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 123, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_AV1.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=0".to_owned(), - rtcp_feedback: video_rtcp_feedback.clone(), - }, - payload_type: 41, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_HEVC.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: video_rtcp_feedback, - }, - payload_type: 126, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: "video/ulpfec".to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 116, - ..Default::default() - }, - ] { - self.register_codec(codec, RTPCodecType::Video)?; - } - - Ok(()) - } - - /// add_codec will append codec if it not exists - fn add_codec(codecs: &mut Vec, codec: RTCRtpCodecParameters) { - for c in codecs.iter() { - if c.capability.mime_type == codec.capability.mime_type - && c.payload_type == codec.payload_type - { - return; - } - } - codecs.push(codec); - } - - /// register_codec adds codec to the MediaEngine - /// These are the list of codecs supported by this PeerConnection. - /// register_codec is not safe for concurrent use. - pub fn register_codec( - &mut self, - mut codec: RTCRtpCodecParameters, - typ: RTPCodecType, - ) -> Result<()> { - codec.stats_id = format!( - "RTPCodec-{}", - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos() - ); - match typ { - RTPCodecType::Audio => { - MediaEngine::add_codec(&mut self.audio_codecs, codec); - Ok(()) - } - RTPCodecType::Video => { - MediaEngine::add_codec(&mut self.video_codecs, codec); - Ok(()) - } - _ => Err(Error::ErrUnknownType), - } - } - - /// Adds a header extension to the MediaEngine - /// To determine the negotiated value use [`MediaEngine::get_header_extension_id`] after signaling is complete. - /// - /// The `allowed_direction` controls for which transceiver directions the extension matches. If - /// set to `None` it matches all directions. The `SendRecv` direction would match all transceiver - /// directions apart from `Inactive`. Inactive only matches inactive. - pub fn register_header_extension( - &mut self, - extension: RTCRtpHeaderExtensionCapability, - typ: RTPCodecType, - allowed_direction: Option, - ) -> Result<()> { - let ext = { - match self - .header_extensions - .iter_mut() - .find(|ext| ext.uri == extension.uri) - { - Some(ext) => ext, - None => { - // We have registered too many extensions - if self.header_extensions.len() > VALID_EXT_IDS.end as usize { - return Err(Error::ErrRegisterHeaderExtensionNoFreeID); - } - self.header_extensions.push(MediaEngineHeaderExtension { - allowed_direction, - ..Default::default() - }); - - // Unwrap is fine because we just pushed - self.header_extensions.last_mut().unwrap() - } - } - }; - - if typ == RTPCodecType::Audio { - ext.is_audio = true; - } else if typ == RTPCodecType::Video { - ext.is_video = true; - } - - ext.uri = extension.uri; - - if ext.allowed_direction != allowed_direction { - return Err(Error::ErrRegisterHeaderExtensionInvalidDirection); - } - - Ok(()) - } - - /// register_feedback adds feedback mechanism to already registered codecs. - pub fn register_feedback(&mut self, feedback: RTCPFeedback, typ: RTPCodecType) { - match typ { - RTPCodecType::Video => { - for v in &mut self.video_codecs { - v.capability.rtcp_feedback.push(feedback.clone()); - } - } - RTPCodecType::Audio => { - for a in &mut self.audio_codecs { - a.capability.rtcp_feedback.push(feedback.clone()); - } - } - _ => {} - } - } - - /// get_header_extension_id returns the negotiated ID for a header extension. - /// If the Header Extension isn't enabled ok will be false - pub async fn get_header_extension_id( - &self, - extension: RTCRtpHeaderExtensionCapability, - ) -> (isize, bool, bool) { - let negotiated_header_extensions = self.negotiated_header_extensions.lock(); - if negotiated_header_extensions.is_empty() { - return (0, false, false); - } - - for (id, h) in &*negotiated_header_extensions { - if extension.uri == h.uri { - return (*id, h.is_audio, h.is_video); - } - } - - (0, false, false) - } - - /// clone_to copies any user modifiable state of the MediaEngine - /// all internal state is reset - pub(crate) fn clone_to(&self) -> Self { - MediaEngine { - video_codecs: self.video_codecs.clone(), - audio_codecs: self.audio_codecs.clone(), - header_extensions: self.header_extensions.clone(), - ..Default::default() - } - } - - pub(crate) async fn get_codec_by_payload( - &self, - payload_type: PayloadType, - ) -> Result<(RTCRtpCodecParameters, RTPCodecType)> { - { - let negotiated_video_codecs = self.negotiated_video_codecs.lock(); - for codec in &*negotiated_video_codecs { - if codec.payload_type == payload_type { - return Ok((codec.clone(), RTPCodecType::Video)); - } - } - } - { - let negotiated_audio_codecs = self.negotiated_audio_codecs.lock(); - for codec in &*negotiated_audio_codecs { - if codec.payload_type == payload_type { - return Ok((codec.clone(), RTPCodecType::Audio)); - } - } - } - - Err(Error::ErrCodecNotFound) - } - - pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { - let mut reports = HashMap::new(); - - for codec in &self.video_codecs { - reports.insert(codec.stats_id.clone(), Codec(CodecStats::from(codec))); - } - - for codec in &self.audio_codecs { - reports.insert(codec.stats_id.clone(), Codec(CodecStats::from(codec))); - } - - collector.merge(reports); - } - - /// Look up a codec and enable if it exists - pub(crate) fn match_remote_codec( - &self, - remote_codec: &RTCRtpCodecParameters, - typ: RTPCodecType, - exact_matches: &[RTCRtpCodecParameters], - partial_matches: &[RTCRtpCodecParameters], - ) -> Result { - let codecs = if typ == RTPCodecType::Audio { - &self.audio_codecs - } else { - &self.video_codecs - }; - - let remote_fmtp = fmtp::parse( - &remote_codec.capability.mime_type, - remote_codec.capability.sdp_fmtp_line.as_str(), - ); - if let Some(apt) = remote_fmtp.parameter("apt") { - let payload_type = apt.parse::()?; - - let mut apt_match = CodecMatch::None; - for codec in exact_matches { - if codec.payload_type == payload_type { - apt_match = CodecMatch::Exact; - break; - } - } - - if apt_match == CodecMatch::None { - for codec in partial_matches { - if codec.payload_type == payload_type { - apt_match = CodecMatch::Partial; - break; - } - } - } - - if apt_match == CodecMatch::None { - return Ok(CodecMatch::None); // not an error, we just ignore this codec we don't support - } - - // if apt's media codec is partial match, then apt codec must be partial match too - let (_, mut match_type) = codec_parameters_fuzzy_search(remote_codec, codecs); - if match_type == CodecMatch::Exact && apt_match == CodecMatch::Partial { - match_type = CodecMatch::Partial; - } - return Ok(match_type); - } - - let (_, match_type) = codec_parameters_fuzzy_search(remote_codec, codecs); - Ok(match_type) - } - - /// Look up a header extension and enable if it exists - pub(crate) async fn update_header_extension( - &self, - id: isize, - extension: &str, - typ: RTPCodecType, - ) -> Result<()> { - let mut negotiated_header_extensions = self.negotiated_header_extensions.lock(); - let mut proposed_header_extensions = self.proposed_header_extensions.lock(); - - for local_extension in &self.header_extensions { - if local_extension.uri != extension { - continue; - } - - let negotiated_ext = negotiated_header_extensions - .iter_mut() - .find(|(_, ext)| ext.uri == extension); - - if let Some(n_ext) = negotiated_ext { - if *n_ext.0 == id { - n_ext.1.is_video |= typ == RTPCodecType::Video; - n_ext.1.is_audio |= typ == RTPCodecType::Audio; - } else { - let nid = n_ext.0; - log::warn!("Invalid ext id mapping in update_header_extension. {} was negotiated as {}, but was {} in call", extension, nid, id); - } - } else { - // We either only have a proposal or we have neither proposal nor a negotiated id - // Accept whatevers the peer suggests - - if let Some(prev_ext) = negotiated_header_extensions.get(&id) { - let prev_uri = &prev_ext.uri; - log::warn!("Assigning {} to {} would override previous assignment to {}, no action taken", id, extension, prev_uri); - } else { - let h = MediaEngineHeaderExtension { - uri: extension.to_owned(), - is_audio: local_extension.is_audio && typ == RTPCodecType::Audio, - is_video: local_extension.is_video && typ == RTPCodecType::Video, - allowed_direction: local_extension.allowed_direction, - }; - negotiated_header_extensions.insert(id, h); - } - } - - // Clear any proposals we had for this id - proposed_header_extensions.remove(&id); - } - Ok(()) - } - - pub(crate) async fn push_codecs(&self, codecs: Vec, typ: RTPCodecType) { - for codec in codecs { - if typ == RTPCodecType::Audio { - let mut negotiated_audio_codecs = self.negotiated_audio_codecs.lock(); - MediaEngine::add_codec(&mut negotiated_audio_codecs, codec); - } else if typ == RTPCodecType::Video { - let mut negotiated_video_codecs = self.negotiated_video_codecs.lock(); - MediaEngine::add_codec(&mut negotiated_video_codecs, codec); - } - } - } - - /// Update the MediaEngine from a remote description - pub(crate) async fn update_from_remote_description( - &self, - desc: &SessionDescription, - ) -> Result<()> { - for media in &desc.media_descriptions { - let typ = if !self.negotiated_audio.load(Ordering::SeqCst) - && media.media_name.media.to_lowercase() == "audio" - { - self.negotiated_audio.store(true, Ordering::SeqCst); - RTPCodecType::Audio - } else if !self.negotiated_video.load(Ordering::SeqCst) - && media.media_name.media.to_lowercase() == "video" - { - self.negotiated_video.store(true, Ordering::SeqCst); - RTPCodecType::Video - } else { - continue; - }; - - let codecs = codecs_from_media_description(media)?; - - let mut exact_matches = vec![]; //make([]RTPCodecParameters, 0, len(codecs)) - let mut partial_matches = vec![]; //make([]RTPCodecParameters, 0, len(codecs)) - - for codec in codecs { - let match_type = - self.match_remote_codec(&codec, typ, &exact_matches, &partial_matches)?; - - if match_type == CodecMatch::Exact { - exact_matches.push(codec); - } else if match_type == CodecMatch::Partial { - partial_matches.push(codec); - } - } - - // use exact matches when they exist, otherwise fall back to partial - if !exact_matches.is_empty() { - self.push_codecs(exact_matches, typ).await; - } else if !partial_matches.is_empty() { - self.push_codecs(partial_matches, typ).await; - } else { - // no match, not negotiated - continue; - } - - let extensions = rtp_extensions_from_media_description(media)?; - - for (extension, id) in extensions { - self.update_header_extension(id, &extension, typ).await?; - } - } - - Ok(()) - } - - pub(crate) fn get_codecs_by_kind(&self, typ: RTPCodecType) -> Vec { - if typ == RTPCodecType::Video { - if self.negotiated_video.load(Ordering::SeqCst) { - let negotiated_video_codecs = self.negotiated_video_codecs.lock(); - negotiated_video_codecs.clone() - } else { - self.video_codecs.clone() - } - } else if typ == RTPCodecType::Audio { - if self.negotiated_audio.load(Ordering::SeqCst) { - let negotiated_audio_codecs = self.negotiated_audio_codecs.lock(); - negotiated_audio_codecs.clone() - } else { - self.audio_codecs.clone() - } - } else { - vec![] - } - } - - pub(crate) fn get_rtp_parameters_by_kind( - &self, - typ: RTPCodecType, - direction: RTCRtpTransceiverDirection, - ) -> RTCRtpParameters { - let mut header_extensions = vec![]; - - if self.negotiated_video.load(Ordering::SeqCst) && typ == RTPCodecType::Video - || self.negotiated_audio.load(Ordering::SeqCst) && typ == RTPCodecType::Audio - { - let negotiated_header_extensions = self.negotiated_header_extensions.lock(); - for (id, e) in &*negotiated_header_extensions { - if e.is_matching_direction(direction) - && (e.is_audio && typ == RTPCodecType::Audio - || e.is_video && typ == RTPCodecType::Video) - { - header_extensions.push(RTCRtpHeaderExtensionParameters { - id: *id, - uri: e.uri.clone(), - }); - } - } - } else { - let mut proposed_header_extensions = self.proposed_header_extensions.lock(); - let mut negotiated_header_extensions = self.negotiated_header_extensions.lock(); - - for local_extension in &self.header_extensions { - let relevant = local_extension.is_matching_direction(direction) - && (local_extension.is_audio && typ == RTPCodecType::Audio - || local_extension.is_video && typ == RTPCodecType::Video); - - if !relevant { - continue; - } - - if let Some((id, negotiated_extension)) = negotiated_header_extensions - .iter_mut() - .find(|(_, e)| e.uri == local_extension.uri) - { - // We have previously negotiated this extension, make sure to record it as - // active for the current type - negotiated_extension.is_audio |= typ == RTPCodecType::Audio; - negotiated_extension.is_video |= typ == RTPCodecType::Video; - - header_extensions.push(RTCRtpHeaderExtensionParameters { - id: *id, - uri: negotiated_extension.uri.clone(), - }); - - continue; - } - - if let Some((id, negotiated_extension)) = proposed_header_extensions - .iter_mut() - .find(|(_, e)| e.uri == local_extension.uri) - { - // We have previously proposed this extension, re-use it - header_extensions.push(RTCRtpHeaderExtensionParameters { - id: *id, - uri: negotiated_extension.uri.clone(), - }); - - continue; - } - - // Figure out which (unused id) to propose. - let id = VALID_EXT_IDS.clone().find(|id| { - !negotiated_header_extensions.keys().any(|nid| nid == id) - && !proposed_header_extensions.keys().any(|pid| pid == id) - }); - - if let Some(id) = id { - proposed_header_extensions.insert( - id, - MediaEngineHeaderExtension { - uri: local_extension.uri.clone(), - is_audio: local_extension.is_audio, - is_video: local_extension.is_video, - allowed_direction: local_extension.allowed_direction, - }, - ); - - header_extensions.push(RTCRtpHeaderExtensionParameters { - id, - uri: local_extension.uri.clone(), - }); - } else { - log::warn!("No available RTP extension ID for {}", local_extension.uri); - } - } - } - - RTCRtpParameters { - header_extensions, - codecs: self.get_codecs_by_kind(typ), - } - } - - pub(crate) async fn get_rtp_parameters_by_payload_type( - &self, - payload_type: PayloadType, - ) -> Result { - let (codec, typ) = self.get_codec_by_payload(payload_type).await?; - - let mut header_extensions = vec![]; - { - let negotiated_header_extensions = self.negotiated_header_extensions.lock(); - for (id, e) in &*negotiated_header_extensions { - if e.is_audio && typ == RTPCodecType::Audio - || e.is_video && typ == RTPCodecType::Video - { - header_extensions.push(RTCRtpHeaderExtensionParameters { - uri: e.uri.clone(), - id: *id, - }); - } - } - } - - Ok(RTCRtpParameters { - header_extensions, - codecs: vec![codec], - }) - } -} diff --git a/webrtc/src/api/mod.rs b/webrtc/src/api/mod.rs deleted file mode 100644 index 2252dfbb2..000000000 --- a/webrtc/src/api/mod.rs +++ /dev/null @@ -1,238 +0,0 @@ -#[cfg(test)] -mod api_test; - -pub mod interceptor_registry; -pub mod media_engine; -pub mod setting_engine; - -use std::sync::Arc; -use std::time::SystemTime; - -use interceptor::registry::Registry; -use interceptor::Interceptor; -use media_engine::*; -use rcgen::KeyPair; -use setting_engine::*; - -use crate::data_channel::data_channel_parameters::DataChannelParameters; -use crate::data_channel::RTCDataChannel; -use crate::dtls_transport::RTCDtlsTransport; -use crate::error::{Error, Result}; -use crate::ice_transport::ice_gatherer::{RTCIceGatherOptions, RTCIceGatherer}; -use crate::ice_transport::RTCIceTransport; -use crate::peer_connection::certificate::RTCCertificate; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::peer_connection::RTCPeerConnection; -use crate::rtp_transceiver::rtp_codec::RTPCodecType; -use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver; -use crate::rtp_transceiver::rtp_sender::RTCRtpSender; -use crate::sctp_transport::RTCSctpTransport; -use crate::track::track_local::TrackLocal; - -/// API bundles the global functions of the WebRTC and ORTC API. -/// Some of these functions are also exported globally using the -/// defaultAPI object. Note that the global version of the API -/// may be phased out in the future. -pub struct API { - pub(crate) setting_engine: Arc, - pub(crate) media_engine: Arc, - pub(crate) interceptor_registry: Registry, -} - -impl API { - /// new_peer_connection creates a new PeerConnection with the provided configuration against the received API object - pub async fn new_peer_connection( - &self, - configuration: RTCConfiguration, - ) -> Result { - RTCPeerConnection::new(self, configuration).await - } - - /// new_ice_gatherer creates a new ice gatherer. - /// This constructor is part of the ORTC API. It is not - /// meant to be used together with the basic WebRTC API. - pub fn new_ice_gatherer(&self, opts: RTCIceGatherOptions) -> Result { - let mut validated_servers = vec![]; - if !opts.ice_servers.is_empty() { - for server in &opts.ice_servers { - let url = server.urls()?; - validated_servers.extend(url); - } - } - - Ok(RTCIceGatherer::new( - validated_servers, - opts.ice_gather_policy, - Arc::clone(&self.setting_engine), - )) - } - - /// new_ice_transport creates a new ice transport. - /// This constructor is part of the ORTC API. It is not - /// meant to be used together with the basic WebRTC API. - pub fn new_ice_transport(&self, gatherer: Arc) -> RTCIceTransport { - RTCIceTransport::new(gatherer) - } - - /// new_dtls_transport creates a new dtls_transport transport. - /// This constructor is part of the ORTC API. It is not - /// meant to be used together with the basic WebRTC API. - pub fn new_dtls_transport( - &self, - ice_transport: Arc, - mut certificates: Vec, - ) -> Result { - if !certificates.is_empty() { - let now = SystemTime::now(); - for cert in &certificates { - cert.expires - .duration_since(now) - .map_err(|_| Error::ErrCertificateExpired)?; - } - } else { - let kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let cert = RTCCertificate::from_key_pair(kp)?; - certificates = vec![cert]; - }; - - Ok(RTCDtlsTransport::new( - ice_transport, - certificates, - Arc::clone(&self.setting_engine), - )) - } - - /// new_sctp_transport creates a new SCTPTransport. - /// This constructor is part of the ORTC API. It is not - /// meant to be used together with the basic WebRTC API. - pub fn new_sctp_transport( - &self, - dtls_transport: Arc, - ) -> Result { - Ok(RTCSctpTransport::new( - dtls_transport, - Arc::clone(&self.setting_engine), - )) - } - - /// new_data_channel creates a new DataChannel. - /// This constructor is part of the ORTC API. It is not - /// meant to be used together with the basic WebRTC API. - pub async fn new_data_channel( - &self, - sctp_transport: Arc, - params: DataChannelParameters, - ) -> Result { - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #5) - if params.label.len() > 65535 { - return Err(Error::ErrStringSizeLimit); - } - - let d = RTCDataChannel::new(params, Arc::clone(&self.setting_engine)); - d.open(sctp_transport).await?; - - Ok(d) - } - - /// new_rtp_receiver constructs a new RTPReceiver - pub fn new_rtp_receiver( - &self, - kind: RTPCodecType, - transport: Arc, - interceptor: Arc, - ) -> RTCRtpReceiver { - RTCRtpReceiver::new( - self.setting_engine.get_receive_mtu(), - kind, - transport, - Arc::clone(&self.media_engine), - interceptor, - ) - } - - /// new_rtp_sender constructs a new RTPSender - pub async fn new_rtp_sender( - &self, - track: Option>, - transport: Arc, - interceptor: Arc, - ) -> RTCRtpSender { - let kind = track.as_ref().map(|t| t.kind()).unwrap_or_default(); - RTCRtpSender::new( - self.setting_engine.get_receive_mtu(), - track, - kind, - transport, - Arc::clone(&self.media_engine), - interceptor, - false, - ) - .await - } - - /// Returns the internal [`SettingEngine`]. - pub fn setting_engine(&self) -> Arc { - Arc::clone(&self.setting_engine) - } - - /// Returns the internal [`MediaEngine`]. - pub fn media_engine(&self) -> Arc { - Arc::clone(&self.media_engine) - } -} - -#[derive(Default)] -pub struct APIBuilder { - setting_engine: Option>, - media_engine: Option>, - interceptor_registry: Option, -} - -impl APIBuilder { - pub fn new() -> Self { - APIBuilder::default() - } - - pub fn build(mut self) -> API { - API { - setting_engine: if let Some(setting_engine) = self.setting_engine.take() { - setting_engine - } else { - Arc::new(SettingEngine::default()) - }, - media_engine: if let Some(media_engine) = self.media_engine.take() { - media_engine - } else { - Arc::new(MediaEngine::default()) - }, - interceptor_registry: if let Some(interceptor_registry) = - self.interceptor_registry.take() - { - interceptor_registry - } else { - Registry::new() - }, - } - } - - /// WithSettingEngine allows providing a SettingEngine to the API. - /// Settings should not be changed after passing the engine to an API. - pub fn with_setting_engine(mut self, setting_engine: SettingEngine) -> Self { - self.setting_engine = Some(Arc::new(setting_engine)); - self - } - - /// WithMediaEngine allows providing a MediaEngine to the API. - /// Settings can be changed after passing the engine to an API. - pub fn with_media_engine(mut self, media_engine: MediaEngine) -> Self { - self.media_engine = Some(Arc::new(media_engine)); - self - } - - /// with_interceptor_registry allows providing Interceptors to the API. - /// Settings should not be changed after passing the registry to an API. - pub fn with_interceptor_registry(mut self, interceptor_registry: Registry) -> Self { - self.interceptor_registry = Some(interceptor_registry); - self - } -} diff --git a/webrtc/src/api/setting_engine/mod.rs b/webrtc/src/api/setting_engine/mod.rs deleted file mode 100644 index 6d909d3ae..000000000 --- a/webrtc/src/api/setting_engine/mod.rs +++ /dev/null @@ -1,327 +0,0 @@ -#[cfg(test)] -mod setting_engine_test; - -use std::sync::Arc; - -use dtls::extension::extension_use_srtp::SrtpProtectionProfile; -use ice::agent::agent_config::{InterfaceFilterFn, IpFilterFn}; -use ice::mdns::MulticastDnsMode; -use ice::network_type::NetworkType; -use ice::udp_network::UDPNetwork; -use tokio::time::Duration; -use util::vnet::net::*; - -use crate::dtls_transport::dtls_role::DTLSRole; -use crate::error::{Error, Result}; -use crate::ice_transport::ice_candidate_type::RTCIceCandidateType; -use crate::RECEIVE_MTU; - -#[derive(Default, Clone)] -pub struct Detach { - pub data_channels: bool, -} - -#[derive(Default, Clone)] -pub struct Timeout { - pub ice_disconnected_timeout: Option, - pub ice_failed_timeout: Option, - pub ice_keepalive_interval: Option, - pub ice_host_acceptance_min_wait: Option, - pub ice_srflx_acceptance_min_wait: Option, - pub ice_prflx_acceptance_min_wait: Option, - pub ice_relay_acceptance_min_wait: Option, -} - -#[derive(Default, Clone)] -pub struct Candidates { - pub ice_lite: bool, - pub ice_network_types: Vec, - pub interface_filter: Arc>, - pub ip_filter: Arc>, - pub nat_1to1_ips: Vec, - pub nat_1to1_ip_candidate_type: RTCIceCandidateType, - pub multicast_dns_mode: MulticastDnsMode, - pub multicast_dns_host_name: String, - pub username_fragment: String, - pub password: String, -} - -#[derive(Default, Clone)] -pub struct ReplayProtection { - pub dtls: usize, - pub srtp: usize, - pub srtcp: usize, -} - -/// SettingEngine allows influencing behavior in ways that are not -/// supported by the WebRTC API. This allows us to support additional -/// use-cases without deviating from the WebRTC API elsewhere. -#[derive(Default, Clone)] -pub struct SettingEngine { - pub(crate) detach: Detach, - pub(crate) timeout: Timeout, - pub(crate) candidates: Candidates, - pub(crate) replay_protection: ReplayProtection, - pub(crate) sdp_media_level_fingerprints: bool, - pub(crate) answering_dtls_role: DTLSRole, - pub(crate) disable_certificate_fingerprint_verification: bool, - pub(crate) allow_insecure_verification_algorithm: bool, - pub(crate) disable_srtp_replay_protection: bool, - pub(crate) disable_srtcp_replay_protection: bool, - pub(crate) vnet: Option>, - //BufferFactory :func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser, - //iceTCPMux :ice.TCPMux,? - //iceProxyDialer :proxy.Dialer,? - pub(crate) udp_network: UDPNetwork, - pub(crate) disable_media_engine_copy: bool, - pub(crate) srtp_protection_profiles: Vec, - pub(crate) receive_mtu: usize, - pub(crate) mid_generator: Option String + Send + Sync>>, -} - -impl SettingEngine { - /// get_receive_mtu returns the configured MTU. If SettingEngine's MTU is configured to 0 it returns the default - pub(crate) fn get_receive_mtu(&self) -> usize { - if self.receive_mtu != 0 { - self.receive_mtu - } else { - RECEIVE_MTU - } - } - /// detach_data_channels enables detaching data channels. When enabled - /// data channels have to be detached in the OnOpen callback using the - /// DataChannel.Detach method. - pub fn detach_data_channels(&mut self) { - self.detach.data_channels = true; - } - - /// set_srtp_protection_profiles allows the user to override the default srtp Protection Profiles - /// The default srtp protection profiles are provided by the function `defaultSrtpProtectionProfiles` - pub fn set_srtp_protection_profiles(&mut self, profiles: Vec) { - self.srtp_protection_profiles = profiles - } - - /// set_ice_timeouts sets the behavior around ICE Timeouts - /// * disconnected_timeout is the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds - /// * failed_timeout is the duration without network activity before a Agent is considered failed after disconnected. Default is 25 Seconds - /// * keep_alive_interval is how often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. Default is 2 seconds - pub fn set_ice_timeouts( - &mut self, - disconnected_timeout: Option, - failed_timeout: Option, - keep_alive_interval: Option, - ) { - self.timeout.ice_disconnected_timeout = disconnected_timeout; - self.timeout.ice_failed_timeout = failed_timeout; - self.timeout.ice_keepalive_interval = keep_alive_interval; - } - - /// set_host_acceptance_min_wait sets the icehost_acceptance_min_wait - pub fn set_host_acceptance_min_wait(&mut self, t: Option) { - self.timeout.ice_host_acceptance_min_wait = t; - } - - /// set_srflx_acceptance_min_wait sets the icesrflx_acceptance_min_wait - pub fn set_srflx_acceptance_min_wait(&mut self, t: Option) { - self.timeout.ice_srflx_acceptance_min_wait = t; - } - - /// set_prflx_acceptance_min_wait sets the iceprflx_acceptance_min_wait - pub fn set_prflx_acceptance_min_wait(&mut self, t: Option) { - self.timeout.ice_prflx_acceptance_min_wait = t; - } - - /// set_relay_acceptance_min_wait sets the icerelay_acceptance_min_wait - pub fn set_relay_acceptance_min_wait(&mut self, t: Option) { - self.timeout.ice_relay_acceptance_min_wait = t; - } - - /// set_udp_network allows ICE traffic to come through Ephemeral or UDPMux. - /// UDPMux drastically simplifying deployments where ports will need to be opened/forwarded. - /// UDPMux should be started prior to creating PeerConnections. - pub fn set_udp_network(&mut self, udp_network: UDPNetwork) { - self.udp_network = udp_network; - } - - /// set_lite configures whether or not the ice agent should be a lite agent - pub fn set_lite(&mut self, lite: bool) { - self.candidates.ice_lite = lite; - } - - /// set_network_types configures what types of candidate networks are supported - /// during local and server reflexive gathering. - pub fn set_network_types(&mut self, candidate_types: Vec) { - self.candidates.ice_network_types = candidate_types; - } - - /// set_interface_filter sets the filtering functions when gathering ICE candidates - /// This can be used to exclude certain network interfaces from ICE. Which may be - /// useful if you know a certain interface will never succeed, or if you wish to reduce - /// the amount of information you wish to expose to the remote peer - pub fn set_interface_filter(&mut self, filter: InterfaceFilterFn) { - self.candidates.interface_filter = Arc::new(Some(filter)); - } - - /// set_ip_filter sets the filtering functions when gathering ICE candidates - /// This can be used to exclude certain ip from ICE. Which may be - /// useful if you know a certain ip will never succeed, or if you wish to reduce - /// the amount of information you wish to expose to the remote peer - pub fn set_ip_filter(&mut self, filter: IpFilterFn) { - self.candidates.ip_filter = Arc::new(Some(filter)); - } - - /// set_nat_1to1_ips sets a list of external IP addresses of 1:1 (D)NAT - /// and a candidate type for which the external IP address is used. - /// This is useful when you are host a server using Pion on an AWS EC2 instance - /// which has a private address, behind a 1:1 DNAT with a public IP (e.g. - /// Elastic IP). In this case, you can give the public IP address so that - /// Pion will use the public IP address in its candidate instead of the private - /// IP address. The second argument, candidate_type, is used to tell Pion which - /// type of candidate should use the given public IP address. - /// Two types of candidates are supported: - /// - /// ICECandidateTypeHost: - /// The public IP address will be used for the host candidate in the SDP. - /// ICECandidateTypeSrflx: - /// A server reflexive candidate with the given public IP address will be added - /// to the SDP. - /// - /// Please note that if you choose ICECandidateTypeHost, then the private IP address - /// won't be advertised with the peer. Also, this option cannot be used along with mDNS. - /// - /// If you choose ICECandidateTypeSrflx, it simply adds a server reflexive candidate - /// with the public IP. The host candidate is still available along with mDNS - /// capabilities unaffected. Also, you cannot give STUN server URL at the same time. - /// It will result in an error otherwise. - pub fn set_nat_1to1_ips(&mut self, ips: Vec, candidate_type: RTCIceCandidateType) { - self.candidates.nat_1to1_ips = ips; - self.candidates.nat_1to1_ip_candidate_type = candidate_type; - } - - /// set_answering_dtls_role sets the dtls_transport role that is selected when offering - /// The dtls_transport role controls if the WebRTC Client as a client or server. This - /// may be useful when interacting with non-compliant clients or debugging issues. - /// - /// DTLSRoleActive: - /// Act as dtls_transport Client, send the ClientHello and starts the handshake - /// DTLSRolePassive: - /// Act as dtls_transport Server, wait for ClientHello - pub fn set_answering_dtls_role(&mut self, role: DTLSRole) -> Result<()> { - if role != DTLSRole::Client && role != DTLSRole::Server { - return Err(Error::ErrSettingEngineSetAnsweringDTLSRole); - } - - self.answering_dtls_role = role; - Ok(()) - } - - /// set_vnet sets the VNet instance that is passed to ice - /// VNet is a virtual network layer, allowing users to simulate - /// different topologies, latency, loss and jitter. This can be useful for - /// learning WebRTC concepts or testing your application in a lab environment - pub fn set_vnet(&mut self, vnet: Option>) { - self.vnet = vnet; - } - - /// set_ice_multicast_dns_mode controls if ice queries and generates mDNS ICE Candidates - pub fn set_ice_multicast_dns_mode(&mut self, multicast_dns_mode: ice::mdns::MulticastDnsMode) { - self.candidates.multicast_dns_mode = multicast_dns_mode - } - - /// set_multicast_dns_host_name sets a static HostName to be used by ice instead of generating one on startup - /// This should only be used for a single PeerConnection. Having multiple PeerConnections with the same HostName will cause - /// undefined behavior - pub fn set_multicast_dns_host_name(&mut self, host_name: String) { - self.candidates.multicast_dns_host_name = host_name; - } - - /// set_ice_credentials sets a staic uFrag/uPwd to be used by ice - /// This is useful if you want to do signalless WebRTC session, or having a reproducible environment with static credentials - pub fn set_ice_credentials(&mut self, username_fragment: String, password: String) { - self.candidates.username_fragment = username_fragment; - self.candidates.password = password; - } - - /// disable_certificate_fingerprint_verification disables fingerprint verification after dtls_transport Handshake has finished - pub fn disable_certificate_fingerprint_verification(&mut self, is_disabled: bool) { - self.disable_certificate_fingerprint_verification = is_disabled; - } - - /// allow_insecure_verification_algorithm allows the usage of certain signature verification - /// algorithm that are known to be vulnerable or deprecated. - pub fn allow_insecure_verification_algorithm(&mut self, is_allowed: bool) { - self.allow_insecure_verification_algorithm = is_allowed; - } - /// set_dtls_replay_protection_window sets a replay attack protection window size of dtls_transport connection. - pub fn set_dtls_replay_protection_window(&mut self, n: usize) { - self.replay_protection.dtls = n; - } - - /// set_srtp_replay_protection_window sets a replay attack protection window size of srtp session. - pub fn set_srtp_replay_protection_window(&mut self, n: usize) { - self.disable_srtp_replay_protection = false; - self.replay_protection.srtp = n; - } - - /// set_srtcp_replay_protection_window sets a replay attack protection window size of srtcp session. - pub fn set_srtcp_replay_protection_window(&mut self, n: usize) { - self.disable_srtcp_replay_protection = false; - self.replay_protection.srtcp = n; - } - - /// disable_srtp_replay_protection disables srtp replay protection. - pub fn disable_srtp_replay_protection(&mut self, is_disabled: bool) { - self.disable_srtp_replay_protection = is_disabled; - } - - /// disable_srtcp_replay_protection disables srtcp replay protection. - pub fn disable_srtcp_replay_protection(&mut self, is_disabled: bool) { - self.disable_srtcp_replay_protection = is_disabled; - } - - /// set_sdp_media_level_fingerprints configures the logic for dtls_transport Fingerprint insertion - /// If true, fingerprints will be inserted in the sdp at the fingerprint - /// level, instead of the session level. This helps with compatibility with - /// some webrtc implementations. - pub fn set_sdp_media_level_fingerprints(&mut self, sdp_media_level_fingerprints: bool) { - self.sdp_media_level_fingerprints = sdp_media_level_fingerprints; - } - - // SetICETCPMux enables ICE-TCP when set to a non-nil value. Make sure that - // NetworkTypeTCP4 or NetworkTypeTCP6 is enabled as well. - //pub fn SetICETCPMux(&mut self, tcpMux ice.TCPMux) { - // self.iceTCPMux = tcpMux - //} - - // SetICEProxyDialer sets the proxy dialer interface based on golang.org/x/net/proxy. - //pub fn SetICEProxyDialer(&mut self, d proxy.Dialer) { - // self.iceProxyDialer = d - //} - - /// disable_media_engine_copy stops the MediaEngine from being copied. This allows a user to modify - /// the MediaEngine after the PeerConnection has been constructed. This is useful if you wish to - /// modify codecs after signaling. Make sure not to share MediaEngines between PeerConnections. - pub fn disable_media_engine_copy(&mut self, is_disabled: bool) { - self.disable_media_engine_copy = is_disabled; - } - - /// set_receive_mtu sets the size of read buffer that copies incoming packets. This is optional. - /// Leave this 0 for the default receive_mtu - pub fn set_receive_mtu(&mut self, receive_mtu: usize) { - self.receive_mtu = receive_mtu; - } - - /// Sets a callback used to generate mid for transceivers created by this side of the RTCPeerconnection. - /// By having separate "naming schemes" for mids generated by either side of a connection, it's - /// possible to reduce complexity when handling SDP offers/answers clashing. - /// - /// The `isize` argument is currently greatest seen _numeric_ mid. Since mids don't need to be numeric - /// this doesn't necessarily indicating anything. - /// - /// Note that the spec says: All MID values MUST be generated in a fashion that does not leak user - /// information, e.g., randomly or using a per-PeerConnection counter, and SHOULD be 3 bytes or less, - /// to allow them to efficiently fit into the RTP header extension - pub fn set_mid_generator(&mut self, f: impl Fn(isize) -> String + Send + Sync + 'static) { - self.mid_generator = Some(Arc::new(f)); - } -} diff --git a/webrtc/src/api/setting_engine/setting_engine_test.rs b/webrtc/src/api/setting_engine/setting_engine_test.rs deleted file mode 100644 index cb5433f58..000000000 --- a/webrtc/src/api/setting_engine/setting_engine_test.rs +++ /dev/null @@ -1,271 +0,0 @@ -use std::sync::atomic::Ordering; - -use super::*; -use crate::api::media_engine::MediaEngine; -use crate::api::APIBuilder; -use crate::peer_connection::peer_connection_test::*; -use crate::rtp_transceiver::rtp_codec::RTPCodecType; - -#[test] -fn test_set_connection_timeout() -> Result<()> { - let mut s = SettingEngine::default(); - - assert_eq!(s.timeout.ice_disconnected_timeout, None); - assert_eq!(s.timeout.ice_failed_timeout, None); - assert_eq!(s.timeout.ice_keepalive_interval, None); - - s.set_ice_timeouts( - Some(Duration::from_secs(1)), - Some(Duration::from_secs(2)), - Some(Duration::from_secs(3)), - ); - assert_eq!( - s.timeout.ice_disconnected_timeout, - Some(Duration::from_secs(1)) - ); - assert_eq!(s.timeout.ice_failed_timeout, Some(Duration::from_secs(2))); - assert_eq!( - s.timeout.ice_keepalive_interval, - Some(Duration::from_secs(3)) - ); - - Ok(()) -} - -#[test] -fn test_detach_data_channels() -> Result<()> { - let mut s = SettingEngine::default(); - - assert!( - !s.detach.data_channels, - "SettingEngine defaults aren't as expected." - ); - - s.detach_data_channels(); - - assert!( - s.detach.data_channels, - "Failed to enable detached data channels." - ); - - Ok(()) -} - -#[test] -fn test_set_nat_1to1_ips() -> Result<()> { - let mut s = SettingEngine::default(); - - assert!( - s.candidates.nat_1to1_ips.is_empty(), - "Invalid default value" - ); - assert!( - s.candidates.nat_1to1_ip_candidate_type == RTCIceCandidateType::Unspecified, - "Invalid default value" - ); - - let ips = vec!["1.2.3.4".to_owned()]; - let typ = RTCIceCandidateType::Host; - s.set_nat_1to1_ips(ips, typ); - assert!( - !(s.candidates.nat_1to1_ips.len() != 1 || s.candidates.nat_1to1_ips[0] != "1.2.3.4"), - "Failed to set NAT1To1IPs" - ); - assert!( - s.candidates.nat_1to1_ip_candidate_type == typ, - "Failed to set NAT1To1IPCandidateType" - ); - - Ok(()) -} - -#[test] -fn test_set_answering_dtls_role() -> Result<()> { - let mut s = SettingEngine::default(); - assert!( - s.set_answering_dtls_role(DTLSRole::Auto).is_err(), - "SetAnsweringDTLSRole can only be called with DTLSRoleClient or DTLSRoleServer" - ); - assert!( - s.set_answering_dtls_role(DTLSRole::Unspecified).is_err(), - "SetAnsweringDTLSRole can only be called with DTLSRoleClient or DTLSRoleServer" - ); - - Ok(()) -} - -#[test] -fn test_set_replay_protection() -> Result<()> { - let mut s = SettingEngine::default(); - - assert!( - !(s.replay_protection.dtls != 0 - || s.replay_protection.srtp != 0 - || s.replay_protection.srtcp != 0), - "SettingEngine defaults aren't as expected." - ); - - s.set_dtls_replay_protection_window(128); - s.set_srtp_replay_protection_window(64); - s.set_srtcp_replay_protection_window(32); - - assert!( - !(s.replay_protection.dtls == 0 || s.replay_protection.dtls != 128), - "Failed to set DTLS replay protection window" - ); - assert!( - !(s.replay_protection.srtp == 0 || s.replay_protection.srtp != 64), - "Failed to set SRTP replay protection window" - ); - assert!( - !(s.replay_protection.srtcp == 0 || s.replay_protection.srtcp != 32), - "Failed to set SRTCP replay protection window" - ); - - Ok(()) -} - -/*TODO:#[test] fn test_setting_engine_set_ice_tcp_mux() ->Result<()> { - - listener, err := net.ListenTCP("tcp", &net.TCPAddr{}) - if err != nil { - panic(err) - } - - defer func() { - _ = listener.Close() - }() - - tcpMux := NewICETCPMux(nil, listener, 8) - - defer func() { - _ = tcpMux.Close() - }() - - let mut s = SettingEngine::default(); - settingEngine.SetICETCPMux(tcpMux) - - assert.Equal(t, tcpMux, settingEngine.iceTCPMux) - - Ok(()) -} -*/ - -#[tokio::test] -async fn test_setting_engine_set_disable_media_engine_copy() -> Result<()> { - //"Copy" - { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut offerer, mut answerer) = new_pair(&api).await?; - - offerer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - signal_pair(&mut offerer, &mut answerer).await?; - - // Assert that the MediaEngine the user created isn't modified - assert!(!api.media_engine.negotiated_video.load(Ordering::SeqCst)); - { - let negotiated_video_codecs = api.media_engine.negotiated_video_codecs.lock(); - assert!(negotiated_video_codecs.is_empty()); - } - - // Assert that the internal MediaEngine is modified - assert!(offerer - .internal - .media_engine - .negotiated_video - .load(Ordering::SeqCst)); - { - let negotiated_video_codecs = - offerer.internal.media_engine.negotiated_video_codecs.lock(); - assert!(!negotiated_video_codecs.is_empty()); - } - - close_pair_now(&offerer, &answerer).await; - - let (new_offerer, new_answerer) = new_pair(&api).await?; - - // Assert that the first internal MediaEngine hasn't been cleared - assert!(offerer - .internal - .media_engine - .negotiated_video - .load(Ordering::SeqCst)); - { - let negotiated_video_codecs = - offerer.internal.media_engine.negotiated_video_codecs.lock(); - assert!(!negotiated_video_codecs.is_empty()); - } - - // Assert that the new internal MediaEngine isn't modified - assert!(!new_offerer - .internal - .media_engine - .negotiated_video - .load(Ordering::SeqCst)); - { - let negotiated_video_codecs = new_offerer - .internal - .media_engine - .negotiated_video_codecs - .lock(); - assert!(negotiated_video_codecs.is_empty()); - } - - close_pair_now(&new_offerer, &new_answerer).await; - } - - //"No Copy" - { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let mut s = SettingEngine::default(); - s.disable_media_engine_copy(true); - - let api = APIBuilder::new() - .with_media_engine(m) - .with_setting_engine(s) - .build(); - - let (mut offerer, mut answerer) = new_pair(&api).await?; - - offerer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - signal_pair(&mut offerer, &mut answerer).await?; - - // Assert that the user MediaEngine was modified, so no copy happened - assert!(api.media_engine.negotiated_video.load(Ordering::SeqCst)); - { - let negotiated_video_codecs = api.media_engine.negotiated_video_codecs.lock(); - assert!(!negotiated_video_codecs.is_empty()); - } - - close_pair_now(&offerer, &answerer).await; - - let (offerer, answerer) = new_pair(&api).await?; - - // Assert that the new internal MediaEngine was modified, so no copy happened - assert!(offerer - .internal - .media_engine - .negotiated_video - .load(Ordering::SeqCst)); - { - let negotiated_video_codecs = - offerer.internal.media_engine.negotiated_video_codecs.lock(); - assert!(!negotiated_video_codecs.is_empty()); - } - - close_pair_now(&offerer, &answerer).await; - } - - Ok(()) -} diff --git a/webrtc/src/data_channel/data_channel_init.rs b/webrtc/src/data_channel/data_channel_init.rs deleted file mode 100644 index 5adbdb721..000000000 --- a/webrtc/src/data_channel/data_channel_init.rs +++ /dev/null @@ -1,29 +0,0 @@ -/// DataChannelConfig can be used to configure properties of the underlying -/// channel such as data reliability. -#[derive(Default, Debug, Clone)] -pub struct RTCDataChannelInit { - /// ordered indicates if data is allowed to be delivered out of order. The - /// default value of true, guarantees that data will be delivered in order. - pub ordered: Option, - - /// max_packet_life_time limits the time (in milliseconds) during which the - /// channel will transmit or retransmit data if not acknowledged. This value - /// may be clamped if it exceeds the maximum value supported. - pub max_packet_life_time: Option, - - /// max_retransmits limits the number of times a channel will retransmit data - /// if not successfully delivered. This value may be clamped if it exceeds - /// the maximum value supported. - pub max_retransmits: Option, - - /// protocol describes the subprotocol name used for this channel. - pub protocol: Option, - - /// negotiated describes if the data channel is created by the local peer or - /// the remote peer. The default value of None tells the user agent to - /// announce the channel in-band and instruct the other peer to dispatch a - /// corresponding DataChannel. If set to Some(id), it is up to the application - /// to negotiate the channel and create an DataChannel with the same id - /// at the other peer. - pub negotiated: Option, -} diff --git a/webrtc/src/data_channel/data_channel_message.rs b/webrtc/src/data_channel/data_channel_message.rs deleted file mode 100644 index a781ec431..000000000 --- a/webrtc/src/data_channel/data_channel_message.rs +++ /dev/null @@ -1,11 +0,0 @@ -use bytes::Bytes; - -/// DataChannelMessage represents a message received from the -/// data channel. IsString will be set to true if the incoming -/// message is of the string type. Otherwise the message is of -/// a binary type. -#[derive(Default, Debug, Clone)] -pub struct DataChannelMessage { - pub is_string: bool, - pub data: Bytes, -} diff --git a/webrtc/src/data_channel/data_channel_parameters.rs b/webrtc/src/data_channel/data_channel_parameters.rs deleted file mode 100644 index 88f116ed3..000000000 --- a/webrtc/src/data_channel/data_channel_parameters.rs +++ /dev/null @@ -1,12 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// DataChannelParameters describes the configuration of the DataChannel. -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct DataChannelParameters { - pub label: String, - pub protocol: String, - pub ordered: bool, - pub max_packet_life_time: u16, - pub max_retransmits: u16, - pub negotiated: Option, -} diff --git a/webrtc/src/data_channel/data_channel_state.rs b/webrtc/src/data_channel/data_channel_state.rs deleted file mode 100644 index 38ea04cb0..000000000 --- a/webrtc/src/data_channel/data_channel_state.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// DataChannelState indicates the state of a data channel. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum RTCDataChannelState { - #[serde(rename = "unspecified")] - #[default] - Unspecified = 0, - - /// DataChannelStateConnecting indicates that the data channel is being - /// established. This is the initial state of DataChannel, whether created - /// with create_data_channel, or dispatched as a part of an DataChannelEvent. - #[serde(rename = "connecting")] - Connecting, - - /// DataChannelStateOpen indicates that the underlying data transport is - /// established and communication is possible. - #[serde(rename = "open")] - Open, - - /// DataChannelStateClosing indicates that the procedure to close down the - /// underlying data transport has started. - #[serde(rename = "closing")] - Closing, - - /// DataChannelStateClosed indicates that the underlying data transport - /// has been closed or could not be established. - #[serde(rename = "closed")] - Closed, -} - -const DATA_CHANNEL_STATE_CONNECTING_STR: &str = "connecting"; -const DATA_CHANNEL_STATE_OPEN_STR: &str = "open"; -const DATA_CHANNEL_STATE_CLOSING_STR: &str = "closing"; -const DATA_CHANNEL_STATE_CLOSED_STR: &str = "closed"; - -impl From for RTCDataChannelState { - fn from(v: u8) -> Self { - match v { - 1 => RTCDataChannelState::Connecting, - 2 => RTCDataChannelState::Open, - 3 => RTCDataChannelState::Closing, - 4 => RTCDataChannelState::Closed, - _ => RTCDataChannelState::Unspecified, - } - } -} - -impl From<&str> for RTCDataChannelState { - fn from(raw: &str) -> Self { - match raw { - DATA_CHANNEL_STATE_CONNECTING_STR => RTCDataChannelState::Connecting, - DATA_CHANNEL_STATE_OPEN_STR => RTCDataChannelState::Open, - DATA_CHANNEL_STATE_CLOSING_STR => RTCDataChannelState::Closing, - DATA_CHANNEL_STATE_CLOSED_STR => RTCDataChannelState::Closed, - _ => RTCDataChannelState::Unspecified, - } - } -} - -impl fmt::Display for RTCDataChannelState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCDataChannelState::Connecting => DATA_CHANNEL_STATE_CONNECTING_STR, - RTCDataChannelState::Open => DATA_CHANNEL_STATE_OPEN_STR, - RTCDataChannelState::Closing => DATA_CHANNEL_STATE_CLOSING_STR, - RTCDataChannelState::Closed => DATA_CHANNEL_STATE_CLOSED_STR, - RTCDataChannelState::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_data_channel_state() { - let tests = vec![ - (crate::UNSPECIFIED_STR, RTCDataChannelState::Unspecified), - ("connecting", RTCDataChannelState::Connecting), - ("open", RTCDataChannelState::Open), - ("closing", RTCDataChannelState::Closing), - ("closed", RTCDataChannelState::Closed), - ]; - - for (state_string, expected_state) in tests { - assert_eq!( - RTCDataChannelState::from(state_string), - expected_state, - "testCase: {expected_state}", - ); - } - } - - #[test] - fn test_data_channel_state_string() { - let tests = vec![ - (RTCDataChannelState::Unspecified, crate::UNSPECIFIED_STR), - (RTCDataChannelState::Connecting, "connecting"), - (RTCDataChannelState::Open, "open"), - (RTCDataChannelState::Closing, "closing"), - (RTCDataChannelState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string) - } - } -} diff --git a/webrtc/src/data_channel/data_channel_test.rs b/webrtc/src/data_channel/data_channel_test.rs deleted file mode 100644 index 65e8eb9f1..000000000 --- a/webrtc/src/data_channel/data_channel_test.rs +++ /dev/null @@ -1,1504 +0,0 @@ -// Silence warning on `for i in 0..vec.len() { โ€ฆ }`: -#![allow(clippy::needless_range_loop)] - -use regex::Regex; -use tokio::sync::mpsc; -use tokio::time::Duration; -use waitgroup::WaitGroup; - -use super::*; -use crate::api::media_engine::MediaEngine; -use crate::api::{APIBuilder, API}; -use crate::data_channel::data_channel_init::RTCDataChannelInit; -//use log::LevelFilter; -//use std::io::Write; -use crate::dtls_transport::dtls_parameters::DTLSParameters; -use crate::dtls_transport::RTCDtlsTransport; -use crate::error::flatten_errs; -use crate::ice_transport::ice_candidate::RTCIceCandidate; -use crate::ice_transport::ice_connection_state::RTCIceConnectionState; -use crate::ice_transport::ice_gatherer::{RTCIceGatherOptions, RTCIceGatherer}; -use crate::ice_transport::ice_parameters::RTCIceParameters; -use crate::ice_transport::ice_role::RTCIceRole; -use crate::ice_transport::RTCIceTransport; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::peer_connection::peer_connection_test::*; -use crate::peer_connection::RTCPeerConnection; -use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; - -// EXPECTED_LABEL represents the label of the data channel we are trying to test. -// Some other channels may have been created during initialization (in the Wasm -// bindings this is a requirement). -const EXPECTED_LABEL: &str = "data"; - -async fn set_up_data_channel_parameters_test( - api: &API, - options: Option, -) -> Result<( - RTCPeerConnection, - RTCPeerConnection, - Arc, - mpsc::Sender<()>, - mpsc::Receiver<()>, -)> { - let (offer_pc, answer_pc) = new_pair(api).await?; - let (done_tx, done_rx) = mpsc::channel(1); - - let dc = offer_pc - .create_data_channel(EXPECTED_LABEL, options) - .await?; - Ok((offer_pc, answer_pc, dc, done_tx, done_rx)) -} - -async fn close_reliability_param_test( - pc1: &mut RTCPeerConnection, - pc2: &mut RTCPeerConnection, - done_rx: mpsc::Receiver<()>, -) -> Result<()> { - signal_pair(pc1, pc2).await?; - - close_pair(pc1, pc2, done_rx).await; - - Ok(()) -} - -/* -TODO: #[tokio::test] async fnBenchmarkDataChannelSend2(b *testing.B) { benchmarkDataChannelSend(b, 2) } -#[tokio::test] async fnBenchmarkDataChannelSend4(b *testing.B) { benchmarkDataChannelSend(b, 4) } -#[tokio::test] async fnBenchmarkDataChannelSend8(b *testing.B) { benchmarkDataChannelSend(b, 8) } -#[tokio::test] async fnBenchmarkDataChannelSend16(b *testing.B) { benchmarkDataChannelSend(b, 16) } -#[tokio::test] async fnBenchmarkDataChannelSend32(b *testing.B) { benchmarkDataChannelSend(b, 32) } - -// See https://github.com/pion/webrtc/issues/1516 -#[tokio::test] async fnbenchmarkDataChannelSend(b *testing.B, numChannels int) { - offerPC, answerPC, err := newPair() - if err != nil { - b.Fatalf("Failed to create a PC pair for testing") - } - - open := make(map[string]chan bool) - answerPC.OnDataChannel(func(d *DataChannel) { - if _, ok := open[d.Label()]; !ok { - // Ignore anything unknown channel label. - return - } - d.OnOpen(func() { open[d.Label()] <- true }) - }) - - var wg sync.WaitGroup - for i := 0; i < numChannels; i++ { - label := fmt.Sprintf("dc-%d", i) - open[label] = make(chan bool) - wg.Add(1) - dc, err := offerPC.CreateDataChannel(label, nil) - assert.NoError(b, err) - - dc.OnOpen(func() { - <-open[label] - for n := 0; n < b.N/numChannels; n++ { - if err := dc.SendText("Ping"); err != nil { - b.Fatalf("Unexpected error sending data (label=%q): %v", label, err) - } - } - wg.Done() - }) - } - - assert.NoError(b, signalPair(offerPC, answerPC)) - wg.Wait() - close_pair_now(b, offerPC, answerPC) -} -*/ - -#[tokio::test] -async fn test_data_channel_open() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - //"handler should be called once" - { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - - let (done_tx, done_rx) = mpsc::channel(1); - let (open_calls_tx, mut open_calls_rx) = mpsc::channel(2); - - let open_calls_tx = Arc::new(open_calls_tx); - let done_tx = Arc::new(done_tx); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - if d.label() == EXPECTED_LABEL { - let open_calls_tx2 = Arc::clone(&open_calls_tx); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - d.on_open(Box::new(move || { - Box::pin(async move { - let _ = open_calls_tx2.send(()).await; - }) - })); - d.on_message(Box::new(move |_: DataChannelMessage| { - let done_tx3 = Arc::clone(&done_tx2); - tokio::spawn(async move { - // Wait a little bit to ensure all messages are processed. - tokio::time::sleep(Duration::from_millis(100)).await; - let _ = done_tx3.send(()).await; - }); - Box::pin(async {}) - })); - }) - } else { - Box::pin(async {}) - } - })); - - let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - - let dc2 = Arc::clone(&dc); - dc.on_open(Box::new(move || { - Box::pin(async move { - let result = dc2.send_text("Ping".to_owned()).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - - signal_pair(&mut offer_pc, &mut answer_pc).await?; - - close_pair(&offer_pc, &answer_pc, done_rx).await; - - let _ = open_calls_rx.recv().await; - } - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_send_before_signaling() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - //"before signaling" - - let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - Box::pin(async move { - let d2 = Arc::clone(&d); - d.on_message(Box::new(move |_: DataChannelMessage| { - let d3 = Arc::clone(&d2); - Box::pin(async move { - let result = d3.send(&Bytes::from(b"Pong".to_vec())).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); - - let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - - assert!(dc.ordered(), "Ordered should be set to true"); - - let dc2 = Arc::clone(&dc); - dc.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - let result = dc3.send_text("Ping".to_owned()).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - - let (done_tx, done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - dc.on_message(Box::new(move |_: DataChannelMessage| { - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - signal_pair(&mut offer_pc, &mut answer_pc).await?; - - close_pair(&offer_pc, &answer_pc, done_rx).await; - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_send_after_connected() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - Box::pin(async move { - let d2 = Arc::clone(&d); - d.on_message(Box::new(move |_: DataChannelMessage| { - let d3 = Arc::clone(&d2); - - Box::pin(async move { - let result = d3.send(&Bytes::from(b"Pong".to_vec())).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); - - let dc = offer_pc - .create_data_channel(EXPECTED_LABEL, None) - .await - .expect("Failed to create a PC pair for testing"); - - let (done_tx, done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - - //once := &sync.Once{} - offer_pc.on_ice_connection_state_change(Box::new(move |state: RTCIceConnectionState| { - let done_tx1 = Arc::clone(&done_tx); - let dc1 = Arc::clone(&dc); - Box::pin(async move { - if state == RTCIceConnectionState::Connected - || state == RTCIceConnectionState::Completed - { - // wasm fires completed state multiple times - /*once.Do(func()*/ - { - assert!(dc1.ordered(), "Ordered should be set to true"); - - dc1.on_message(Box::new(move |_: DataChannelMessage| { - let done_tx2 = Arc::clone(&done_tx1); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - if dc1.send_text("Ping".to_owned()).await.is_err() { - // wasm binding doesn't fire OnOpen (we probably already missed it) - let dc2 = Arc::clone(&dc1); - dc1.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - let result = dc3.send_text("Ping".to_owned()).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - } - } - } - }) - })); - - signal_pair(&mut offer_pc, &mut answer_pc).await?; - - close_pair(&offer_pc, &answer_pc, done_rx).await; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_close() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - // "Close after PeerConnection Closed" - { - let (offer_pc, answer_pc) = new_pair(&api).await?; - - let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - - close_pair_now(&offer_pc, &answer_pc).await; - dc.close().await?; - } - - // "Close before connected" - { - let (offer_pc, answer_pc) = new_pair(&api).await?; - - let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - - dc.close().await?; - close_pair_now(&offer_pc, &answer_pc).await; - } - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_parameters_max_packet_life_time_exchange() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let ordered = true; - let max_packet_life_time = 3u16; - let options = RTCDataChannelInit { - ordered: Some(ordered), - max_packet_life_time: Some(max_packet_life_time), - ..Default::default() - }; - - let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = - set_up_data_channel_parameters_test(&api, Some(options)).await?; - - // Check if parameters are correctly set - assert_eq!( - dc.ordered(), - ordered, - "Ordered should be same value as set in DataChannelInit" - ); - assert_eq!( - dc.max_packet_lifetime(), - max_packet_life_time, - "should match" - ); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - // Check if parameters are correctly set - assert_eq!( - d.ordered(), - ordered, - "Ordered should be same value as set in DataChannelInit" - ); - assert_eq!( - d.max_packet_lifetime(), - max_packet_life_time, - "should match" - ); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_parameters_max_retransmits_exchange() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let ordered = false; - let max_retransmits = 3000u16; - let options = RTCDataChannelInit { - ordered: Some(ordered), - max_retransmits: Some(max_retransmits), - ..Default::default() - }; - - let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = - set_up_data_channel_parameters_test(&api, Some(options)).await?; - - // Check if parameters are correctly set - assert!(!dc.ordered(), "Ordered should be set to false"); - assert_eq!(dc.max_retransmits(), max_retransmits, "should match"); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - - // Check if parameters are correctly set - assert!(!d.ordered(), "Ordered should be set to false"); - assert_eq!(max_retransmits, d.max_retransmits(), "should match"); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_parameters_protocol_exchange() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let protocol = "json".to_owned(); - let options = RTCDataChannelInit { - protocol: Some(protocol.clone()), - ..Default::default() - }; - - let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = - set_up_data_channel_parameters_test(&api, Some(options)).await?; - - // Check if parameters are correctly set - assert_eq!( - protocol, - dc.protocol(), - "Protocol should match DataChannelConfig" - ); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - // Check if parameters are correctly set - assert_eq!( - protocol, - d.protocol(), - "Protocol should match what channel creator declared" - ); - - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_parameters_negotiated_exchange() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - const EXPECTED_MESSAGE: &str = "Hello World"; - - let id = 500u16; - let options = RTCDataChannelInit { - negotiated: Some(id), - ..Default::default() - }; - - let (mut offer_pc, mut answer_pc, offer_datachannel, done_tx, done_rx) = - set_up_data_channel_parameters_test(&api, Some(options.clone())).await?; - - let answer_datachannel = answer_pc - .create_data_channel(EXPECTED_LABEL, Some(options)) - .await?; - - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Ignore our default channel, exists to force ICE candidates. See signalPair for more info - if d.label() == "initial_data_channel" { - return Box::pin(async {}); - } - panic!("OnDataChannel must not be fired when negotiated == true"); - })); - - offer_pc.on_data_channel(Box::new(move |_d: Arc| { - panic!("OnDataChannel must not be fired when negotiated == true"); - })); - - let seen_answer_message = Arc::new(AtomicBool::new(false)); - let seen_offer_message = Arc::new(AtomicBool::new(false)); - - let seen_answer_message2 = Arc::clone(&seen_answer_message); - answer_datachannel.on_message(Box::new(move |msg: DataChannelMessage| { - if msg.is_string && msg.data == EXPECTED_MESSAGE { - seen_answer_message2.store(true, Ordering::SeqCst); - } - - Box::pin(async {}) - })); - - let seen_offer_message2 = Arc::clone(&seen_offer_message); - offer_datachannel.on_message(Box::new(move |msg: DataChannelMessage| { - if msg.is_string && msg.data == EXPECTED_MESSAGE { - seen_offer_message2.store(true, Ordering::SeqCst); - } - Box::pin(async {}) - })); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - tokio::spawn(async move { - loop { - if seen_answer_message.load(Ordering::SeqCst) - && seen_offer_message.load(Ordering::SeqCst) - { - break; - } - - if offer_datachannel.ready_state() == RTCDataChannelState::Open { - offer_datachannel - .send_text(EXPECTED_MESSAGE.to_owned()) - .await?; - } - if answer_datachannel.ready_state() == RTCDataChannelState::Open { - answer_datachannel - .send_text(EXPECTED_MESSAGE.to_owned()) - .await?; - } - - tokio::time::sleep(Duration::from_millis(50)).await; - } - - let mut done = done_tx.lock().await; - done.take(); - - Result::<()>::Ok(()) - }); - - close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_event_handlers() -> Result<()> { - let api = APIBuilder::new().build(); - - let dc = RTCDataChannel { - setting_engine: Arc::clone(&api.setting_engine), - ..Default::default() - }; - - let (on_open_called_tx, mut on_open_called_rx) = mpsc::channel::<()>(1); - let (on_message_called_tx, mut on_message_called_rx) = mpsc::channel::<()>(1); - - // Verify that the noop case works - dc.do_open(); - - let on_open_called_tx = Arc::new(Mutex::new(Some(on_open_called_tx))); - dc.on_open(Box::new(move || { - let on_open_called_tx2 = Arc::clone(&on_open_called_tx); - Box::pin(async move { - let mut done = on_open_called_tx2.lock().await; - done.take(); - }) - })); - - let on_message_called_tx = Arc::new(Mutex::new(Some(on_message_called_tx))); - dc.on_message(Box::new(move |_: DataChannelMessage| { - let on_message_called_tx2 = Arc::clone(&on_message_called_tx); - Box::pin(async move { - let mut done = on_message_called_tx2.lock().await; - done.take(); - }) - })); - - // Verify that the set handlers are called - dc.do_open(); - dc.do_message(DataChannelMessage { - is_string: false, - data: Bytes::from_static(b"o hai"), - }) - .await; - - // Wait for all handlers to be called - let _ = on_open_called_rx.recv().await; - let _ = on_message_called_rx.recv().await; - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_messages_are_ordered() -> Result<()> { - let api = APIBuilder::new().build(); - - let dc = RTCDataChannel { - setting_engine: Arc::clone(&api.setting_engine), - ..Default::default() - }; - - let m = 16u64; - let (out_tx, mut out_rx) = mpsc::channel::(m as usize); - - let out_tx = Arc::new(out_tx); - - let out_tx1 = Arc::clone(&out_tx); - dc.on_message(Box::new(move |msg: DataChannelMessage| { - let out_tx2 = Arc::clone(&out_tx1); - - Box::pin(async move { - // randomly sleep - let r = rand::random::() % m; - tokio::time::sleep(Duration::from_millis(r)).await; - - let mut buf = [0u8; 8]; - for i in 0..8 { - buf[i] = msg.data[i]; - } - let s = u64::from_be_bytes(buf); - - let _ = out_tx2.send(s).await; - }) - })); - - tokio::spawn(async move { - for j in 1..=m { - let buf = j.to_be_bytes().to_vec(); - - dc.do_message(DataChannelMessage { - is_string: false, - data: Bytes::from(buf), - }) - .await; - // Change the registered handler a couple of times to make sure - // that everything continues to work, we don't lose messages, etc. - if j % 2 == 0 { - let out_tx1 = Arc::clone(&out_tx); - dc.on_message(Box::new(move |msg: DataChannelMessage| { - let out_tx2 = Arc::clone(&out_tx1); - - Box::pin(async move { - // randomly sleep - let r = rand::random::() % m; - tokio::time::sleep(Duration::from_millis(r)).await; - - let mut buf = [0u8; 8]; - for i in 0..8 { - buf[i] = msg.data[i]; - } - let s = u64::from_be_bytes(buf); - - let _ = out_tx2.send(s).await; - }) - })); - } - } - }); - - let mut values = vec![]; - for _ in 1..=m { - if let Some(v) = out_rx.recv().await { - values.push(v); - } else { - break; - } - } - - let mut expected = vec![0u64; m as usize]; - for i in 1..=m as usize { - expected[i - 1] = i as u64; - } - assert_eq!(values, expected); - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_parameters_go() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - //"MaxPacketLifeTime exchange" - { - let ordered = true; - let max_packet_life_time = 3u16; - let options = RTCDataChannelInit { - ordered: Some(ordered), - max_packet_life_time: Some(max_packet_life_time), - ..Default::default() - }; - - let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = - set_up_data_channel_parameters_test(&api, Some(options)).await?; - - // Check if parameters are correctly set - assert!(dc.ordered(), "Ordered should be set to true"); - assert_eq!( - max_packet_life_time, - dc.max_packet_lifetime(), - "should match" - ); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - - // Check if parameters are correctly set - assert!(d.ordered, "Ordered should be set to true"); - assert_eq!( - max_packet_life_time, - d.max_packet_lifetime(), - "should match" - ); - - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; - } - - //"All other property methods" - { - let id = 123u16; - let dc = RTCDataChannel { - id: AtomicU16::new(id), - label: "mylabel".to_owned(), - protocol: "myprotocol".to_owned(), - negotiated: true, - ..Default::default() - }; - - assert_eq!(dc.id.load(Ordering::SeqCst), dc.id(), "should match"); - assert_eq!(dc.label, dc.label(), "should match"); - assert_eq!(dc.protocol, dc.protocol(), "should match"); - assert_eq!(dc.negotiated, dc.negotiated(), "should match"); - assert_eq!(0, dc.buffered_amount().await, "should match"); - dc.set_buffered_amount_low_threshold(1500).await; - assert_eq!( - 1500, - dc.buffered_amount_low_threshold().await, - "should match" - ); - } - - Ok(()) -} - -//use log::LevelFilter; -//use std::io::Write; - -#[tokio::test] -async fn test_data_channel_buffered_amount_set_before_open() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let n_cbs = Arc::new(AtomicU16::new(0)); - let buf = Bytes::from_static(&[0u8; 1000]); - - let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - - let (done_tx, done_rx) = mpsc::channel::<()>(1); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - let n_packets_received = Arc::new(AtomicU16::new(0)); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - - let done_tx2 = Arc::clone(&done_tx); - let n_packets_received2 = Arc::clone(&n_packets_received); - Box::pin(async move { - d.on_message(Box::new(move |_msg: DataChannelMessage| { - let n = n_packets_received2.fetch_add(1, Ordering::SeqCst); - if n == 9 { - let done_tx3 = Arc::clone(&done_tx2); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(10)).await; - let mut done = done_tx3.lock().await; - done.take(); - }); - } - - Box::pin(async {}) - })); - - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); - - let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - - assert!(dc.ordered(), "Ordered should be set to true"); - - let dc2 = Arc::clone(&dc); - dc.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - for _ in 0..10 { - assert!( - dc3.send(&buf).await.is_ok(), - "Failed to send string on data channel" - ); - assert_eq!( - 1500, - dc3.buffered_amount_low_threshold().await, - "value mismatch" - ); - } - }) - })); - - dc.on_message(Box::new(|_msg: DataChannelMessage| Box::pin(async {}))); - - // The value is temporarily stored in the dc object - // until the dc gets opened - dc.set_buffered_amount_low_threshold(1500).await; - // The callback function is temporarily stored in the dc object - // until the dc gets opened - let n_cbs2 = Arc::clone(&n_cbs); - dc.on_buffered_amount_low(Box::new(move || { - n_cbs2.fetch_add(1, Ordering::SeqCst); - Box::pin(async {}) - })) - .await; - - signal_pair(&mut offer_pc, &mut answer_pc).await?; - - close_pair(&offer_pc, &answer_pc, done_rx).await; - - assert!( - n_cbs.load(Ordering::SeqCst) > 0, - "callback should be made at least once" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_data_channel_buffered_amount_set_after_open() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let n_cbs = Arc::new(AtomicU16::new(0)); - let buf = Bytes::from_static(&[0u8; 1000]); - - let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - - let (done_tx, done_rx) = mpsc::channel::<()>(1); - - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - let n_packets_received = Arc::new(AtomicU16::new(0)); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - - let done_tx2 = Arc::clone(&done_tx); - let n_packets_received2 = Arc::clone(&n_packets_received); - Box::pin(async move { - d.on_message(Box::new(move |_msg: DataChannelMessage| { - let n = n_packets_received2.fetch_add(1, Ordering::SeqCst); - if n == 9 { - let done_tx3 = Arc::clone(&done_tx2); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(10)).await; - let mut done = done_tx3.lock().await; - done.take(); - }); - } - - Box::pin(async {}) - })); - - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); - - let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - - assert!(dc.ordered(), "Ordered should be set to true"); - - let dc2 = Arc::clone(&dc); - let n_cbs2 = Arc::clone(&n_cbs); - dc.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - // The value should directly be passed to sctp - dc3.set_buffered_amount_low_threshold(1500).await; - // The callback function should directly be passed to sctp - dc3.on_buffered_amount_low(Box::new(move || { - n_cbs2.fetch_add(1, Ordering::SeqCst); - Box::pin(async {}) - })) - .await; - - for _ in 0..10 { - assert!( - dc3.send(&buf).await.is_ok(), - "Failed to send string on data channel" - ); - assert_eq!( - 1500, - dc3.buffered_amount_low_threshold().await, - "value mismatch" - ); - } - }) - })); - - dc.on_message(Box::new(|_msg: DataChannelMessage| Box::pin(async {}))); - - signal_pair(&mut offer_pc, &mut answer_pc).await?; - - close_pair(&offer_pc, &answer_pc, done_rx).await; - - assert!( - n_cbs.load(Ordering::SeqCst) > 0, - "callback should be made at least once" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_eof_detach() -> Result<()> { - let label: &str = "test-channel"; - let test_data: &'static str = "this is some test data"; - - // Use Detach data channels mode - let mut s = SettingEngine::default(); - s.detach_data_channels(); - let api = APIBuilder::new().with_setting_engine(s).build(); - - // Set up two peer connections. - let mut pca = api.new_peer_connection(RTCConfiguration::default()).await?; - let mut pcb = api.new_peer_connection(RTCConfiguration::default()).await?; - - let wg = WaitGroup::new(); - - let (dc_chan_tx, mut dc_chan_rx) = mpsc::channel(1); - let dc_chan_tx = Arc::new(dc_chan_tx); - pcb.on_data_channel(Box::new(move |dc: Arc| { - if dc.label() != label { - return Box::pin(async {}); - } - log::debug!("OnDataChannel was called"); - let dc_chan_tx2 = Arc::clone(&dc_chan_tx); - let dc2 = Arc::clone(&dc); - Box::pin(async move { - let dc3 = Arc::clone(&dc2); - dc2.on_open(Box::new(move || { - let dc_chan_tx3 = Arc::clone(&dc_chan_tx2); - let dc4 = Arc::clone(&dc3); - Box::pin(async move { - let detached = match dc4.detach().await { - Ok(detached) => detached, - Err(err) => { - log::debug!("Detach failed: {}", err); - panic!(); - } - }; - - let _ = dc_chan_tx3.send(detached).await; - }) - })); - }) - })); - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - - log::debug!("Waiting for OnDataChannel"); - let dc = dc_chan_rx.recv().await.unwrap(); - log::debug!("data channel opened"); - - log::debug!("Waiting for ping..."); - let mut msg = vec![0u8; 256]; - let n = dc.read(&mut msg).await?; - log::debug!("Received ping! {:?}\n", &msg[..n]); - - assert_eq!(test_data.as_bytes(), &msg[..n]); - log::debug!("Received ping successfully!"); - - dc.close().await?; - - Result::<()>::Ok(()) - }); - - signal_pair(&mut pca, &mut pcb).await?; - - let attached = pca.create_data_channel(label, None).await?; - - log::debug!("Waiting for data channel to open"); - let (open_tx, mut open_rx) = mpsc::channel::<()>(1); - let open_tx = Arc::new(open_tx); - attached.on_open(Box::new(move || { - let open_tx2 = Arc::clone(&open_tx); - Box::pin(async move { - let _ = open_tx2.send(()).await; - }) - })); - - let _ = open_rx.recv().await; - log::debug!("data channel opened"); - - let dc = attached.detach().await?; - - let w = wg.worker(); - tokio::spawn(async move { - let _d = w; - log::debug!("Sending ping..."); - dc.write(&Bytes::from_static(test_data.as_bytes())).await?; - log::debug!("Sent ping"); - - dc.close().await?; - - log::debug!("Waiting for EOF"); - let mut buf = vec![0u8; 256]; - let n = dc.read(&mut buf).await?; - assert_eq!(0, n, "should be empty"); - - Result::<()>::Ok(()) - }); - - wg.wait().await; - - close_pair_now(&pca, &pcb).await; - - Ok(()) -} - -#[tokio::test] -async fn test_eof_no_detach() -> Result<()> { - let label: &str = "test-channel"; - let test_data: &'static [u8] = b"this is some test data"; - - let api = APIBuilder::new().build(); - - // Set up two peer connections. - let mut pca = api.new_peer_connection(RTCConfiguration::default()).await?; - let mut pcb = api.new_peer_connection(RTCConfiguration::default()).await?; - - let (dca_closed_ch_tx, mut dca_closed_ch_rx) = mpsc::channel::<()>(1); - let (dcb_closed_ch_tx, mut dcb_closed_ch_rx) = mpsc::channel::<()>(1); - - let dcb_closed_ch_tx = Arc::new(dcb_closed_ch_tx); - pcb.on_data_channel(Box::new(move |dc: Arc| { - if dc.label() != label { - return Box::pin(async {}); - } - - log::debug!("pcb: new datachannel: {}", dc.label()); - - let dcb_closed_ch_tx2 = Arc::clone(&dcb_closed_ch_tx); - Box::pin(async move { - // Register channel opening handling - dc.on_open(Box::new(move || { - log::debug!("pcb: datachannel opened"); - Box::pin(async {}) - })); - - dc.on_close(Box::new(move || { - // (2) - log::debug!("pcb: data channel closed"); - let dcb_closed_ch_tx3 = Arc::clone(&dcb_closed_ch_tx2); - Box::pin(async move { - let _ = dcb_closed_ch_tx3.send(()).await; - }) - })); - - // Register the OnMessage to handle incoming messages - log::debug!("pcb: registering onMessage callback"); - dc.on_message(Box::new(|dc_msg: DataChannelMessage| { - let test_data: &'static [u8] = b"this is some test data"; - log::debug!("pcb: received ping: {:?}", dc_msg.data); - assert_eq!(&dc_msg.data[..], test_data, "data mismatch"); - Box::pin(async {}) - })); - }) - })); - - let dca = pca.create_data_channel(label, None).await?; - let dca2 = Arc::clone(&dca); - dca.on_open(Box::new(move || { - log::debug!("pca: data channel opened"); - log::debug!("pca: sending {:?}", test_data); - let dca3 = Arc::clone(&dca2); - Box::pin(async move { - let _ = dca3.send(&Bytes::from_static(test_data)).await; - log::debug!("pca: sent ping"); - assert!(dca3.close().await.is_ok(), "should succeed"); // <-- dca closes - }) - })); - - let dca_closed_ch_tx = Arc::new(dca_closed_ch_tx); - dca.on_close(Box::new(move || { - // (1) - log::debug!("pca: data channel closed"); - let dca_closed_ch_tx2 = Arc::clone(&dca_closed_ch_tx); - Box::pin(async move { - let _ = dca_closed_ch_tx2.send(()).await; - }) - })); - - // Register the OnMessage to handle incoming messages - log::debug!("pca: registering onMessage callback"); - dca.on_message(Box::new(move |dc_msg: DataChannelMessage| { - log::debug!("pca: received pong: {:?}", &dc_msg.data[..]); - assert_eq!(&dc_msg.data[..], test_data, "data mismatch"); - Box::pin(async {}) - })); - - signal_pair(&mut pca, &mut pcb).await?; - - // When dca closes the channel, - // (1) dca.Onclose() will fire immediately, then - // (2) dcb.OnClose will also fire - let _ = dca_closed_ch_rx.recv().await; // (1) - let _ = dcb_closed_ch_rx.recv().await; // (2) - - close_pair_now(&pca, &pcb).await; - - Ok(()) -} - -// Assert that a Session Description that doesn't follow -// draft-ietf-mmusic-sctp-sdp is still accepted -#[tokio::test] -async fn test_data_channel_non_standard_session_description() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (offer_pc, answer_pc) = new_pair(&api).await?; - - let _ = offer_pc.create_data_channel("foo", None).await?; - - let (on_data_channel_called_tx, mut on_data_channel_called_rx) = mpsc::channel::<()>(1); - let on_data_channel_called_tx = Arc::new(on_data_channel_called_tx); - answer_pc.on_data_channel(Box::new(move |_: Arc| { - let on_data_channel_called_tx2 = Arc::clone(&on_data_channel_called_tx); - Box::pin(async move { - let _ = on_data_channel_called_tx2.send(()).await; - }) - })); - - let offer = offer_pc.create_offer(None).await?; - - let mut offer_gathering_complete = offer_pc.gathering_complete_promise().await; - offer_pc.set_local_description(offer).await?; - let _ = offer_gathering_complete.recv().await; - - let mut offer = offer_pc.local_description().await.unwrap(); - - // Replace with old values - const OLD_APPLICATION: &str = "m=application 63743 DTLS/SCTP 5000\r"; - const OLD_ATTRIBUTE: &str = "a=sctpmap:5000 webrtc-datachannel 256\r"; - - let re = Regex::new(r"m=application (.*?)\r").unwrap(); - offer.sdp = re - .replace_all(offer.sdp.as_str(), OLD_APPLICATION) - .to_string(); - let re = Regex::new(r"a=sctp-port(.*?)\r").unwrap(); - offer.sdp = re - .replace_all(offer.sdp.as_str(), OLD_ATTRIBUTE) - .to_string(); - - // Assert that replace worked - assert!(offer.sdp.contains(OLD_APPLICATION)); - assert!(offer.sdp.contains(OLD_ATTRIBUTE)); - - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - - let mut answer_gathering_complete = answer_pc.gathering_complete_promise().await; - answer_pc.set_local_description(answer).await?; - let _ = answer_gathering_complete.recv().await; - - let answer = answer_pc.local_description().await.unwrap(); - offer_pc.set_remote_description(answer).await?; - - let _ = on_data_channel_called_rx.recv().await; - - close_pair_now(&offer_pc, &answer_pc).await; - - Ok(()) -} - -struct TestOrtcStack { - //api *API - gatherer: Arc, - ice: Arc, - dtls: Arc, - sctp: Arc, -} - -struct TestOrtcSignal { - ice_candidates: Vec, //`json:"iceCandidates"` - ice_parameters: RTCIceParameters, //`json:"iceParameters"` - dtls_parameters: DTLSParameters, //`json:"dtlsParameters"` - sctp_capabilities: SCTPTransportCapabilities, //`json:"sctpCapabilities"` -} - -impl TestOrtcStack { - async fn new(api: &API) -> Result { - // Create the ICE gatherer - let gatherer = Arc::new(api.new_ice_gatherer(RTCIceGatherOptions::default())?); - - // Construct the ICE transport - let ice = Arc::new(api.new_ice_transport(Arc::clone(&gatherer))); - - // Construct the DTLS transport - let dtls = Arc::new(api.new_dtls_transport(Arc::clone(&ice), vec![])?); - - // Construct the SCTP transport - let sctp = Arc::new(api.new_sctp_transport(Arc::clone(&dtls))?); - - Ok(TestOrtcStack { - gatherer, - ice, - dtls, - sctp, - }) - } - - async fn set_signal(&self, sig: &TestOrtcSignal, is_offer: bool) -> Result<()> { - let ice_role = if is_offer { - RTCIceRole::Controlling - } else { - RTCIceRole::Controlled - }; - - self.ice.set_remote_candidates(&sig.ice_candidates).await?; - - // Start the ICE transport - self.ice.start(&sig.ice_parameters, Some(ice_role)).await?; - - // Start the DTLS transport - self.dtls.start(sig.dtls_parameters.clone()).await?; - - // Start the SCTP transport - self.sctp.start(sig.sctp_capabilities).await?; - - Ok(()) - } - - async fn get_signal(&self) -> Result { - let (gather_finished_tx, mut gather_finished_rx) = mpsc::channel::<()>(1); - let gather_finished_tx = Arc::new(gather_finished_tx); - self.gatherer - .on_local_candidate(Box::new(move |i: Option| { - let gather_finished_tx2 = Arc::clone(&gather_finished_tx); - Box::pin(async move { - if i.is_none() { - let _ = gather_finished_tx2.send(()).await; - } - }) - })); - - self.gatherer.gather().await?; - - let _ = gather_finished_rx.recv().await; - - let ice_candidates = self.gatherer.get_local_candidates().await?; - - let ice_parameters = self.gatherer.get_local_parameters().await?; - - let dtls_parameters = self.dtls.get_local_parameters()?; - - let sctp_capabilities = self.sctp.get_capabilities(); - - Ok(TestOrtcSignal { - ice_candidates, - ice_parameters, - dtls_parameters, - sctp_capabilities, - }) - } - - async fn close(&self) -> Result<()> { - let mut close_errs = vec![]; - - if let Err(err) = self.sctp.stop().await { - close_errs.push(err); - } - - if let Err(err) = self.ice.stop().await { - close_errs.push(err); - } - - flatten_errs(close_errs) - } -} - -async fn new_ortc_pair(api: &API) -> Result<(Arc, Arc)> { - let sa = Arc::new(TestOrtcStack::new(api).await?); - let sb = Arc::new(TestOrtcStack::new(api).await?); - Ok((sa, sb)) -} - -async fn signal_ortc_pair(stack_a: Arc, stack_b: Arc) -> Result<()> { - let sig_a = stack_a.get_signal().await?; - let sig_b = stack_b.get_signal().await?; - - let (a_tx, mut a_rx) = mpsc::channel(1); - let (b_tx, mut b_rx) = mpsc::channel(1); - - tokio::spawn(async move { - let _ = a_tx.send(stack_b.set_signal(&sig_a, false).await).await; - }); - - tokio::spawn(async move { - let _ = b_tx.send(stack_a.set_signal(&sig_b, true).await).await; - }); - - let err_a = a_rx.recv().await.unwrap(); - let err_b = b_rx.recv().await.unwrap(); - - let mut close_errs = vec![]; - if let Err(err) = err_a { - close_errs.push(err); - } - if let Err(err) = err_b { - close_errs.push(err); - } - - flatten_errs(close_errs) -} - -#[tokio::test] -async fn test_data_channel_ortc_e2e() -> Result<()> { - let api = APIBuilder::new().build(); - - let (stack_a, stack_b) = new_ortc_pair(&api).await?; - - let (await_setup_tx, mut await_setup_rx) = mpsc::channel::<()>(1); - let (await_string_tx, mut await_string_rx) = mpsc::channel::<()>(1); - let (await_binary_tx, mut await_binary_rx) = mpsc::channel::<()>(1); - - let await_setup_tx = Arc::new(await_setup_tx); - let await_string_tx = Arc::new(await_string_tx); - let await_binary_tx = Arc::new(await_binary_tx); - stack_b - .sctp - .on_data_channel(Box::new(move |d: Arc| { - let await_setup_tx2 = Arc::clone(&await_setup_tx); - let await_string_tx2 = Arc::clone(&await_string_tx); - let await_binary_tx2 = Arc::clone(&await_binary_tx); - Box::pin(async move { - let _ = await_setup_tx2.send(()).await; - - d.on_message(Box::new(move |msg: DataChannelMessage| { - let await_string_tx3 = Arc::clone(&await_string_tx2); - let await_binary_tx3 = Arc::clone(&await_binary_tx2); - Box::pin(async move { - if msg.is_string { - let _ = await_string_tx3.send(()).await; - } else { - let _ = await_binary_tx3.send(()).await; - } - }) - })); - }) - })); - - signal_ortc_pair(Arc::clone(&stack_a), Arc::clone(&stack_b)).await?; - - let dc_params = DataChannelParameters { - label: "Foo".to_owned(), - negotiated: None, - ..Default::default() - }; - - let channel_a = api - .new_data_channel(Arc::clone(&stack_a.sctp), dc_params) - .await?; - - let _ = await_setup_rx.recv().await; - - channel_a.send_text("ABC".to_owned()).await?; - channel_a.send(&Bytes::from_static(b"ABC")).await?; - - let _ = await_string_rx.recv().await; - let _ = await_binary_rx.recv().await; - - stack_a.close().await?; - stack_b.close().await?; - - // attempt to send when channel is closed - let result = channel_a.send(&Bytes::from_static(b"ABC")).await; - if let Err(err) = result { - assert_eq!( - Error::ErrClosedPipe, - err, - "expected ErrClosedPipe, but got {err}" - ); - } else { - panic!(); - } - - let result = channel_a.send_text("test".to_owned()).await; - if let Err(err) = result { - assert_eq!( - Error::ErrClosedPipe, - err, - "expected ErrClosedPipe, but got {err}" - ); - } else { - panic!(); - } - - let result = channel_a.ensure_open(); - if let Err(err) = result { - assert_eq!( - Error::ErrClosedPipe, - err, - "expected ErrClosedPipe, but got {err}" - ); - } else { - panic!(); - } - - Ok(()) -} diff --git a/webrtc/src/data_channel/mod.rs b/webrtc/src/data_channel/mod.rs deleted file mode 100644 index bdf05af2f..000000000 --- a/webrtc/src/data_channel/mod.rs +++ /dev/null @@ -1,556 +0,0 @@ -#[cfg(test)] -mod data_channel_test; - -pub mod data_channel_init; -pub mod data_channel_message; -pub mod data_channel_parameters; -pub mod data_channel_state; - -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::{Arc, Weak}; -use std::time::SystemTime; - -use arc_swap::ArcSwapOption; -use bytes::Bytes; -use data::message::message_channel_open::ChannelType; -use data_channel_message::*; -use data_channel_parameters::*; -use data_channel_state::RTCDataChannelState; -use portable_atomic::{AtomicBool, AtomicU16, AtomicU8, AtomicUsize}; -use sctp::stream::OnBufferedAmountLowFn; -use tokio::sync::{Mutex, Notify}; -use util::sync::Mutex as SyncMutex; - -use crate::api::setting_engine::SettingEngine; -use crate::error::{Error, OnErrorHdlrFn, Result}; -use crate::sctp_transport::RTCSctpTransport; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::{DataChannelStats, StatsReportType}; - -/// message size limit for Chromium -const DATA_CHANNEL_BUFFER_SIZE: u16 = u16::MAX; - -pub type OnMessageHdlrFn = Box< - dyn (FnMut(DataChannelMessage) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnOpenHdlrFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -pub type OnCloseHdlrFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -/// DataChannel represents a WebRTC DataChannel -/// The DataChannel interface represents a network channel -/// which can be used for bidirectional peer-to-peer transfers of arbitrary data -#[derive(Default)] -pub struct RTCDataChannel { - pub(crate) stats_id: String, - pub(crate) label: String, - pub(crate) ordered: bool, - pub(crate) max_packet_lifetime: u16, - pub(crate) max_retransmits: u16, - pub(crate) protocol: String, - pub(crate) negotiated: bool, - pub(crate) id: AtomicU16, - pub(crate) ready_state: Arc, // DataChannelState - pub(crate) buffered_amount_low_threshold: AtomicUsize, - pub(crate) detach_called: Arc, - - // The binaryType represents attribute MUST, on getting, return the value to - // which it was last set. On setting, if the new value is either the string - // "blob" or the string "arraybuffer", then set the IDL attribute to this - // new value. Otherwise, throw a SyntaxError. When an DataChannel object - // is created, the binaryType attribute MUST be initialized to the string - // "blob". This attribute controls how binary data is exposed to scripts. - // binaryType string - pub(crate) on_message_handler: Arc>>, - pub(crate) on_open_handler: SyncMutex>, - pub(crate) on_close_handler: Arc>>, - pub(crate) on_error_handler: Arc>>, - - pub(crate) on_buffered_amount_low: Mutex>, - - pub(crate) sctp_transport: Mutex>>, - pub(crate) data_channel: Mutex>>, - - pub(crate) notify_tx: Arc, - - // A reference to the associated api object used by this datachannel - pub(crate) setting_engine: Arc, -} - -impl RTCDataChannel { - // create the DataChannel object before the networking is set up. - pub(crate) fn new(params: DataChannelParameters, setting_engine: Arc) -> Self { - // the id value if non-negotiated doesn't matter, since it will be overwritten - // on opening - let id = params.negotiated.unwrap_or(0); - RTCDataChannel { - stats_id: format!( - "DataChannel-{}", - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map_or(0, |d| d.as_nanos()) - ), - label: params.label, - protocol: params.protocol, - negotiated: params.negotiated.is_some(), - id: AtomicU16::new(id), - ordered: params.ordered, - max_packet_lifetime: params.max_packet_life_time, - max_retransmits: params.max_retransmits, - ready_state: Arc::new(AtomicU8::new(RTCDataChannelState::Connecting as u8)), - detach_called: Arc::new(AtomicBool::new(false)), - - notify_tx: Arc::new(Notify::new()), - - setting_engine, - ..Default::default() - } - } - - /// open opens the datachannel over the sctp transport - pub(crate) async fn open(&self, sctp_transport: Arc) -> Result<()> { - if let Some(association) = sctp_transport.association().await { - { - let mut st = self.sctp_transport.lock().await; - if st.is_none() { - *st = Some(Arc::downgrade(&sctp_transport)); - } else { - return Ok(()); - } - } - - let channel_type; - let reliability_parameter; - - if self.max_packet_lifetime == 0 && self.max_retransmits == 0 { - reliability_parameter = 0u32; - if self.ordered { - channel_type = ChannelType::Reliable; - } else { - channel_type = ChannelType::ReliableUnordered; - } - } else if self.max_retransmits != 0 { - reliability_parameter = self.max_retransmits as u32; - if self.ordered { - channel_type = ChannelType::PartialReliableRexmit; - } else { - channel_type = ChannelType::PartialReliableRexmitUnordered; - } - } else { - reliability_parameter = self.max_packet_lifetime as u32; - if self.ordered { - channel_type = ChannelType::PartialReliableTimed; - } else { - channel_type = ChannelType::PartialReliableTimedUnordered; - } - } - - let cfg = data::data_channel::Config { - channel_type, - priority: data::message::message_channel_open::CHANNEL_PRIORITY_NORMAL, - reliability_parameter, - label: self.label.clone(), - protocol: self.protocol.clone(), - negotiated: self.negotiated, - }; - - if !self.negotiated { - self.id.store( - sctp_transport - .generate_and_set_data_channel_id( - sctp_transport.dtls_transport.role().await, - ) - .await?, - Ordering::SeqCst, - ); - } - - let dc = data::data_channel::DataChannel::dial(&association, self.id(), cfg).await?; - - // buffered_amount_low_threshold and on_buffered_amount_low might be set earlier - dc.set_buffered_amount_low_threshold( - self.buffered_amount_low_threshold.load(Ordering::SeqCst), - ); - { - let mut on_buffered_amount_low = self.on_buffered_amount_low.lock().await; - if let Some(f) = on_buffered_amount_low.take() { - dc.on_buffered_amount_low(f); - } - } - - self.handle_open(Arc::new(dc)).await; - - Ok(()) - } else { - Err(Error::ErrSCTPNotEstablished) - } - } - - /// transport returns the SCTPTransport instance the DataChannel is sending over. - pub async fn transport(&self) -> Option> { - let sctp_transport = self.sctp_transport.lock().await; - sctp_transport.clone() - } - - /// on_open sets an event handler which is invoked when - /// the underlying data transport has been established (or re-established). - pub fn on_open(&self, f: OnOpenHdlrFn) { - let _ = self.on_open_handler.lock().replace(f); - - if self.ready_state() == RTCDataChannelState::Open { - self.do_open(); - } - } - - fn do_open(&self) { - let on_open_handler = self.on_open_handler.lock().take(); - if on_open_handler.is_none() { - return; - } - - let detach_data_channels = self.setting_engine.detach.data_channels; - let detach_called = Arc::clone(&self.detach_called); - tokio::spawn(async move { - if let Some(f) = on_open_handler { - f().await; - - // self.check_detach_after_open(); - // After onOpen is complete check that the user called detach - // and provide an error message if the call was missed - if detach_data_channels && !detach_called.load(Ordering::SeqCst) { - log::warn!( - "webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen" - ); - } - } - }); - } - - /// on_close sets an event handler which is invoked when - /// the underlying data transport has been closed. - pub fn on_close(&self, f: OnCloseHdlrFn) { - self.on_close_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_message sets an event handler which is invoked on a binary - /// message arrival over the sctp transport from a remote peer. - /// OnMessage can currently receive messages up to 16384 bytes - /// in size. Check out the detach API if you want to use larger - /// message sizes. Note that browser support for larger messages - /// is also limited. - pub fn on_message(&self, f: OnMessageHdlrFn) { - self.on_message_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - async fn do_message(&self, msg: DataChannelMessage) { - if let Some(handler) = &*self.on_message_handler.load() { - let mut f = handler.lock().await; - f(msg).await; - } - } - - pub(crate) async fn handle_open(&self, dc: Arc) { - { - let mut data_channel = self.data_channel.lock().await; - *data_channel = Some(Arc::clone(&dc)); - } - self.set_ready_state(RTCDataChannelState::Open); - - self.do_open(); - - if !self.setting_engine.detach.data_channels { - let ready_state = Arc::clone(&self.ready_state); - let on_message_handler = Arc::clone(&self.on_message_handler); - let on_close_handler = Arc::clone(&self.on_close_handler); - let on_error_handler = Arc::clone(&self.on_error_handler); - let notify_rx = self.notify_tx.clone(); - tokio::spawn(async move { - RTCDataChannel::read_loop( - notify_rx, - dc, - ready_state, - on_message_handler, - on_close_handler, - on_error_handler, - ) - .await; - }); - } - } - - /// on_error sets an event handler which is invoked when - /// the underlying data transport cannot be read. - pub fn on_error(&self, f: OnErrorHdlrFn) { - self.on_error_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - async fn read_loop( - notify_rx: Arc, - data_channel: Arc, - ready_state: Arc, - on_message_handler: Arc>>, - on_close_handler: Arc>>, - on_error_handler: Arc>>, - ) { - let mut buffer = vec![0u8; DATA_CHANNEL_BUFFER_SIZE as usize]; - loop { - let (n, is_string) = tokio::select! { - _ = notify_rx.notified() => break, - result = data_channel.read_data_channel(&mut buffer) => { - match result{ - // EOF (`data_channel` was either closed or the underlying stream got - // reset by the remote) => close and run `on_close` handler. - Ok((0, _)) => - { - ready_state.store(RTCDataChannelState::Closed as u8, Ordering::SeqCst); - - let on_close_handler2 = Arc::clone(&on_close_handler); - tokio::spawn(async move { - if let Some(handler) = &*on_close_handler2.load() { - let mut f = handler.lock().await; - f().await; - } - }); - - break; - } - Ok((n, is_string)) => (n, is_string), - Err(err) => { - ready_state.store(RTCDataChannelState::Closed as u8, Ordering::SeqCst); - - let on_error_handler2 = Arc::clone(&on_error_handler); - tokio::spawn(async move { - if let Some(handler) = &*on_error_handler2.load() { - let mut f = handler.lock().await; - f(err.into()).await; - } - }); - - let on_close_handler2 = Arc::clone(&on_close_handler); - tokio::spawn(async move { - if let Some(handler) = &*on_close_handler2.load() { - let mut f = handler.lock().await; - f().await; - } - }); - - break; - } - } - } - }; - - if let Some(handler) = &*on_message_handler.load() { - let mut f = handler.lock().await; - f(DataChannelMessage { - is_string, - data: Bytes::from(buffer[..n].to_vec()), - }) - .await; - } - } - } - - /// send sends the binary message to the DataChannel peer - pub async fn send(&self, data: &Bytes) -> Result { - self.ensure_open()?; - - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - Ok(dc.write_data_channel(data, false).await?) - } else { - Err(Error::ErrClosedPipe) - } - } - - /// send_text sends the text message to the DataChannel peer - pub async fn send_text(&self, s: impl Into) -> Result { - self.ensure_open()?; - - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - Ok(dc.write_data_channel(&Bytes::from(s.into()), true).await?) - } else { - Err(Error::ErrClosedPipe) - } - } - - fn ensure_open(&self) -> Result<()> { - if self.ready_state() != RTCDataChannelState::Open { - Err(Error::ErrClosedPipe) - } else { - Ok(()) - } - } - - /// detach allows you to detach the underlying datachannel. This provides - /// an idiomatic API to work with, however it disables the OnMessage callback. - /// Before calling Detach you have to enable this behavior by calling - /// webrtc.DetachDataChannels(). Combining detached and normal data channels - /// is not supported. - /// Please refer to the data-channels-detach example and the - /// pion/datachannel documentation for the correct way to handle the - /// resulting DataChannel object. - pub async fn detach(&self) -> Result> { - if !self.setting_engine.detach.data_channels { - return Err(Error::ErrDetachNotEnabled); - } - - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - self.detach_called.store(true, Ordering::SeqCst); - - Ok(Arc::clone(dc)) - } else { - Err(Error::ErrDetachBeforeOpened) - } - } - - /// Close Closes the DataChannel. It may be called regardless of whether - /// the DataChannel object was created by this peer or the remote peer. - pub async fn close(&self) -> Result<()> { - if self.ready_state() == RTCDataChannelState::Closed { - return Ok(()); - } - - self.set_ready_state(RTCDataChannelState::Closing); - self.notify_tx.notify_waiters(); - - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - Ok(dc.close().await?) - } else { - Ok(()) - } - } - - /// label represents a label that can be used to distinguish this - /// DataChannel object from other DataChannel objects. Scripts are - /// allowed to create multiple DataChannel objects with the same label. - pub fn label(&self) -> &str { - self.label.as_str() - } - - /// Ordered returns true if the DataChannel is ordered, and false if - /// out-of-order delivery is allowed. - pub fn ordered(&self) -> bool { - self.ordered - } - - /// max_packet_lifetime represents the length of the time window (msec) during - /// which transmissions and retransmissions may occur in unreliable mode. - pub fn max_packet_lifetime(&self) -> u16 { - self.max_packet_lifetime - } - - /// max_retransmits represents the maximum number of retransmissions that are - /// attempted in unreliable mode. - pub fn max_retransmits(&self) -> u16 { - self.max_retransmits - } - - /// protocol represents the name of the sub-protocol used with this - /// DataChannel. - pub fn protocol(&self) -> &str { - self.protocol.as_str() - } - - /// negotiated represents whether this DataChannel was negotiated by the - /// application (true), or not (false). - pub fn negotiated(&self) -> bool { - self.negotiated - } - - /// ID represents the ID for this DataChannel. The value is initially - /// null, which is what will be returned if the ID was not provided at - /// channel creation time, and the DTLS role of the SCTP transport has not - /// yet been negotiated. Otherwise, it will return the ID that was either - /// selected by the script or generated. After the ID is set to a non-null - /// value, it will not change. - pub fn id(&self) -> u16 { - self.id.load(Ordering::SeqCst) - } - - /// ready_state represents the state of the DataChannel object. - pub fn ready_state(&self) -> RTCDataChannelState { - self.ready_state.load(Ordering::SeqCst).into() - } - - /// buffered_amount represents the number of bytes of application data - /// (UTF-8 text and binary data) that have been queued using send(). Even - /// though the data transmission can occur in parallel, the returned value - /// MUST NOT be decreased before the current task yielded back to the event - /// loop to prevent race conditions. The value does not include framing - /// overhead incurred by the protocol, or buffering done by the operating - /// system or network hardware. The value of buffered_amount slot will only - /// increase with each call to the send() method as long as the ready_state is - /// open; however, buffered_amount does not reset to zero once the channel - /// closes. - pub async fn buffered_amount(&self) -> usize { - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - dc.buffered_amount() - } else { - 0 - } - } - - /// buffered_amount_low_threshold represents the threshold at which the - /// bufferedAmount is considered to be low. When the bufferedAmount decreases - /// from above this threshold to equal or below it, the bufferedamountlow - /// event fires. buffered_amount_low_threshold is initially zero on each new - /// DataChannel, but the application may change its value at any time. - /// The threshold is set to 0 by default. - pub async fn buffered_amount_low_threshold(&self) -> usize { - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - dc.buffered_amount_low_threshold() - } else { - self.buffered_amount_low_threshold.load(Ordering::SeqCst) - } - } - - /// set_buffered_amount_low_threshold is used to update the threshold. - /// See buffered_amount_low_threshold(). - pub async fn set_buffered_amount_low_threshold(&self, th: usize) { - self.buffered_amount_low_threshold - .store(th, Ordering::SeqCst); - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - dc.set_buffered_amount_low_threshold(th); - } - } - - /// on_buffered_amount_low sets an event handler which is invoked when - /// the number of bytes of outgoing data becomes lower than the - /// buffered_amount_low_threshold. - pub async fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { - let data_channel = self.data_channel.lock().await; - if let Some(dc) = &*data_channel { - dc.on_buffered_amount_low(f); - } else { - let mut on_buffered_amount_low = self.on_buffered_amount_low.lock().await; - *on_buffered_amount_low = Some(f); - } - } - - pub(crate) fn get_stats_id(&self) -> &str { - self.stats_id.as_str() - } - - pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { - let stats = DataChannelStats::from(self).await; - collector.insert(self.stats_id.clone(), StatsReportType::DataChannel(stats)); - } - - pub(crate) fn set_ready_state(&self, r: RTCDataChannelState) { - self.ready_state.store(r as u8, Ordering::SeqCst); - } -} diff --git a/webrtc/src/dtls_transport/dtls_fingerprint.rs b/webrtc/src/dtls_transport/dtls_fingerprint.rs deleted file mode 100644 index eaea36763..000000000 --- a/webrtc/src/dtls_transport/dtls_fingerprint.rs +++ /dev/null @@ -1,15 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// DTLSFingerprint specifies the hash function algorithm and certificate -/// fingerprint as described in . -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct RTCDtlsFingerprint { - /// Algorithm specifies one of the the hash function algorithms defined in - /// the 'Hash function Textual Names' registry. - pub algorithm: String, - - /// Value specifies the value of the certificate fingerprint in lowercase - /// hex string as expressed utilizing the syntax of 'fingerprint' in - /// . - pub value: String, -} diff --git a/webrtc/src/dtls_transport/dtls_parameters.rs b/webrtc/src/dtls_transport/dtls_parameters.rs deleted file mode 100644 index 6aa14477f..000000000 --- a/webrtc/src/dtls_transport/dtls_parameters.rs +++ /dev/null @@ -1,11 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use super::dtls_fingerprint::*; -use super::dtls_role::*; - -/// DTLSParameters holds information relating to DTLS configuration. -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct DTLSParameters { - pub role: DTLSRole, - pub fingerprints: Vec, -} diff --git a/webrtc/src/dtls_transport/dtls_role.rs b/webrtc/src/dtls_transport/dtls_role.rs deleted file mode 100644 index 88730c631..000000000 --- a/webrtc/src/dtls_transport/dtls_role.rs +++ /dev/null @@ -1,170 +0,0 @@ -use std::fmt; - -use sdp::description::session::SessionDescription; -use sdp::util::ConnectionRole; -use serde::{Deserialize, Serialize}; - -/// DtlsRole indicates the role of the DTLS transport. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum DTLSRole { - #[default] - Unspecified = 0, - - /// DTLSRoleAuto defines the DTLS role is determined based on - /// the resolved ICE role: the ICE controlled role acts as the DTLS - /// client and the ICE controlling role acts as the DTLS server. - #[serde(rename = "auto")] - Auto = 1, - - /// DTLSRoleClient defines the DTLS client role. - #[serde(rename = "client")] - Client = 2, - - /// DTLSRoleServer defines the DTLS server role. - #[serde(rename = "server")] - Server = 3, -} - -/// -/// The answerer MUST use either a -/// setup attribute value of setup:active or setup:passive. Note that -/// if the answerer uses setup:passive, then the DTLS handshake will -/// not begin until the answerer is received, which adds additional -/// latency. setup:active allows the answer and the DTLS handshake to -/// occur in parallel. Thus, setup:active is RECOMMENDED. -pub(crate) const DEFAULT_DTLS_ROLE_ANSWER: DTLSRole = DTLSRole::Client; - -/// The endpoint that is the offerer MUST use the setup attribute -/// value of setup:actpass and be prepared to receive a client_hello -/// before it receives the answer. -pub(crate) const DEFAULT_DTLS_ROLE_OFFER: DTLSRole = DTLSRole::Auto; - -impl fmt::Display for DTLSRole { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - DTLSRole::Auto => write!(f, "auto"), - DTLSRole::Client => write!(f, "client"), - DTLSRole::Server => write!(f, "server"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -/// Iterate a SessionDescription from a remote to determine if an explicit -/// role can been determined from it. The decision is made from the first role we we parse. -/// If no role can be found we return DTLSRoleAuto -impl From<&SessionDescription> for DTLSRole { - fn from(session_description: &SessionDescription) -> Self { - for media_section in &session_description.media_descriptions { - for attribute in &media_section.attributes { - if attribute.key == "setup" { - if let Some(value) = &attribute.value { - match value.as_str() { - "active" => return DTLSRole::Client, - "passive" => return DTLSRole::Server, - _ => return DTLSRole::Auto, - }; - } else { - return DTLSRole::Auto; - } - } - } - } - - DTLSRole::Auto - } -} - -impl DTLSRole { - pub(crate) fn to_connection_role(self) -> ConnectionRole { - match self { - DTLSRole::Client => ConnectionRole::Active, - DTLSRole::Server => ConnectionRole::Passive, - DTLSRole::Auto => ConnectionRole::Actpass, - _ => ConnectionRole::Unspecified, - } - } -} - -#[cfg(test)] -mod test { - use std::io::Cursor; - - use super::*; - use crate::error::Result; - - #[test] - fn test_dtls_role_string() { - let tests = vec![ - (DTLSRole::Unspecified, "Unspecified"), - (DTLSRole::Auto, "auto"), - (DTLSRole::Client, "client"), - (DTLSRole::Server, "server"), - ]; - - for (role, expected_string) in tests { - assert_eq!(role.to_string(), expected_string) - } - } - - #[test] - fn test_dtls_role_from_remote_sdp() -> Result<()> { - const NO_MEDIA: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -"; - - const MEDIA_NO_SETUP: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=application 47299 DTLS/SCTP 5000 -c=IN IP4 192.168.20.129 -"; - - const MEDIA_SETUP_DECLARED: &str = "v=0 -o=- 4596489990601351948 2 IN IP4 127.0.0.1 -s=- -t=0 0 -m=application 47299 DTLS/SCTP 5000 -c=IN IP4 192.168.20.129 -a=setup:"; - - let tests = vec![ - ("No MediaDescriptions", NO_MEDIA.to_owned(), DTLSRole::Auto), - ( - "MediaDescription, no setup", - MEDIA_NO_SETUP.to_owned(), - DTLSRole::Auto, - ), - ( - "MediaDescription, setup:actpass", - format!("{}{}\n", MEDIA_SETUP_DECLARED, "actpass"), - DTLSRole::Auto, - ), - ( - "MediaDescription, setup:passive", - format!("{}{}\n", MEDIA_SETUP_DECLARED, "passive"), - DTLSRole::Server, - ), - ( - "MediaDescription, setup:active", - format!("{}{}\n", MEDIA_SETUP_DECLARED, "active"), - DTLSRole::Client, - ), - ]; - - for (name, session_description_str, expected_role) in tests { - let mut reader = Cursor::new(session_description_str.as_bytes()); - let session_description = SessionDescription::unmarshal(&mut reader)?; - assert_eq!( - DTLSRole::from(&session_description), - expected_role, - "{name} failed" - ); - } - - Ok(()) - } -} diff --git a/webrtc/src/dtls_transport/dtls_transport_state.rs b/webrtc/src/dtls_transport/dtls_transport_state.rs deleted file mode 100644 index 3700d8b81..000000000 --- a/webrtc/src/dtls_transport/dtls_transport_state.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::fmt; - -/// DTLSTransportState indicates the DTLS transport establishment state. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCDtlsTransportState { - #[default] - Unspecified = 0, - - /// DTLSTransportStateNew indicates that DTLS has not started negotiating - /// yet. - New = 1, - - /// DTLSTransportStateConnecting indicates that DTLS is in the process of - /// negotiating a secure connection and verifying the remote fingerprint. - Connecting = 2, - - /// DTLSTransportStateConnected indicates that DTLS has completed - /// negotiation of a secure connection and verified the remote fingerprint. - Connected = 3, - - /// DTLSTransportStateClosed indicates that the transport has been closed - /// intentionally as the result of receipt of a close_notify alert, or - /// calling close(). - Closed = 4, - - /// DTLSTransportStateFailed indicates that the transport has failed as - /// the result of an error (such as receipt of an error alert or failure to - /// validate the remote fingerprint). - Failed = 5, -} - -const DTLS_TRANSPORT_STATE_NEW_STR: &str = "new"; -const DTLS_TRANSPORT_STATE_CONNECTING_STR: &str = "connecting"; -const DTLS_TRANSPORT_STATE_CONNECTED_STR: &str = "connected"; -const DTLS_TRANSPORT_STATE_CLOSED_STR: &str = "closed"; -const DTLS_TRANSPORT_STATE_FAILED_STR: &str = "failed"; - -impl From<&str> for RTCDtlsTransportState { - fn from(raw: &str) -> Self { - match raw { - DTLS_TRANSPORT_STATE_NEW_STR => RTCDtlsTransportState::New, - DTLS_TRANSPORT_STATE_CONNECTING_STR => RTCDtlsTransportState::Connecting, - DTLS_TRANSPORT_STATE_CONNECTED_STR => RTCDtlsTransportState::Connected, - DTLS_TRANSPORT_STATE_CLOSED_STR => RTCDtlsTransportState::Closed, - DTLS_TRANSPORT_STATE_FAILED_STR => RTCDtlsTransportState::Failed, - _ => RTCDtlsTransportState::Unspecified, - } - } -} - -impl From for RTCDtlsTransportState { - fn from(v: u8) -> Self { - match v { - 1 => RTCDtlsTransportState::New, - 2 => RTCDtlsTransportState::Connecting, - 3 => RTCDtlsTransportState::Connected, - 4 => RTCDtlsTransportState::Closed, - 5 => RTCDtlsTransportState::Failed, - _ => RTCDtlsTransportState::Unspecified, - } - } -} - -impl fmt::Display for RTCDtlsTransportState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCDtlsTransportState::New => DTLS_TRANSPORT_STATE_NEW_STR, - RTCDtlsTransportState::Connecting => DTLS_TRANSPORT_STATE_CONNECTING_STR, - RTCDtlsTransportState::Connected => DTLS_TRANSPORT_STATE_CONNECTED_STR, - RTCDtlsTransportState::Closed => DTLS_TRANSPORT_STATE_CLOSED_STR, - RTCDtlsTransportState::Failed => DTLS_TRANSPORT_STATE_FAILED_STR, - RTCDtlsTransportState::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_dtls_transport_state() { - let tests = vec![ - (crate::UNSPECIFIED_STR, RTCDtlsTransportState::Unspecified), - ("new", RTCDtlsTransportState::New), - ("connecting", RTCDtlsTransportState::Connecting), - ("connected", RTCDtlsTransportState::Connected), - ("closed", RTCDtlsTransportState::Closed), - ("failed", RTCDtlsTransportState::Failed), - ]; - - for (state_string, expected_state) in tests { - assert_eq!( - RTCDtlsTransportState::from(state_string), - expected_state, - "testCase: {expected_state}", - ); - } - } - - #[test] - fn test_dtls_transport_state_string() { - let tests = vec![ - (RTCDtlsTransportState::Unspecified, crate::UNSPECIFIED_STR), - (RTCDtlsTransportState::New, "new"), - (RTCDtlsTransportState::Connecting, "connecting"), - (RTCDtlsTransportState::Connected, "connected"), - (RTCDtlsTransportState::Closed, "closed"), - (RTCDtlsTransportState::Failed, "failed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string) - } - } -} diff --git a/webrtc/src/dtls_transport/dtls_transport_test.rs b/webrtc/src/dtls_transport/dtls_transport_test.rs deleted file mode 100644 index a9e71aca0..000000000 --- a/webrtc/src/dtls_transport/dtls_transport_test.rs +++ /dev/null @@ -1,204 +0,0 @@ -use ice::mdns::MulticastDnsMode; -use ice::network_type::NetworkType; -use regex::Regex; -use tokio::time::Duration; -use waitgroup::WaitGroup; - -use super::*; -use crate::api::media_engine::MediaEngine; -use crate::api::APIBuilder; -use crate::data_channel::RTCDataChannel; -use crate::ice_transport::ice_candidate::RTCIceCandidate; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; -use crate::peer_connection::peer_connection_test::{ - close_pair_now, new_pair, signal_pair, until_connection_state, -}; - -//use log::LevelFilter; -//use std::io::Write; - -// An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed -#[tokio::test] -async fn test_invalid_fingerprint_causes_failed() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; - - pc_answer.on_data_channel(Box::new(|_: Arc| { - panic!("A DataChannel must not be created when Fingerprint verification fails"); - })); - - let (offer_chan_tx, mut offer_chan_rx) = mpsc::channel::<()>(1); - - let offer_chan_tx = Arc::new(offer_chan_tx); - pc_offer.on_ice_candidate(Box::new(move |candidate: Option| { - let offer_chan_tx2 = Arc::clone(&offer_chan_tx); - Box::pin(async move { - if candidate.is_none() { - let _ = offer_chan_tx2.send(()).await; - } - }) - })); - - let offer_connection_has_failed = WaitGroup::new(); - until_connection_state( - &mut pc_offer, - &offer_connection_has_failed, - RTCPeerConnectionState::Failed, - ) - .await; - let answer_connection_has_failed = WaitGroup::new(); - until_connection_state( - &mut pc_answer, - &answer_connection_has_failed, - RTCPeerConnectionState::Failed, - ) - .await; - - let _ = pc_offer - .create_data_channel("unusedDataChannel", None) - .await?; - - let offer = pc_offer.create_offer(None).await?; - pc_offer.set_local_description(offer).await?; - - let timeout = tokio::time::sleep(Duration::from_secs(1)); - tokio::pin!(timeout); - - tokio::select! { - _ = offer_chan_rx.recv() =>{ - let mut offer = pc_offer.pending_local_description().await.unwrap(); - - log::trace!("receiving pending local desc: {:?}", offer); - - // Replace with invalid fingerprint - let re = Regex::new(r"sha-256 (.*?)\r").unwrap(); - offer.sdp = re.replace_all(offer.sdp.as_str(), "sha-256 AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA\r").to_string(); - - pc_answer.set_remote_description(offer).await?; - - let mut answer = pc_answer.create_answer(None).await?; - - pc_answer.set_local_description(answer.clone()).await?; - - answer.sdp = re.replace_all(answer.sdp.as_str(), "sha-256 AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA\r").to_string(); - - pc_offer.set_remote_description(answer).await?; - } - _ = timeout.as_mut() =>{ - panic!("timed out waiting to receive offer"); - } - } - - log::trace!("offer_connection_has_failed wait begin"); - - offer_connection_has_failed.wait().await; - answer_connection_has_failed.wait().await; - - log::trace!("offer_connection_has_failed wait end"); - { - let transport = pc_offer.sctp().transport(); - assert_eq!(transport.state(), RTCDtlsTransportState::Failed); - assert!(transport.conn().await.is_none()); - } - - { - let transport = pc_answer.sctp().transport(); - assert_eq!(transport.state(), RTCDtlsTransportState::Failed); - assert!(transport.conn().await.is_none()); - } - - close_pair_now(&pc_offer, &pc_answer).await; - - Ok(()) -} - -async fn run_test(r: DTLSRole) -> Result<()> { - let mut offer_s = SettingEngine::default(); - offer_s.set_answering_dtls_role(r)?; - offer_s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); - offer_s.set_network_types(vec![NetworkType::Udp4]); - let mut offer_pc = APIBuilder::new() - .with_setting_engine(offer_s) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - let mut answer_s = SettingEngine::default(); - answer_s.set_answering_dtls_role(r)?; - answer_s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); - answer_s.set_network_types(vec![NetworkType::Udp4]); - let mut answer_pc = APIBuilder::new() - .with_setting_engine(answer_s) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - signal_pair(&mut offer_pc, &mut answer_pc).await?; - - let wg = WaitGroup::new(); - until_connection_state(&mut answer_pc, &wg, RTCPeerConnectionState::Connected).await; - wg.wait().await; - - close_pair_now(&offer_pc, &answer_pc).await; - - Ok(()) -} - -#[tokio::test] -async fn test_peer_connection_dtls_role_setting_engine_server() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - run_test(DTLSRole::Server).await -} - -#[tokio::test] -async fn test_peer_connection_dtls_role_setting_engine_client() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - run_test(DTLSRole::Client).await -} diff --git a/webrtc/src/dtls_transport/mod.rs b/webrtc/src/dtls_transport/mod.rs deleted file mode 100644 index 7e07f6316..000000000 --- a/webrtc/src/dtls_transport/mod.rs +++ /dev/null @@ -1,617 +0,0 @@ -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use bytes::Bytes; -use dtls::config::ClientAuthType; -use dtls::conn::DTLSConn; -use dtls::extension::extension_use_srtp::SrtpProtectionProfile; -use dtls_role::*; -use interceptor::stream_info::StreamInfo; -use interceptor::{Interceptor, RTCPReader, RTPReader}; -use portable_atomic::{AtomicBool, AtomicU8}; -use sha2::{Digest, Sha256}; -use srtp::protection_profile::ProtectionProfile; -use srtp::session::Session; -use srtp::stream::Stream; -use tokio::sync::{mpsc, Mutex}; -use util::Conn; - -use crate::api::setting_engine::SettingEngine; -use crate::dtls_transport::dtls_parameters::DTLSParameters; -use crate::dtls_transport::dtls_transport_state::RTCDtlsTransportState; -use crate::error::{flatten_errs, Error, Result}; -use crate::ice_transport::ice_role::RTCIceRole; -use crate::ice_transport::ice_transport_state::RTCIceTransportState; -use crate::ice_transport::RTCIceTransport; -use crate::mux::endpoint::Endpoint; -use crate::mux::mux_func::{match_dtls, match_srtcp, match_srtp, MatchFunc}; -use crate::peer_connection::certificate::RTCCertificate; -use crate::rtp_transceiver::SSRC; -use crate::stats::stats_collector::StatsCollector; - -#[cfg(test)] -mod dtls_transport_test; - -pub mod dtls_fingerprint; -pub mod dtls_parameters; -pub mod dtls_role; -pub mod dtls_transport_state; - -pub(crate) fn default_srtp_protection_profiles() -> Vec { - vec![ - SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm, - SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, - ] -} - -pub type OnDTLSTransportStateChangeHdlrFn = Box< - dyn (FnMut(RTCDtlsTransportState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -/// DTLSTransport allows an application access to information about the DTLS -/// transport over which RTP and RTCP packets are sent and received by -/// RTPSender and RTPReceiver, as well other data such as SCTP packets sent -/// and received by data channels. -#[derive(Default)] -pub struct RTCDtlsTransport { - pub(crate) ice_transport: Arc, - pub(crate) certificates: Vec, - pub(crate) setting_engine: Arc, - - pub(crate) remote_parameters: Mutex, - pub(crate) remote_certificate: Mutex, - pub(crate) state: AtomicU8, //DTLSTransportState, - pub(crate) srtp_protection_profile: Mutex, - pub(crate) on_state_change_handler: ArcSwapOption>, - pub(crate) conn: Mutex>>, - - pub(crate) srtp_session: Mutex>>, - pub(crate) srtcp_session: Mutex>>, - pub(crate) srtp_endpoint: Mutex>>, - pub(crate) srtcp_endpoint: Mutex>>, - - pub(crate) simulcast_streams: Mutex>>, - - pub(crate) srtp_ready_signal: Arc, - pub(crate) srtp_ready_tx: Mutex>>, - pub(crate) srtp_ready_rx: Mutex>>, - - pub(crate) dtls_matcher: Option, -} - -impl RTCDtlsTransport { - pub(crate) fn new( - ice_transport: Arc, - certificates: Vec, - setting_engine: Arc, - ) -> Self { - let (srtp_ready_tx, srtp_ready_rx) = mpsc::channel(1); - RTCDtlsTransport { - ice_transport, - certificates, - setting_engine, - srtp_ready_signal: Arc::new(AtomicBool::new(false)), - srtp_ready_tx: Mutex::new(Some(srtp_ready_tx)), - srtp_ready_rx: Mutex::new(Some(srtp_ready_rx)), - state: AtomicU8::new(RTCDtlsTransportState::New as u8), - dtls_matcher: Some(Box::new(match_dtls)), - ..Default::default() - } - } - - pub(crate) async fn conn(&self) -> Option> { - let conn = self.conn.lock().await; - conn.clone() - } - - /// returns the currently-configured ICETransport or None - /// if one has not been configured - pub fn ice_transport(&self) -> &RTCIceTransport { - &self.ice_transport - } - - /// state_change requires the caller holds the lock - async fn state_change(&self, state: RTCDtlsTransportState) { - self.state.store(state as u8, Ordering::SeqCst); - if let Some(handler) = &*self.on_state_change_handler.load() { - let mut f = handler.lock().await; - f(state).await; - } - } - - /// on_state_change sets a handler that is fired when the DTLS - /// connection state changes. - pub fn on_state_change(&self, f: OnDTLSTransportStateChangeHdlrFn) { - self.on_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// state returns the current dtls_transport transport state. - pub fn state(&self) -> RTCDtlsTransportState { - self.state.load(Ordering::SeqCst).into() - } - - /// write_rtcp sends a user provided RTCP packet to the connected peer. If no peer is connected the - /// packet is discarded. - pub async fn write_rtcp( - &self, - pkts: &[Box], - ) -> Result { - let srtcp_session = self.srtcp_session.lock().await; - if let Some(srtcp_session) = &*srtcp_session { - let raw = rtcp::packet::marshal(pkts)?; - Ok(srtcp_session.write(&raw, false).await?) - } else { - Ok(0) - } - } - - /// get_local_parameters returns the DTLS parameters of the local DTLSTransport upon construction. - pub fn get_local_parameters(&self) -> Result { - let mut fingerprints = vec![]; - - for c in &self.certificates { - fingerprints.extend(c.get_fingerprints()); - } - - Ok(DTLSParameters { - role: DTLSRole::Auto, // always returns the default role - fingerprints, - }) - } - - /// get_remote_certificate returns the certificate chain in use by the remote side - /// returns an empty list prior to selection of the remote certificate - pub async fn get_remote_certificate(&self) -> Bytes { - let remote_certificate = self.remote_certificate.lock().await; - remote_certificate.clone() - } - - pub(crate) async fn start_srtp(&self) -> Result<()> { - let profile = { - let srtp_protection_profile = self.srtp_protection_profile.lock().await; - *srtp_protection_profile - }; - - let mut srtp_config = srtp::config::Config { - profile, - ..Default::default() - }; - - if self.setting_engine.replay_protection.srtp != 0 { - srtp_config.remote_rtp_options = Some(srtp::option::srtp_replay_protection( - self.setting_engine.replay_protection.srtp, - )); - } else if self.setting_engine.disable_srtp_replay_protection { - srtp_config.remote_rtp_options = Some(srtp::option::srtp_no_replay_protection()); - } - - if let Some(conn) = self.conn().await { - let conn_state = conn.connection_state().await; - srtp_config - .extract_session_keys_from_dtls(conn_state, self.role().await == DTLSRole::Client) - .await?; - } else { - return Err(Error::ErrDtlsTransportNotStarted); - } - - { - let mut srtp_session = self.srtp_session.lock().await; - *srtp_session = { - let se = self.srtp_endpoint.lock().await; - if let Some(srtp_endpoint) = &*se { - Some(Arc::new( - Session::new( - Arc::clone(srtp_endpoint) as Arc, - srtp_config, - true, - ) - .await?, - )) - } else { - None - } - }; - } - - let mut srtcp_config = srtp::config::Config { - profile, - ..Default::default() - }; - if self.setting_engine.replay_protection.srtcp != 0 { - srtcp_config.remote_rtcp_options = Some(srtp::option::srtcp_replay_protection( - self.setting_engine.replay_protection.srtcp, - )); - } else if self.setting_engine.disable_srtcp_replay_protection { - srtcp_config.remote_rtcp_options = Some(srtp::option::srtcp_no_replay_protection()); - } - - if let Some(conn) = self.conn().await { - let conn_state = conn.connection_state().await; - srtcp_config - .extract_session_keys_from_dtls(conn_state, self.role().await == DTLSRole::Client) - .await?; - } else { - return Err(Error::ErrDtlsTransportNotStarted); - } - - { - let mut srtcp_session = self.srtcp_session.lock().await; - *srtcp_session = { - let se = self.srtcp_endpoint.lock().await; - if let Some(srtcp_endpoint) = &*se { - Some(Arc::new( - Session::new( - Arc::clone(srtcp_endpoint) as Arc, - srtcp_config, - false, - ) - .await?, - )) - } else { - None - } - }; - } - - { - let mut srtp_ready_tx = self.srtp_ready_tx.lock().await; - srtp_ready_tx.take(); - if srtp_ready_tx.is_none() { - self.srtp_ready_signal.store(true, Ordering::SeqCst); - } - } - - Ok(()) - } - - pub(crate) async fn get_srtp_session(&self) -> Option> { - let srtp_session = self.srtp_session.lock().await; - srtp_session.clone() - } - - pub(crate) async fn get_srtcp_session(&self) -> Option> { - let srtcp_session = self.srtcp_session.lock().await; - srtcp_session.clone() - } - - pub(crate) async fn role(&self) -> DTLSRole { - // If remote has an explicit role use the inverse - { - let remote_parameters = self.remote_parameters.lock().await; - match remote_parameters.role { - DTLSRole::Client => return DTLSRole::Server, - DTLSRole::Server => return DTLSRole::Client, - _ => {} - }; - } - - // If SettingEngine has an explicit role - match self.setting_engine.answering_dtls_role { - DTLSRole::Server => return DTLSRole::Server, - DTLSRole::Client => return DTLSRole::Client, - _ => {} - }; - - // Remote was auto and no explicit role was configured via SettingEngine - if self.ice_transport.role().await == RTCIceRole::Controlling { - return DTLSRole::Server; - } - - DEFAULT_DTLS_ROLE_ANSWER - } - - pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { - for cert in &self.certificates { - cert.collect_stats(collector).await; - } - } - - async fn prepare_transport( - &self, - remote_parameters: DTLSParameters, - ) -> Result<(DTLSRole, dtls::config::Config)> { - self.ensure_ice_conn()?; - - if self.state() != RTCDtlsTransportState::New { - return Err(Error::ErrInvalidDTLSStart); - } - - { - let mut srtp_endpoint = self.srtp_endpoint.lock().await; - *srtp_endpoint = self.ice_transport.new_endpoint(Box::new(match_srtp)).await; - } - { - let mut srtcp_endpoint = self.srtcp_endpoint.lock().await; - *srtcp_endpoint = self.ice_transport.new_endpoint(Box::new(match_srtcp)).await; - } - { - let mut rp = self.remote_parameters.lock().await; - *rp = remote_parameters; - } - - let certificate = if let Some(cert) = self.certificates.first() { - cert.dtls_certificate.clone() - } else { - return Err(Error::ErrNonCertificate); - }; - self.state_change(RTCDtlsTransportState::Connecting).await; - - Ok(( - self.role().await, - dtls::config::Config { - certificates: vec![certificate], - srtp_protection_profiles: if !self - .setting_engine - .srtp_protection_profiles - .is_empty() - { - self.setting_engine.srtp_protection_profiles.clone() - } else { - default_srtp_protection_profiles() - }, - client_auth: ClientAuthType::RequireAnyClientCert, - insecure_skip_verify: true, - insecure_verification: self.setting_engine.allow_insecure_verification_algorithm, - ..Default::default() - }, - )) - } - - /// start DTLS transport negotiation with the parameters of the remote DTLS transport - pub async fn start(&self, remote_parameters: DTLSParameters) -> Result<()> { - let dtls_conn_result = if let Some(dtls_endpoint) = - self.ice_transport.new_endpoint(Box::new(match_dtls)).await - { - let (role, mut dtls_config) = self.prepare_transport(remote_parameters).await?; - if self.setting_engine.replay_protection.dtls != 0 { - dtls_config.replay_protection_window = self.setting_engine.replay_protection.dtls; - } - - // Connect as DTLS Client/Server, function is blocking and we - // must not hold the DTLSTransport lock - if role == DTLSRole::Client { - dtls::conn::DTLSConn::new( - dtls_endpoint as Arc, - dtls_config, - true, - None, - ) - .await - } else { - dtls::conn::DTLSConn::new( - dtls_endpoint as Arc, - dtls_config, - false, - None, - ) - .await - } - } else { - Err(dtls::Error::Other( - "ice_transport.new_endpoint failed".to_owned(), - )) - }; - - let dtls_conn = match dtls_conn_result { - Ok(dtls_conn) => dtls_conn, - Err(err) => { - self.state_change(RTCDtlsTransportState::Failed).await; - return Err(err.into()); - } - }; - - let srtp_profile = dtls_conn.selected_srtpprotection_profile(); - { - let mut srtp_protection_profile = self.srtp_protection_profile.lock().await; - *srtp_protection_profile = match srtp_profile { - dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm => { - srtp::protection_profile::ProtectionProfile::AeadAes128Gcm - } - dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80 => { - srtp::protection_profile::ProtectionProfile::Aes128CmHmacSha1_80 - } - _ => { - if let Err(err) = dtls_conn.close().await { - log::error!("{}", err); - } - - self.state_change(RTCDtlsTransportState::Failed).await; - return Err(Error::ErrNoSRTPProtectionProfile); - } - }; - } - - // Check the fingerprint if a certificate was exchanged - let remote_certs = &dtls_conn.connection_state().await.peer_certificates; - if remote_certs.is_empty() { - if let Err(err) = dtls_conn.close().await { - log::error!("{}", err); - } - - self.state_change(RTCDtlsTransportState::Failed).await; - return Err(Error::ErrNoRemoteCertificate); - } - - { - let mut remote_certificate = self.remote_certificate.lock().await; - *remote_certificate = Bytes::from(remote_certs[0].clone()); - } - - if !self - .setting_engine - .disable_certificate_fingerprint_verification - { - if let Err(err) = self.validate_fingerprint(&remote_certs[0]).await { - if let Err(close_err) = dtls_conn.close().await { - log::error!("{}", close_err); - } - - self.state_change(RTCDtlsTransportState::Failed).await; - return Err(err); - } - } - - { - let mut conn = self.conn.lock().await; - *conn = Some(Arc::new(dtls_conn)); - } - self.state_change(RTCDtlsTransportState::Connected).await; - - self.start_srtp().await - } - - /// stops and closes the DTLSTransport object. - pub async fn stop(&self) -> Result<()> { - // Try closing everything and collect the errors - let mut close_errs: Vec = vec![]; - { - let srtp_session = { - let mut srtp_session = self.srtp_session.lock().await; - srtp_session.take() - }; - if let Some(srtp_session) = srtp_session { - match srtp_session.close().await { - Ok(_) => {} - Err(err) => { - close_errs.push(err.into()); - } - }; - } - } - - { - let srtcp_session = { - let mut srtcp_session = self.srtcp_session.lock().await; - srtcp_session.take() - }; - if let Some(srtcp_session) = srtcp_session { - match srtcp_session.close().await { - Ok(_) => {} - Err(err) => { - close_errs.push(err.into()); - } - }; - } - } - - { - let simulcast_streams: Vec> = { - let mut simulcast_streams = self.simulcast_streams.lock().await; - simulcast_streams.drain().map(|(_, v)| v).collect() - }; - for ss in simulcast_streams { - match ss.close().await { - Ok(_) => {} - Err(err) => { - close_errs.push(Error::new(format!( - "simulcast_streams ssrc={}: {}", - ss.get_ssrc(), - err - ))); - } - }; - } - } - - if let Some(conn) = self.conn().await { - // dtls_transport connection may be closed on sctp close. - match conn.close().await { - Ok(_) => {} - Err(err) => { - if err.to_string() != dtls::Error::ErrConnClosed.to_string() { - close_errs.push(err.into()); - } - } - } - } - - self.state_change(RTCDtlsTransportState::Closed).await; - - flatten_errs(close_errs) - } - - pub(crate) async fn validate_fingerprint(&self, remote_cert: &[u8]) -> Result<()> { - let remote_parameters = self.remote_parameters.lock().await; - for fp in &remote_parameters.fingerprints { - if fp.algorithm != "sha-256" { - return Err(Error::ErrUnsupportedFingerprintAlgorithm); - } - - let mut h = Sha256::new(); - h.update(remote_cert); - let hashed = h.finalize(); - let values: Vec = hashed.iter().map(|x| format! {"{x:02x}"}).collect(); - let remote_value = values.join(":").to_lowercase(); - - if remote_value == fp.value.to_lowercase() { - return Ok(()); - } - } - - Err(Error::ErrNoMatchingCertificateFingerprint) - } - - pub(crate) fn ensure_ice_conn(&self) -> Result<()> { - if self.ice_transport.state() == RTCIceTransportState::New { - Err(Error::ErrICEConnectionNotStarted) - } else { - Ok(()) - } - } - - pub(crate) async fn store_simulcast_stream(&self, ssrc: SSRC, stream: Arc) { - let mut simulcast_streams = self.simulcast_streams.lock().await; - simulcast_streams.insert(ssrc, stream); - } - - pub(crate) async fn remove_simulcast_stream(&self, ssrc: SSRC) { - let mut simulcast_streams = self.simulcast_streams.lock().await; - simulcast_streams.remove(&ssrc); - } - - pub(crate) async fn streams_for_ssrc( - &self, - ssrc: SSRC, - stream_info: &StreamInfo, - interceptor: &Arc, - ) -> Result<( - Arc, - Arc, - Arc, - Arc, - )> { - let srtp_session = self - .get_srtp_session() - .await - .ok_or(Error::ErrDtlsTransportNotStarted)?; - //log::debug!("streams_for_ssrc: srtp_session.listen ssrc={}", ssrc); - let rtp_read_stream = srtp_session.open(ssrc).await; - let rtp_stream_reader = Arc::clone(&rtp_read_stream) as Arc; - let rtp_interceptor = interceptor - .bind_remote_stream(stream_info, rtp_stream_reader) - .await; - - let srtcp_session = self - .get_srtcp_session() - .await - .ok_or(Error::ErrDtlsTransportNotStarted)?; - //log::debug!("streams_for_ssrc: srtcp_session.listen ssrc={}", ssrc); - let rtcp_read_stream = srtcp_session.open(ssrc).await; - let rtcp_stream_reader = Arc::clone(&rtcp_read_stream) as Arc; - let rtcp_interceptor = interceptor.bind_rtcp_reader(rtcp_stream_reader).await; - - Ok(( - rtp_read_stream, - rtp_interceptor, - rtcp_read_stream, - rtcp_interceptor, - )) - } -} diff --git a/webrtc/src/error.rs b/webrtc/src/error.rs deleted file mode 100644 index 9a9ff4403..000000000 --- a/webrtc/src/error.rs +++ /dev/null @@ -1,482 +0,0 @@ -use std::future::Future; -use std::num::ParseIntError; -use std::pin::Pin; -use std::string::FromUtf8Error; - -use thiserror::Error; -use tokio::sync::mpsc::error::SendError as MpscSendError; - -use crate::peer_connection::sdp::sdp_type::RTCSdpType; -use crate::peer_connection::signaling_state::RTCSignalingState; -use crate::rtp_transceiver::rtp_receiver; -#[cfg(doc)] -use crate::rtp_transceiver::rtp_sender; - -pub type Result = std::result::Result; - -#[derive(Error, Debug, PartialEq)] -#[non_exhaustive] -pub enum Error { - /// ErrUnknownType indicates an error with Unknown info. - #[error("unknown")] - ErrUnknownType, - - /// ErrConnectionClosed indicates an operation executed after connection - /// has already been closed. - #[error("connection closed")] - ErrConnectionClosed, - - /// ErrDataChannelNotOpen indicates an operation executed when the data - /// channel is not (yet) open. - #[error("data channel not open")] - ErrDataChannelNotOpen, - - /// ErrCertificateExpired indicates that an x509 certificate has expired. - #[error("x509Cert expired")] - ErrCertificateExpired, - - /// ErrNoTurnCredentials indicates that a TURN server URL was provided - /// without required credentials. - #[error("turn server credentials required")] - ErrNoTurnCredentials, - - /// ErrTurnCredentials indicates that provided TURN credentials are partial - /// or malformed. - #[error("invalid turn server credentials")] - ErrTurnCredentials, - - /// ErrExistingTrack indicates that a track already exists. - #[error("track already exists")] - ErrExistingTrack, - - /// ErrPrivateKeyType indicates that a particular private key encryption - /// chosen to generate a certificate is not supported. - #[error("private key type not supported")] - ErrPrivateKeyType, - - /// ErrModifyingPeerIdentity indicates that an attempt to modify - /// PeerIdentity was made after PeerConnection has been initialized. - #[error("peerIdentity cannot be modified")] - ErrModifyingPeerIdentity, - - /// ErrModifyingCertificates indicates that an attempt to modify - /// Certificates was made after PeerConnection has been initialized. - #[error("certificates cannot be modified")] - ErrModifyingCertificates, - - /// ErrNonCertificate indicates that there is no certificate - #[error("no certificate")] - ErrNonCertificate, - - /// ErrModifyingBundlePolicy indicates that an attempt to modify - /// BundlePolicy was made after PeerConnection has been initialized. - #[error("bundle policy cannot be modified")] - ErrModifyingBundlePolicy, - - /// ErrModifyingRTCPMuxPolicy indicates that an attempt to modify - /// RTCPMuxPolicy was made after PeerConnection has been initialized. - #[error("rtcp mux policy cannot be modified")] - ErrModifyingRTCPMuxPolicy, - - /// ErrModifyingICECandidatePoolSize indicates that an attempt to modify - /// ICECandidatePoolSize was made after PeerConnection has been initialized. - #[error("ice candidate pool size cannot be modified")] - ErrModifyingICECandidatePoolSize, - - /// ErrStringSizeLimit indicates that the character size limit of string is - /// exceeded. The limit is hardcoded to 65535 according to specifications. - #[error("data channel label exceeds size limit")] - ErrStringSizeLimit, - - /// ErrMaxDataChannelID indicates that the maximum number ID that could be - /// specified for a data channel has been exceeded. - #[error("maximum number ID for datachannel specified")] - ErrMaxDataChannelID, - - /// ErrNegotiatedWithoutID indicates that an attempt to create a data channel - /// was made while setting the negotiated option to true without providing - /// the negotiated channel ID. - #[error("negotiated set without channel id")] - ErrNegotiatedWithoutID, - - /// ErrRetransmitsOrPacketLifeTime indicates that an attempt to create a data - /// channel was made with both options max_packet_life_time and max_retransmits - /// set together. Such configuration is not supported by the specification - /// and is mutually exclusive. - #[error("both max_packet_life_time and max_retransmits was set")] - ErrRetransmitsOrPacketLifeTime, - - /// ErrCodecNotFound is returned when a codec search to the Media Engine fails - #[error("codec not found")] - ErrCodecNotFound, - - /// ErrNoRemoteDescription indicates that an operation was rejected because - /// the remote description is not set - #[error("remote description is not set")] - ErrNoRemoteDescription, - - /// ErrIncorrectSDPSemantics indicates that the PeerConnection was configured to - /// generate SDP Answers with different SDP Semantics than the received Offer - #[error("offer SDP semantics does not match configuration")] - ErrIncorrectSDPSemantics, - - /// ErrIncorrectSignalingState indicates that the signaling state of PeerConnection is not correct - #[error("operation can not be run in current signaling state")] - ErrIncorrectSignalingState, - - /// ErrProtocolTooLarge indicates that value given for a DataChannelInit protocol is - /// longer then 65535 bytes - #[error("protocol is larger then 65535 bytes")] - ErrProtocolTooLarge, - - /// ErrSenderNotCreatedByConnection indicates remove_track was called with a - /// [`rtp_sender::RTCRtpSender`] not created by this PeerConnection - #[error("RtpSender not created by this PeerConnection")] - ErrSenderNotCreatedByConnection, - - /// ErrSenderInitialTrackIdAlreadySet indicates a second call to - /// `RTCRtpSender::set_initial_track_id` which is not allowed. Purely internal error, should not happen in practice. - #[error("RtpSender's initial_track_id has already been set")] - ErrSenderInitialTrackIdAlreadySet, - - /// ErrSessionDescriptionNoFingerprint indicates set_remote_description was called with a SessionDescription that has no - /// fingerprint - #[error("set_remote_description called with no fingerprint")] - ErrSessionDescriptionNoFingerprint, - - /// ErrSessionDescriptionInvalidFingerprint indicates set_remote_description was called with a SessionDescription that - /// has an invalid fingerprint - #[error("set_remote_description called with an invalid fingerprint")] - ErrSessionDescriptionInvalidFingerprint, - - /// ErrSessionDescriptionConflictingFingerprints indicates set_remote_description was called with a SessionDescription that - /// has an conflicting fingerprints - #[error("set_remote_description called with multiple conflicting fingerprint")] - ErrSessionDescriptionConflictingFingerprints, - - /// ErrSessionDescriptionMissingIceUfrag indicates set_remote_description was called with a SessionDescription that - /// is missing an ice-ufrag value - #[error("set_remote_description called with no ice-ufrag")] - ErrSessionDescriptionMissingIceUfrag, - - /// ErrSessionDescriptionMissingIcePwd indicates set_remote_description was called with a SessionDescription that - /// is missing an ice-pwd value - #[error("set_remote_description called with no ice-pwd")] - ErrSessionDescriptionMissingIcePwd, - - /// ErrSessionDescriptionConflictingIceUfrag indicates set_remote_description was called with a SessionDescription that - /// contains multiple conflicting ice-ufrag values - #[error("set_remote_description called with multiple conflicting ice-ufrag values")] - ErrSessionDescriptionConflictingIceUfrag, - - /// ErrSessionDescriptionConflictingIcePwd indicates set_remote_description was called with a SessionDescription that - /// contains multiple conflicting ice-pwd values - #[error("set_remote_description called with multiple conflicting ice-pwd values")] - ErrSessionDescriptionConflictingIcePwd, - - /// ErrNoSRTPProtectionProfile indicates that the DTLS handshake completed and no SRTP Protection Profile was chosen - #[error("DTLS Handshake completed and no SRTP Protection Profile was chosen")] - ErrNoSRTPProtectionProfile, - - /// ErrFailedToGenerateCertificateFingerprint indicates that we failed to generate the fingerprint used for comparing certificates - #[error("failed to generate certificate fingerprint")] - ErrFailedToGenerateCertificateFingerprint, - - /// ErrNoCodecsAvailable indicates that operation isn't possible because the MediaEngine has no codecs available - #[error("operation failed no codecs are available")] - ErrNoCodecsAvailable, - - /// ErrUnsupportedCodec indicates the remote peer doesn't support the requested codec - #[error("unable to start track, codec is not supported by remote")] - ErrUnsupportedCodec, - - /// ErrSenderWithNoCodecs indicates that a RTPSender was created without any codecs. To send media the MediaEngine needs at - /// least one configured codec. - #[error("unable to populate media section, RTPSender created with no codecs")] - ErrSenderWithNoCodecs, - - /// ErrRTPSenderNewTrackHasIncorrectKind indicates that the new track is of a different kind than the previous/original - #[error("new track must be of the same kind as previous")] - ErrRTPSenderNewTrackHasIncorrectKind, - - /// ErrRTPSenderNewTrackHasIncorrectEnvelope indicates that the new track has a different envelope than the previous/original - #[error("new track must have the same envelope as previous")] - ErrRTPSenderNewTrackHasIncorrectEnvelope, - - /// ErrRTPSenderDataSent indicates that the sequence number transformer tries to be enabled after the data sending began - #[error("Sequence number transformer must be enabled before sending data")] - ErrRTPSenderDataSent, - - /// ErrRTPSenderSeqTransEnabled indicates that the sequence number transformer has been already enabled - #[error("Sequence number transformer has been already enabled")] - ErrRTPSenderSeqTransEnabled, - - /// ErrUnbindFailed indicates that a TrackLocal was not able to be unbind - #[error("failed to unbind TrackLocal from PeerConnection")] - ErrUnbindFailed, - - /// ErrNoPayloaderForCodec indicates that the requested codec does not have a payloader - #[error("the requested codec does not have a payloader")] - ErrNoPayloaderForCodec, - - /// ErrRegisterHeaderExtensionInvalidDirection indicates that a extension was registered with different - /// directions for two different calls. - #[error("a header extension must be registered with the same direction each time")] - ErrRegisterHeaderExtensionInvalidDirection, - - /// ErrRegisterHeaderExtensionNoFreeID indicates that there was no extension ID available which - /// in turn means that all 15 available id(1 through 14) have been used. - #[error("no header extension ID was free to use(this means the maximum of 15 extensions have been registered)")] - ErrRegisterHeaderExtensionNoFreeID, - - /// ErrSimulcastProbeOverflow indicates that too many Simulcast probe streams are in flight and the requested SSRC was ignored - #[error("simulcast probe limit has been reached, new SSRC has been discarded")] - ErrSimulcastProbeOverflow, - - #[error("enable detaching by calling webrtc.DetachDataChannels()")] - ErrDetachNotEnabled, - #[error("datachannel not opened yet, try calling Detach from OnOpen")] - ErrDetachBeforeOpened, - #[error("the DTLS transport has not started yet")] - ErrDtlsTransportNotStarted, - #[error("failed extracting keys from DTLS for SRTP")] - ErrDtlsKeyExtractionFailed, - #[error("failed to start SRTP")] - ErrFailedToStartSRTP, - #[error("failed to start SRTCP")] - ErrFailedToStartSRTCP, - #[error("attempted to start DTLSTransport that is not in new state")] - ErrInvalidDTLSStart, - #[error("peer didn't provide certificate via DTLS")] - ErrNoRemoteCertificate, - #[error("identity provider is not implemented")] - ErrIdentityProviderNotImplemented, - #[error("remote certificate does not match any fingerprint")] - ErrNoMatchingCertificateFingerprint, - #[error("unsupported fingerprint algorithm")] - ErrUnsupportedFingerprintAlgorithm, - #[error("ICE connection not started")] - ErrICEConnectionNotStarted, - #[error("unknown candidate type")] - ErrICECandidateTypeUnknown, - #[error("cannot convert ice.CandidateType into webrtc.ICECandidateType, invalid type")] - ErrICEInvalidConvertCandidateType, - #[error("ICEAgent does not exist")] - ErrICEAgentNotExist, - #[error("unable to convert ICE candidates to ICECandidates")] - ErrICECandidatesConversionFailed, - #[error("unknown ICE Role")] - ErrICERoleUnknown, - #[error("unknown protocol")] - ErrICEProtocolUnknown, - #[error("gatherer not started")] - ErrICEGathererNotStarted, - #[error("unknown network type")] - ErrNetworkTypeUnknown, - #[error("new sdp does not match previous offer")] - ErrSDPDoesNotMatchOffer, - #[error("new sdp does not match previous answer")] - ErrSDPDoesNotMatchAnswer, - #[error("provided value is not a valid enum value of type SDPType")] - ErrPeerConnSDPTypeInvalidValue, - #[error("invalid state change op")] - ErrPeerConnStateChangeInvalid, - #[error("unhandled state change op")] - ErrPeerConnStateChangeUnhandled, - #[error("invalid SDP type supplied to SetLocalDescription()")] - ErrPeerConnSDPTypeInvalidValueSetLocalDescription, - #[error("remoteDescription contained media section without mid value")] - ErrPeerConnRemoteDescriptionWithoutMidValue, - #[error("remoteDescription has not been set yet")] - ErrPeerConnRemoteDescriptionNil, - #[error("single media section has an explicit SSRC")] - ErrPeerConnSingleMediaSectionHasExplicitSSRC, - #[error("could not add transceiver for remote SSRC")] - ErrPeerConnRemoteSSRCAddTransceiver, - #[error("mid RTP Extensions required for Simulcast")] - ErrPeerConnSimulcastMidRTPExtensionRequired, - #[error("stream id RTP Extensions required for Simulcast")] - ErrPeerConnSimulcastStreamIDRTPExtensionRequired, - #[error("incoming SSRC failed Simulcast probing")] - ErrPeerConnSimulcastIncomingSSRCFailed, - #[error("failed collecting stats")] - ErrPeerConnStatsCollectionFailed, - #[error("add_transceiver_from_kind only accepts one RTPTransceiverInit")] - ErrPeerConnAddTransceiverFromKindOnlyAcceptsOne, - #[error("add_transceiver_from_track only accepts one RTPTransceiverInit")] - ErrPeerConnAddTransceiverFromTrackOnlyAcceptsOne, - #[error("add_transceiver_from_kind currently only supports recvonly")] - ErrPeerConnAddTransceiverFromKindSupport, - #[error("add_transceiver_from_track currently only supports sendonly and sendrecv")] - ErrPeerConnAddTransceiverFromTrackSupport, - #[error("TODO set_identity_provider")] - ErrPeerConnSetIdentityProviderNotImplemented, - #[error("write_rtcp failed to open write_stream")] - ErrPeerConnWriteRTCPOpenWriteStream, - #[error("cannot find transceiver with mid")] - ErrPeerConnTransceiverMidNil, - #[error("DTLSTransport must not be nil")] - ErrRTPReceiverDTLSTransportNil, - #[error("Receive has already been called")] - ErrRTPReceiverReceiveAlreadyCalled, - #[error("unable to find stream for Track with SSRC")] - ErrRTPReceiverWithSSRCTrackStreamNotFound, - #[error("no trackStreams found for SSRC")] - ErrRTPReceiverForSSRCTrackStreamNotFound, - #[error("no trackStreams found for RID")] - ErrRTPReceiverForRIDTrackStreamNotFound, - #[error("invalid RTP Receiver transition from {from} to {to}")] - ErrRTPReceiverStateChangeInvalid { - from: rtp_receiver::State, - to: rtp_receiver::State, - }, - #[error("Track must not be nil")] - ErrRTPSenderTrackNil, - #[error("Sender has already been stopped")] - ErrRTPSenderStopped, - #[error("Sender Track has been removed or replaced to nil")] - ErrRTPSenderTrackRemoved, - #[error("Sender cannot add encoding as rid is empty")] - ErrRTPSenderRidNil, - #[error("Sender cannot add encoding as there is no base track")] - ErrRTPSenderNoBaseEncoding, - #[error("Sender cannot add encoding as provided track does not match base track")] - ErrRTPSenderBaseEncodingMismatch, - #[error("Sender cannot encoding due to RID collision")] - ErrRTPSenderRIDCollision, - #[error("Sender does not have track for RID")] - ErrRTPSenderNoTrackForRID, - #[error("RTPSender must not be nil")] - ErrRTPSenderNil, - #[error("RTPReceiver must not be nil")] - ErrRTPReceiverNil, - #[error("DTLSTransport must not be nil")] - ErrRTPSenderDTLSTransportNil, - #[error("Send has already been called")] - ErrRTPSenderSendAlreadyCalled, - #[error("errRTPSenderTrackNil")] - ErrRTPTransceiverCannotChangeMid, - #[error("invalid state change in RTPTransceiver.setSending")] - ErrRTPTransceiverSetSendingInvalidState, - #[error("unsupported codec type by this transceiver")] - ErrRTPTransceiverCodecUnsupported, - #[error("DTLS not established")] - ErrSCTPTransportDTLS, - #[error("add_transceiver_sdp() called with 0 transceivers")] - ErrSDPZeroTransceivers, - #[error("invalid Media Section. Media + DataChannel both enabled")] - ErrSDPMediaSectionMediaDataChanInvalid, - #[error( - "invalid Media Section. Can not have multiple tracks in one MediaSection in UnifiedPlan" - )] - ErrSDPMediaSectionMultipleTrackInvalid, - #[error("set_answering_dtlsrole must DTLSRoleClient or DTLSRoleServer")] - ErrSettingEngineSetAnsweringDTLSRole, - #[error("can't rollback from stable state")] - ErrSignalingStateCannotRollback, - #[error( - "invalid proposed signaling state transition from {} applying {} {}", - from, - if *is_local { "local" } else { "remote" }, - applying - )] - ErrSignalingStateProposedTransitionInvalid { - from: RTCSignalingState, - applying: RTCSdpType, - is_local: bool, - }, - #[error("cannot convert to StatsICECandidatePairStateSucceeded invalid ice candidate state")] - ErrStatsICECandidateStateInvalid, - #[error("ICETransport can only be called in ICETransportStateNew")] - ErrICETransportNotInNew, - #[error("bad Certificate PEM format")] - ErrCertificatePEMFormatError, - #[error("SCTP is not established")] - ErrSCTPNotEstablished, - - #[error("DataChannel is not opened")] - ErrClosedPipe, - #[error("Interceptor is not bind")] - ErrInterceptorNotBind, - #[error("excessive retries in CreateOffer")] - ErrExcessiveRetries, - - #[error("not long enough to be a RTP Packet")] - ErrRTPTooShort, - - #[error("{0}")] - Util(#[from] util::Error), - #[error("{0}")] - Ice(#[from] ice::Error), - #[error("{0}")] - Srtp(#[from] srtp::Error), - #[error("{0}")] - Dtls(#[from] dtls::Error), - #[error("{0}")] - Data(#[from] data::Error), - #[error("{0}")] - Sctp(#[from] sctp::Error), - #[error("{0}")] - Sdp(#[from] sdp::Error), - #[error("{0}")] - Interceptor(#[from] interceptor::Error), - #[error("{0}")] - Rtcp(#[from] rtcp::Error), - #[error("{0}")] - Rtp(#[from] rtp::Error), - - #[error("utf-8 error: {0}")] - Utf8(#[from] FromUtf8Error), - #[error("{0}")] - RcGen(#[from] rcgen::Error), - #[error("mpsc send: {0}")] - MpscSend(String), - #[error("parse int: {0}")] - ParseInt(#[from] ParseIntError), - #[error("parse url: {0}")] - ParseUrl(#[from] url::ParseError), - - /// Error parsing a given PEM string. - #[error("invalid PEM: {0}")] - InvalidPEM(String), - - #[allow(non_camel_case_types)] - #[error("{0}")] - new(String), -} - -pub type OnErrorHdlrFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -// Because Tokio SendError is parameterized, we sadly lose the backtrace. -impl From> for Error { - fn from(e: MpscSendError) -> Self { - Error::MpscSend(e.to_string()) - } -} - -impl From for interceptor::Error { - fn from(e: Error) -> Self { - // this is a bit lol, but we do preserve the stack trace - interceptor::Error::Util(util::Error::from_std(e)) - } -} - -impl PartialEq for Error { - fn eq(&self, other: &ice::Error) -> bool { - if let Error::Ice(e) = self { - return e == other; - } - false - } -} - -/// flatten_errs flattens multiple errors into one -pub fn flatten_errs(errs: Vec>) -> Result<()> { - if errs.is_empty() { - Ok(()) - } else { - let errs_strs: Vec = errs.into_iter().map(|e| e.into().to_string()).collect(); - Err(Error::new(errs_strs.join("\n"))) - } -} diff --git a/webrtc/src/ice_transport/ice_candidate.rs b/webrtc/src/ice_transport/ice_candidate.rs deleted file mode 100644 index 9eab74dae..000000000 --- a/webrtc/src/ice_transport/ice_candidate.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::fmt; -use std::sync::Arc; - -use ice::candidate::candidate_base::CandidateBaseConfig; -use ice::candidate::candidate_host::CandidateHostConfig; -use ice::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; -use ice::candidate::candidate_relay::CandidateRelayConfig; -use ice::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; -use ice::candidate::Candidate; -use serde::{Deserialize, Serialize}; - -use crate::error::{Error, Result}; -use crate::ice_transport::ice_candidate_type::RTCIceCandidateType; -use crate::ice_transport::ice_protocol::RTCIceProtocol; - -/// ICECandidate represents a ice candidate -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct RTCIceCandidate { - pub stats_id: String, - pub foundation: String, - pub priority: u32, - pub address: String, - pub protocol: RTCIceProtocol, - pub port: u16, - pub typ: RTCIceCandidateType, - pub component: u16, - pub related_address: String, - pub related_port: u16, - pub tcp_type: String, -} - -/// Conversion for ice_candidates -pub(crate) fn rtc_ice_candidates_from_ice_candidates( - ice_candidates: &[Arc], -) -> Vec { - ice_candidates.iter().map(|c| c.into()).collect() -} - -impl From<&Arc> for RTCIceCandidate { - fn from(c: &Arc) -> Self { - let typ: RTCIceCandidateType = c.candidate_type().into(); - let protocol = RTCIceProtocol::from(c.network_type().network_short().as_str()); - let (related_address, related_port) = if let Some(ra) = c.related_address() { - (ra.address, ra.port) - } else { - (String::new(), 0) - }; - - RTCIceCandidate { - stats_id: c.id(), - foundation: c.foundation(), - priority: c.priority(), - address: c.address(), - protocol, - port: c.port(), - component: c.component(), - typ, - tcp_type: c.tcp_type().to_string(), - related_address, - related_port, - } - } -} - -impl RTCIceCandidate { - pub(crate) fn to_ice(&self) -> Result { - let candidate_id = self.stats_id.clone(); - let base_config = CandidateBaseConfig { - candidate_id, - network: self.protocol.to_string(), - address: self.address.clone(), - port: self.port, - component: self.component, - //tcp_type: ice.NewTCPType(c.TCPType), - foundation: self.foundation.clone(), - priority: self.priority, - ..Default::default() - }; - - let c = match self.typ { - RTCIceCandidateType::Host => { - let config = CandidateHostConfig { - base_config, - ..Default::default() - }; - config.new_candidate_host()? - } - RTCIceCandidateType::Srflx => { - let config = CandidateServerReflexiveConfig { - base_config, - rel_addr: self.related_address.clone(), - rel_port: self.related_port, - }; - config.new_candidate_server_reflexive()? - } - RTCIceCandidateType::Prflx => { - let config = CandidatePeerReflexiveConfig { - base_config, - rel_addr: self.related_address.clone(), - rel_port: self.related_port, - }; - config.new_candidate_peer_reflexive()? - } - RTCIceCandidateType::Relay => { - let config = CandidateRelayConfig { - base_config, - rel_addr: self.related_address.clone(), - rel_port: self.related_port, - relay_client: None, //TODO? - }; - config.new_candidate_relay()? - } - _ => return Err(Error::ErrICECandidateTypeUnknown), - }; - - Ok(c) - } - - /// to_json returns an ICECandidateInit - /// as indicated by the spec - pub fn to_json(&self) -> Result { - let candidate = self.to_ice()?; - - Ok(RTCIceCandidateInit { - candidate: format!("candidate:{}", candidate.marshal()), - sdp_mid: Some("".to_owned()), - sdp_mline_index: Some(0u16), - username_fragment: None, - }) - } -} - -impl fmt::Display for RTCIceCandidate { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{} {} {}:{}{}", - self.protocol, self.typ, self.address, self.port, self.related_address, - ) - } -} - -/// ICECandidateInit is used to serialize ice candidates -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RTCIceCandidateInit { - pub candidate: String, - pub sdp_mid: Option, - #[serde(rename = "sdpMLineIndex")] - pub sdp_mline_index: Option, - pub username_fragment: Option, -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ice_candidate_serialization() { - let tests = vec![ - ( - RTCIceCandidateInit { - candidate: "candidate:abc123".to_string(), - sdp_mid: Some("0".to_string()), - sdp_mline_index: Some(0), - username_fragment: Some("def".to_string()), - }, - r#"{"candidate":"candidate:abc123","sdpMid":"0","sdpMLineIndex":0,"usernameFragment":"def"}"#, - ), - ( - RTCIceCandidateInit { - candidate: "candidate:abc123".to_string(), - sdp_mid: None, - sdp_mline_index: None, - username_fragment: None, - }, - r#"{"candidate":"candidate:abc123","sdpMid":null,"sdpMLineIndex":null,"usernameFragment":null}"#, - ), - ]; - - for (candidate_init, expected_string) in tests { - let result = serde_json::to_string(&candidate_init); - assert!(result.is_ok(), "testCase: marshal err: {result:?}"); - let candidate_data = result.unwrap(); - assert_eq!(candidate_data, expected_string, "string is not expected"); - - let result = serde_json::from_str::(&candidate_data); - assert!(result.is_ok(), "testCase: unmarshal err: {result:?}"); - if let Ok(actual_candidate_init) = result { - assert_eq!(actual_candidate_init, candidate_init); - } - } - } -} diff --git a/webrtc/src/ice_transport/ice_candidate_pair.rs b/webrtc/src/ice_transport/ice_candidate_pair.rs deleted file mode 100644 index 8c0684c0d..000000000 --- a/webrtc/src/ice_transport/ice_candidate_pair.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::fmt; - -use crate::ice_transport::ice_candidate::*; - -/// ICECandidatePair represents an ICE Candidate pair -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct RTCIceCandidatePair { - stats_id: String, - local: RTCIceCandidate, - remote: RTCIceCandidate, -} - -impl fmt::Display for RTCIceCandidatePair { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(local) {} <-> (remote) {}", self.local, self.remote) - } -} - -impl RTCIceCandidatePair { - fn stats_id(local_id: &str, remote_id: &str) -> String { - format!("{local_id}-{remote_id}") - } - - /// returns an initialized ICECandidatePair - /// for the given pair of ICECandidate instances - pub fn new(local: RTCIceCandidate, remote: RTCIceCandidate) -> Self { - let stats_id = Self::stats_id(&local.stats_id, &remote.stats_id); - RTCIceCandidatePair { - stats_id, - local, - remote, - } - } -} diff --git a/webrtc/src/ice_transport/ice_candidate_type.rs b/webrtc/src/ice_transport/ice_candidate_type.rs deleted file mode 100644 index ec44328f7..000000000 --- a/webrtc/src/ice_transport/ice_candidate_type.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::fmt; - -use ice::candidate::CandidateType; -use serde::{Deserialize, Serialize}; - -/// ICECandidateType represents the type of the ICE candidate used. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum RTCIceCandidateType { - #[default] - Unspecified, - - /// ICECandidateTypeHost indicates that the candidate is of Host type as - /// described in . A - /// candidate obtained by binding to a specific port from an IP address on - /// the host. This includes IP addresses on physical interfaces and logical - /// ones, such as ones obtained through VPNs. - #[serde(rename = "host")] - Host, - - /// ICECandidateTypeSrflx indicates the the candidate is of Server - /// Reflexive type as described - /// . A candidate type - /// whose IP address and port are a binding allocated by a NAT for an ICE - /// agent after it sends a packet through the NAT to a server, such as a - /// STUN server. - #[serde(rename = "srflx")] - Srflx, - - /// ICECandidateTypePrflx indicates that the candidate is of Peer - /// Reflexive type. A candidate type whose IP address and port are a binding - /// allocated by a NAT for an ICE agent after it sends a packet through the - /// NAT to its peer. - #[serde(rename = "prflx")] - Prflx, - - /// ICECandidateTypeRelay indicates the the candidate is of Relay type as - /// described in . A - /// candidate type obtained from a relay server, such as a TURN server. - #[serde(rename = "relay")] - Relay, -} - -const ICE_CANDIDATE_TYPE_HOST_STR: &str = "host"; -const ICE_CANDIDATE_TYPE_SRFLX_STR: &str = "srflx"; -const ICE_CANDIDATE_TYPE_PRFLX_STR: &str = "prflx"; -const ICE_CANDIDATE_TYPE_RELAY_STR: &str = "relay"; - -/// takes a string and converts it into ICECandidateType -impl From<&str> for RTCIceCandidateType { - fn from(raw: &str) -> Self { - match raw { - ICE_CANDIDATE_TYPE_HOST_STR => RTCIceCandidateType::Host, - ICE_CANDIDATE_TYPE_SRFLX_STR => RTCIceCandidateType::Srflx, - ICE_CANDIDATE_TYPE_PRFLX_STR => RTCIceCandidateType::Prflx, - ICE_CANDIDATE_TYPE_RELAY_STR => RTCIceCandidateType::Relay, - _ => RTCIceCandidateType::Unspecified, - } - } -} - -impl From for RTCIceCandidateType { - fn from(candidate_type: CandidateType) -> Self { - match candidate_type { - CandidateType::Host => RTCIceCandidateType::Host, - CandidateType::ServerReflexive => RTCIceCandidateType::Srflx, - CandidateType::PeerReflexive => RTCIceCandidateType::Prflx, - CandidateType::Relay => RTCIceCandidateType::Relay, - _ => RTCIceCandidateType::Unspecified, - } - } -} - -impl fmt::Display for RTCIceCandidateType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceCandidateType::Host => write!(f, "{ICE_CANDIDATE_TYPE_HOST_STR}"), - RTCIceCandidateType::Srflx => write!(f, "{ICE_CANDIDATE_TYPE_SRFLX_STR}"), - RTCIceCandidateType::Prflx => write!(f, "{ICE_CANDIDATE_TYPE_PRFLX_STR}"), - RTCIceCandidateType::Relay => write!(f, "{ICE_CANDIDATE_TYPE_RELAY_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ice_candidate_type() { - let tests = vec![ - ("Unspecified", RTCIceCandidateType::Unspecified), - ("host", RTCIceCandidateType::Host), - ("srflx", RTCIceCandidateType::Srflx), - ("prflx", RTCIceCandidateType::Prflx), - ("relay", RTCIceCandidateType::Relay), - ]; - - for (type_string, expected_type) in tests { - let actual = RTCIceCandidateType::from(type_string); - assert_eq!(actual, expected_type); - } - } - - #[test] - fn test_ice_candidate_type_string() { - let tests = vec![ - (RTCIceCandidateType::Unspecified, "Unspecified"), - (RTCIceCandidateType::Host, "host"), - (RTCIceCandidateType::Srflx, "srflx"), - (RTCIceCandidateType::Prflx, "prflx"), - (RTCIceCandidateType::Relay, "relay"), - ]; - - for (ctype, expected_string) in tests { - assert_eq!(ctype.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/ice_transport/ice_connection_state.rs b/webrtc/src/ice_transport/ice_connection_state.rs deleted file mode 100644 index d019854ae..000000000 --- a/webrtc/src/ice_transport/ice_connection_state.rs +++ /dev/null @@ -1,142 +0,0 @@ -use std::fmt; - -/// RTCIceConnectionState indicates signaling state of the ICE Connection. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCIceConnectionState { - #[default] - Unspecified, - - /// ICEConnectionStateNew indicates that any of the ICETransports are - /// in the "new" state and none of them are in the "checking", "disconnected" - /// or "failed" state, or all ICETransports are in the "closed" state, or - /// there are no transports. - New, - - /// ICEConnectionStateChecking indicates that any of the ICETransports - /// are in the "checking" state and none of them are in the "disconnected" - /// or "failed" state. - Checking, - - /// ICEConnectionStateConnected indicates that all ICETransports are - /// in the "connected", "completed" or "closed" state and at least one of - /// them is in the "connected" state. - Connected, - - /// ICEConnectionStateCompleted indicates that all ICETransports are - /// in the "completed" or "closed" state and at least one of them is in the - /// "completed" state. - Completed, - - /// ICEConnectionStateDisconnected indicates that any of the - /// ICETransports are in the "disconnected" state and none of them are - /// in the "failed" state. - Disconnected, - - /// ICEConnectionStateFailed indicates that any of the ICETransports - /// are in the "failed" state. - Failed, - - /// ICEConnectionStateClosed indicates that the PeerConnection's - /// isClosed is true. - Closed, -} - -const ICE_CONNECTION_STATE_NEW_STR: &str = "new"; -const ICE_CONNECTION_STATE_CHECKING_STR: &str = "checking"; -const ICE_CONNECTION_STATE_CONNECTED_STR: &str = "connected"; -const ICE_CONNECTION_STATE_COMPLETED_STR: &str = "completed"; -const ICE_CONNECTION_STATE_DISCONNECTED_STR: &str = "disconnected"; -const ICE_CONNECTION_STATE_FAILED_STR: &str = "failed"; -const ICE_CONNECTION_STATE_CLOSED_STR: &str = "closed"; - -/// takes a string and converts it to iceconnection_state -impl From<&str> for RTCIceConnectionState { - fn from(raw: &str) -> Self { - match raw { - ICE_CONNECTION_STATE_NEW_STR => RTCIceConnectionState::New, - ICE_CONNECTION_STATE_CHECKING_STR => RTCIceConnectionState::Checking, - ICE_CONNECTION_STATE_CONNECTED_STR => RTCIceConnectionState::Connected, - ICE_CONNECTION_STATE_COMPLETED_STR => RTCIceConnectionState::Completed, - ICE_CONNECTION_STATE_DISCONNECTED_STR => RTCIceConnectionState::Disconnected, - ICE_CONNECTION_STATE_FAILED_STR => RTCIceConnectionState::Failed, - ICE_CONNECTION_STATE_CLOSED_STR => RTCIceConnectionState::Closed, - _ => RTCIceConnectionState::Unspecified, - } - } -} - -impl From for RTCIceConnectionState { - fn from(v: u8) -> Self { - match v { - 1 => RTCIceConnectionState::New, - 2 => RTCIceConnectionState::Checking, - 3 => RTCIceConnectionState::Connected, - 4 => RTCIceConnectionState::Completed, - 5 => RTCIceConnectionState::Disconnected, - 6 => RTCIceConnectionState::Failed, - 7 => RTCIceConnectionState::Closed, - _ => RTCIceConnectionState::Unspecified, - } - } -} - -impl fmt::Display for RTCIceConnectionState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCIceConnectionState::New => ICE_CONNECTION_STATE_NEW_STR, - RTCIceConnectionState::Checking => ICE_CONNECTION_STATE_CHECKING_STR, - RTCIceConnectionState::Connected => ICE_CONNECTION_STATE_CONNECTED_STR, - RTCIceConnectionState::Completed => ICE_CONNECTION_STATE_COMPLETED_STR, - RTCIceConnectionState::Disconnected => ICE_CONNECTION_STATE_DISCONNECTED_STR, - RTCIceConnectionState::Failed => ICE_CONNECTION_STATE_FAILED_STR, - RTCIceConnectionState::Closed => ICE_CONNECTION_STATE_CLOSED_STR, - RTCIceConnectionState::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_ice_connection_state() { - let tests = vec![ - (crate::UNSPECIFIED_STR, RTCIceConnectionState::Unspecified), - ("new", RTCIceConnectionState::New), - ("checking", RTCIceConnectionState::Checking), - ("connected", RTCIceConnectionState::Connected), - ("completed", RTCIceConnectionState::Completed), - ("disconnected", RTCIceConnectionState::Disconnected), - ("failed", RTCIceConnectionState::Failed), - ("closed", RTCIceConnectionState::Closed), - ]; - - for (state_string, expected_state) in tests { - assert_eq!( - RTCIceConnectionState::from(state_string), - expected_state, - "testCase: {expected_state}", - ); - } - } - - #[test] - fn test_ice_connection_state_string() { - let tests = vec![ - (RTCIceConnectionState::Unspecified, crate::UNSPECIFIED_STR), - (RTCIceConnectionState::New, "new"), - (RTCIceConnectionState::Checking, "checking"), - (RTCIceConnectionState::Connected, "connected"), - (RTCIceConnectionState::Completed, "completed"), - (RTCIceConnectionState::Disconnected, "disconnected"), - (RTCIceConnectionState::Failed, "failed"), - (RTCIceConnectionState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string) - } - } -} diff --git a/webrtc/src/ice_transport/ice_credential_type.rs b/webrtc/src/ice_transport/ice_credential_type.rs deleted file mode 100644 index 5db6c6cd2..000000000 --- a/webrtc/src/ice_transport/ice_credential_type.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// ICECredentialType indicates the type of credentials used to connect to -/// an ICE server. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub enum RTCIceCredentialType { - #[default] - Unspecified, - - /// ICECredential::Password describes username and password based - /// credentials as described in . - Password, - - /// ICECredential::Oauth describes token based credential as described - /// in . - /// Not supported in WebRTC 1.0 spec - Oauth, -} - -const ICE_CREDENTIAL_TYPE_PASSWORD_STR: &str = "password"; -const ICE_CREDENTIAL_TYPE_OAUTH_STR: &str = "oauth"; - -impl From<&str> for RTCIceCredentialType { - fn from(raw: &str) -> Self { - match raw { - ICE_CREDENTIAL_TYPE_PASSWORD_STR => RTCIceCredentialType::Password, - ICE_CREDENTIAL_TYPE_OAUTH_STR => RTCIceCredentialType::Oauth, - _ => RTCIceCredentialType::Unspecified, - } - } -} - -impl fmt::Display for RTCIceCredentialType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceCredentialType::Password => write!(f, "{ICE_CREDENTIAL_TYPE_PASSWORD_STR}"), - RTCIceCredentialType::Oauth => write!(f, "{ICE_CREDENTIAL_TYPE_OAUTH_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_ice_credential_type() { - let tests = vec![ - ("Unspecified", RTCIceCredentialType::Unspecified), - ("password", RTCIceCredentialType::Password), - ("oauth", RTCIceCredentialType::Oauth), - ]; - - for (ct_str, expected_ct) in tests { - assert_eq!(RTCIceCredentialType::from(ct_str), expected_ct); - } - } - - #[test] - fn test_ice_credential_type_string() { - let tests = vec![ - (RTCIceCredentialType::Unspecified, "Unspecified"), - (RTCIceCredentialType::Password, "password"), - (RTCIceCredentialType::Oauth, "oauth"), - ]; - - for (ct, expected_string) in tests { - assert_eq!(ct.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/ice_transport/ice_gatherer.rs b/webrtc/src/ice_transport/ice_gatherer.rs deleted file mode 100644 index d247e3bd1..000000000 --- a/webrtc/src/ice_transport/ice_gatherer.rs +++ /dev/null @@ -1,410 +0,0 @@ -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use ice::agent::Agent; -use ice::candidate::{Candidate, CandidateType}; -use ice::url::Url; -use portable_atomic::AtomicU8; -use tokio::sync::Mutex; - -use crate::api::setting_engine::SettingEngine; -use crate::error::{Error, Result}; -use crate::ice_transport::ice_candidate::*; -use crate::ice_transport::ice_candidate_type::RTCIceCandidateType; -use crate::ice_transport::ice_gatherer_state::RTCIceGathererState; -use crate::ice_transport::ice_parameters::RTCIceParameters; -use crate::ice_transport::ice_server::RTCIceServer; -use crate::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::SourceStatsType::*; -use crate::stats::{ICECandidatePairStats, StatsReportType}; - -/// ICEGatherOptions provides options relating to the gathering of ICE candidates. -#[derive(Default, Debug, Clone)] -pub struct RTCIceGatherOptions { - pub ice_servers: Vec, - pub ice_gather_policy: RTCIceTransportPolicy, -} - -pub type OnLocalCandidateHdlrFn = Box< - dyn (FnMut(Option) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnICEGathererStateChangeHdlrFn = Box< - dyn (FnMut(RTCIceGathererState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnGatheringCompleteHdlrFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -/// ICEGatherer gathers local host, server reflexive and relay -/// candidates, as well as enabling the retrieval of local Interactive -/// Connectivity Establishment (ICE) parameters which can be -/// exchanged in signaling. -#[derive(Default)] -pub struct RTCIceGatherer { - pub(crate) validated_servers: Vec, - pub(crate) gather_policy: RTCIceTransportPolicy, - pub(crate) setting_engine: Arc, - - pub(crate) state: Arc, //ICEGathererState, - pub(crate) agent: Mutex>>, - - pub(crate) on_local_candidate_handler: Arc>>, - pub(crate) on_state_change_handler: Arc>>, - - // Used for gathering_complete_promise - pub(crate) on_gathering_complete_handler: Arc>>, -} - -impl RTCIceGatherer { - pub(crate) fn new( - validated_servers: Vec, - gather_policy: RTCIceTransportPolicy, - setting_engine: Arc, - ) -> Self { - RTCIceGatherer { - gather_policy, - validated_servers, - setting_engine, - state: Arc::new(AtomicU8::new(RTCIceGathererState::New as u8)), - ..Default::default() - } - } - - pub(crate) async fn create_agent(&self) -> Result<()> { - // NOTE: A lock is held for the duration of this function in order to - // avoid potential double-agent creations. Care should be taken to - // ensure we do not do anything expensive other than the actual agent - // creation in this function. - let mut agent = self.agent.lock().await; - - if agent.is_some() || self.state() != RTCIceGathererState::New { - return Ok(()); - } - - let mut candidate_types = vec![]; - if self.setting_engine.candidates.ice_lite { - candidate_types.push(ice::candidate::CandidateType::Host); - } else if self.gather_policy == RTCIceTransportPolicy::Relay { - candidate_types.push(ice::candidate::CandidateType::Relay); - } - - let nat_1to1_cand_type = match self.setting_engine.candidates.nat_1to1_ip_candidate_type { - RTCIceCandidateType::Host => CandidateType::Host, - RTCIceCandidateType::Srflx => CandidateType::ServerReflexive, - _ => CandidateType::Unspecified, - }; - - let mdns_mode = self.setting_engine.candidates.multicast_dns_mode; - - let mut config = ice::agent::agent_config::AgentConfig { - udp_network: self.setting_engine.udp_network.clone(), - lite: self.setting_engine.candidates.ice_lite, - urls: self.validated_servers.clone(), - disconnected_timeout: self.setting_engine.timeout.ice_disconnected_timeout, - failed_timeout: self.setting_engine.timeout.ice_failed_timeout, - keepalive_interval: self.setting_engine.timeout.ice_keepalive_interval, - candidate_types, - host_acceptance_min_wait: self.setting_engine.timeout.ice_host_acceptance_min_wait, - srflx_acceptance_min_wait: self.setting_engine.timeout.ice_srflx_acceptance_min_wait, - prflx_acceptance_min_wait: self.setting_engine.timeout.ice_prflx_acceptance_min_wait, - relay_acceptance_min_wait: self.setting_engine.timeout.ice_relay_acceptance_min_wait, - interface_filter: self.setting_engine.candidates.interface_filter.clone(), - ip_filter: self.setting_engine.candidates.ip_filter.clone(), - nat_1to1_ips: self.setting_engine.candidates.nat_1to1_ips.clone(), - nat_1to1_ip_candidate_type: nat_1to1_cand_type, - net: self.setting_engine.vnet.clone(), - multicast_dns_mode: mdns_mode, - multicast_dns_host_name: self - .setting_engine - .candidates - .multicast_dns_host_name - .clone(), - local_ufrag: self.setting_engine.candidates.username_fragment.clone(), - local_pwd: self.setting_engine.candidates.password.clone(), - //TODO: TCPMux: self.setting_engine.iceTCPMux, - //TODO: ProxyDialer: self.setting_engine.iceProxyDialer, - ..Default::default() - }; - - let requested_network_types = if self.setting_engine.candidates.ice_network_types.is_empty() - { - ice::network_type::supported_network_types() - } else { - self.setting_engine.candidates.ice_network_types.clone() - }; - - config.network_types.extend(requested_network_types); - - *agent = Some(Arc::new(ice::agent::Agent::new(config).await?)); - - Ok(()) - } - - /// Gather ICE candidates. - pub async fn gather(&self) -> Result<()> { - self.create_agent().await?; - self.set_state(RTCIceGathererState::Gathering).await; - - if let Some(agent) = self.get_agent().await { - let state = Arc::clone(&self.state); - let on_local_candidate_handler = Arc::clone(&self.on_local_candidate_handler); - let on_state_change_handler = Arc::clone(&self.on_state_change_handler); - let on_gathering_complete_handler = Arc::clone(&self.on_gathering_complete_handler); - - agent.on_candidate(Box::new( - move |candidate: Option>| { - let state_clone = Arc::clone(&state); - let on_local_candidate_handler_clone = Arc::clone(&on_local_candidate_handler); - let on_state_change_handler_clone = Arc::clone(&on_state_change_handler); - let on_gathering_complete_handler_clone = - Arc::clone(&on_gathering_complete_handler); - - Box::pin(async move { - if let Some(cand) = candidate { - if let Some(handler) = &*on_local_candidate_handler_clone.load() { - let mut f = handler.lock().await; - f(Some(RTCIceCandidate::from(&cand))).await; - } - } else { - state_clone - .store(RTCIceGathererState::Complete as u8, Ordering::SeqCst); - - if let Some(handler) = &*on_state_change_handler_clone.load() { - let mut f = handler.lock().await; - f(RTCIceGathererState::Complete).await; - } - - if let Some(handler) = &*on_gathering_complete_handler_clone.load() { - let mut f = handler.lock().await; - f().await; - } - - if let Some(handler) = &*on_local_candidate_handler_clone.load() { - let mut f = handler.lock().await; - f(None).await; - } - } - }) - }, - )); - - agent.gather_candidates()?; - } - - Ok(()) - } - - /// Close prunes all local candidates, and closes the ports. - pub async fn close(&self) -> Result<()> { - self.set_state(RTCIceGathererState::Closed).await; - - let agent = { - let mut agent_opt = self.agent.lock().await; - agent_opt.take() - }; - - if let Some(agent) = agent { - agent.close().await?; - } - - Ok(()) - } - - /// get_local_parameters returns the ICE parameters of the ICEGatherer. - pub async fn get_local_parameters(&self) -> Result { - self.create_agent().await?; - - let (frag, pwd) = if let Some(agent) = self.get_agent().await { - agent.get_local_user_credentials().await - } else { - return Err(Error::ErrICEAgentNotExist); - }; - - Ok(RTCIceParameters { - username_fragment: frag, - password: pwd, - ice_lite: false, - }) - } - - /// get_local_candidates returns the sequence of valid local candidates associated with the ICEGatherer. - pub async fn get_local_candidates(&self) -> Result> { - self.create_agent().await?; - - let ice_candidates = if let Some(agent) = self.get_agent().await { - agent.get_local_candidates().await? - } else { - return Err(Error::ErrICEAgentNotExist); - }; - - Ok(rtc_ice_candidates_from_ice_candidates(&ice_candidates)) - } - - /// on_local_candidate sets an event handler which fires when a new local ICE candidate is available - /// Take note that the handler is gonna be called with a nil pointer when gathering is finished. - pub fn on_local_candidate(&self, f: OnLocalCandidateHdlrFn) { - self.on_local_candidate_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_state_change sets an event handler which fires any time the ICEGatherer changes - pub fn on_state_change(&self, f: OnICEGathererStateChangeHdlrFn) { - self.on_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_gathering_complete sets an event handler which fires any time the ICEGatherer changes - pub fn on_gathering_complete(&self, f: OnGatheringCompleteHdlrFn) { - self.on_gathering_complete_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// State indicates the current state of the ICE gatherer. - pub fn state(&self) -> RTCIceGathererState { - self.state.load(Ordering::SeqCst).into() - } - - pub async fn set_state(&self, s: RTCIceGathererState) { - self.state.store(s as u8, Ordering::SeqCst); - - if let Some(handler) = &*self.on_state_change_handler.load() { - let mut f = handler.lock().await; - f(s).await; - } - } - - pub(crate) async fn get_agent(&self) -> Option> { - let agent = self.agent.lock().await; - agent.clone() - } - - pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { - if let Some(agent) = self.get_agent().await { - let mut reports = HashMap::new(); - - for stats in agent.get_candidate_pairs_stats().await { - let stats: ICECandidatePairStats = stats.into(); - reports.insert(stats.id.clone(), StatsReportType::CandidatePair(stats)); - } - - for stats in agent.get_local_candidates_stats().await { - reports.insert( - stats.id.clone(), - StatsReportType::from(LocalCandidate(stats)), - ); - } - - for stats in agent.get_remote_candidates_stats().await { - reports.insert( - stats.id.clone(), - StatsReportType::from(RemoteCandidate(stats)), - ); - } - - collector.merge(reports); - } - } -} - -#[cfg(test)] -mod test { - use tokio::sync::mpsc; - - use super::*; - use crate::api::APIBuilder; - use crate::ice_transport::ice_gatherer::RTCIceGatherOptions; - use crate::ice_transport::ice_server::RTCIceServer; - - #[tokio::test] - async fn test_new_ice_gatherer_success() -> Result<()> { - let opts = RTCIceGatherOptions { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - let gatherer = APIBuilder::new().build().new_ice_gatherer(opts)?; - - assert_eq!( - gatherer.state(), - RTCIceGathererState::New, - "Expected gathering state new" - ); - - let (gather_finished_tx, mut gather_finished_rx) = mpsc::channel::<()>(1); - let gather_finished_tx = Arc::new(Mutex::new(Some(gather_finished_tx))); - gatherer.on_local_candidate(Box::new(move |c: Option| { - let gather_finished_tx_clone = Arc::clone(&gather_finished_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = gather_finished_tx_clone.lock().await; - tx.take(); - } - }) - })); - - gatherer.gather().await?; - - let _ = gather_finished_rx.recv().await; - - let params = gatherer.get_local_parameters().await?; - - assert!( - !params.username_fragment.is_empty() && !params.password.is_empty(), - "Empty local username or password frag" - ); - - let candidates = gatherer.get_local_candidates().await?; - - assert!(!candidates.is_empty(), "No candidates gathered"); - - gatherer.close().await?; - - Ok(()) - } - - #[tokio::test] - async fn test_ice_gather_mdns_candidate_gathering() -> Result<()> { - let mut s = SettingEngine::default(); - s.set_ice_multicast_dns_mode(ice::mdns::MulticastDnsMode::QueryAndGather); - - let gatherer = APIBuilder::new() - .with_setting_engine(s) - .build() - .new_ice_gatherer(RTCIceGatherOptions::default())?; - - let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - gatherer.on_local_candidate(Box::new(move |c: Option| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if let Some(c) = c { - if c.address.ends_with(".local") { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - } - }) - })); - - gatherer.gather().await?; - - let _ = done_rx.recv().await; - - gatherer.close().await?; - - Ok(()) - } -} diff --git a/webrtc/src/ice_transport/ice_gatherer_state.rs b/webrtc/src/ice_transport/ice_gatherer_state.rs deleted file mode 100644 index 7b24e9968..000000000 --- a/webrtc/src/ice_transport/ice_gatherer_state.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::fmt; - -/// ICEGathererState represents the current state of the ICE gatherer. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCIceGathererState { - #[default] - Unspecified, - - /// ICEGathererStateNew indicates object has been created but - /// gather() has not been called. - New, - - /// ICEGathererStateGathering indicates gather() has been called, - /// and the ICEGatherer is in the process of gathering candidates. - Gathering, - - /// ICEGathererStateComplete indicates the ICEGatherer has completed gathering. - Complete, - - /// ICEGathererStateClosed indicates the closed state can only be entered - /// when the ICEGatherer has been closed intentionally by calling close(). - Closed, -} - -const ICE_GATHERED_STATE_NEW_STR: &str = "new"; -const ICE_GATHERED_STATE_GATHERING_STR: &str = "gathering"; -const ICE_GATHERED_STATE_COMPLETE_STR: &str = "complete"; -const ICE_GATHERED_STATE_CLOSED_STR: &str = "closed"; - -impl From<&str> for RTCIceGathererState { - fn from(raw: &str) -> Self { - match raw { - ICE_GATHERED_STATE_NEW_STR => RTCIceGathererState::New, - ICE_GATHERED_STATE_GATHERING_STR => RTCIceGathererState::Gathering, - ICE_GATHERED_STATE_COMPLETE_STR => RTCIceGathererState::Complete, - ICE_GATHERED_STATE_CLOSED_STR => RTCIceGathererState::Closed, - _ => RTCIceGathererState::Unspecified, - } - } -} - -impl From for RTCIceGathererState { - fn from(v: u8) -> Self { - match v { - 1 => RTCIceGathererState::New, - 2 => RTCIceGathererState::Gathering, - 3 => RTCIceGathererState::Complete, - 4 => RTCIceGathererState::Closed, - _ => RTCIceGathererState::Unspecified, - } - } -} - -impl fmt::Display for RTCIceGathererState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceGathererState::New => write!(f, "{ICE_GATHERED_STATE_NEW_STR}"), - RTCIceGathererState::Gathering => write!(f, "{ICE_GATHERED_STATE_GATHERING_STR}"), - RTCIceGathererState::Complete => { - write!(f, "{ICE_GATHERED_STATE_COMPLETE_STR}") - } - RTCIceGathererState::Closed => { - write!(f, "{ICE_GATHERED_STATE_CLOSED_STR}") - } - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ice_gatherer_state_string() { - let tests = vec![ - (RTCIceGathererState::Unspecified, "Unspecified"), - (RTCIceGathererState::New, "new"), - (RTCIceGathererState::Gathering, "gathering"), - (RTCIceGathererState::Complete, "complete"), - (RTCIceGathererState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/ice_transport/ice_gathering_state.rs b/webrtc/src/ice_transport/ice_gathering_state.rs deleted file mode 100644 index fa043312d..000000000 --- a/webrtc/src/ice_transport/ice_gathering_state.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::fmt; - -/// ICEGatheringState describes the state of the candidate gathering process. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCIceGatheringState { - #[default] - Unspecified, - - /// ICEGatheringStateNew indicates that any of the ICETransports are - /// in the "new" gathering state and none of the transports are in the - /// "gathering" state, or there are no transports. - New, - - /// ICEGatheringStateGathering indicates that any of the ICETransports - /// are in the "gathering" state. - Gathering, - - /// ICEGatheringStateComplete indicates that at least one ICETransport - /// exists, and all ICETransports are in the "completed" gathering state. - Complete, -} - -const ICE_GATHERING_STATE_NEW_STR: &str = "new"; -const ICE_GATHERING_STATE_GATHERING_STR: &str = "gathering"; -const ICE_GATHERING_STATE_COMPLETE_STR: &str = "complete"; - -/// takes a string and converts it to ICEGatheringState -impl From<&str> for RTCIceGatheringState { - fn from(raw: &str) -> Self { - match raw { - ICE_GATHERING_STATE_NEW_STR => RTCIceGatheringState::New, - ICE_GATHERING_STATE_GATHERING_STR => RTCIceGatheringState::Gathering, - ICE_GATHERING_STATE_COMPLETE_STR => RTCIceGatheringState::Complete, - _ => RTCIceGatheringState::Unspecified, - } - } -} - -impl fmt::Display for RTCIceGatheringState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceGatheringState::New => write!(f, "{ICE_GATHERING_STATE_NEW_STR}"), - RTCIceGatheringState::Gathering => write!(f, "{ICE_GATHERING_STATE_GATHERING_STR}"), - RTCIceGatheringState::Complete => { - write!(f, "{ICE_GATHERING_STATE_COMPLETE_STR}") - } - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_ice_gathering_state() { - let tests = vec![ - ("Unspecified", RTCIceGatheringState::Unspecified), - ("new", RTCIceGatheringState::New), - ("gathering", RTCIceGatheringState::Gathering), - ("complete", RTCIceGatheringState::Complete), - ]; - - for (state_string, expected_state) in tests { - assert_eq!(RTCIceGatheringState::from(state_string), expected_state); - } - } - - #[test] - fn test_ice_gathering_state_string() { - let tests = vec![ - (RTCIceGatheringState::Unspecified, "Unspecified"), - (RTCIceGatheringState::New, "new"), - (RTCIceGatheringState::Gathering, "gathering"), - (RTCIceGatheringState::Complete, "complete"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/ice_transport/ice_parameters.rs b/webrtc/src/ice_transport/ice_parameters.rs deleted file mode 100644 index 048e359b9..000000000 --- a/webrtc/src/ice_transport/ice_parameters.rs +++ /dev/null @@ -1,10 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// ICEParameters includes the ICE username fragment -/// and password and other ICE-related parameters. -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct RTCIceParameters { - pub username_fragment: String, - pub password: String, - pub ice_lite: bool, -} diff --git a/webrtc/src/ice_transport/ice_protocol.rs b/webrtc/src/ice_transport/ice_protocol.rs deleted file mode 100644 index 308912505..000000000 --- a/webrtc/src/ice_transport/ice_protocol.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// ICEProtocol indicates the transport protocol type that is used in the -/// ice.URL structure. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum RTCIceProtocol { - #[default] - Unspecified, - - /// UDP indicates the URL uses a UDP transport. - #[serde(rename = "udp")] - Udp, - - /// TCP indicates the URL uses a TCP transport. - #[serde(rename = "tcp")] - Tcp, -} - -const ICE_PROTOCOL_UDP_STR: &str = "udp"; -const ICE_PROTOCOL_TCP_STR: &str = "tcp"; - -/// takes a string and converts it to ICEProtocol -impl From<&str> for RTCIceProtocol { - fn from(raw: &str) -> Self { - if raw.to_uppercase() == ICE_PROTOCOL_UDP_STR.to_uppercase() { - RTCIceProtocol::Udp - } else if raw.to_uppercase() == ICE_PROTOCOL_TCP_STR.to_uppercase() { - RTCIceProtocol::Tcp - } else { - RTCIceProtocol::Unspecified - } - } -} - -impl fmt::Display for RTCIceProtocol { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceProtocol::Udp => write!(f, "{ICE_PROTOCOL_UDP_STR}"), - RTCIceProtocol::Tcp => write!(f, "{ICE_PROTOCOL_TCP_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_ice_protocol() { - let tests = vec![ - ("Unspecified", RTCIceProtocol::Unspecified), - ("udp", RTCIceProtocol::Udp), - ("tcp", RTCIceProtocol::Tcp), - ("UDP", RTCIceProtocol::Udp), - ("TCP", RTCIceProtocol::Tcp), - ]; - - for (proto_string, expected_proto) in tests { - let actual = RTCIceProtocol::from(proto_string); - assert_eq!(actual, expected_proto); - } - } - - #[test] - fn test_ice_protocol_string() { - let tests = vec![ - (RTCIceProtocol::Unspecified, "Unspecified"), - (RTCIceProtocol::Udp, "udp"), - (RTCIceProtocol::Tcp, "tcp"), - ]; - - for (proto, expected_string) in tests { - assert_eq!(proto.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/ice_transport/ice_role.rs b/webrtc/src/ice_transport/ice_role.rs deleted file mode 100644 index 699dd4b45..000000000 --- a/webrtc/src/ice_transport/ice_role.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::fmt; - -/// ICERole describes the role ice.Agent is playing in selecting the -/// preferred the candidate pair. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCIceRole { - #[default] - Unspecified, - - /// ICERoleControlling indicates that the ICE agent that is responsible - /// for selecting the final choice of candidate pairs and signaling them - /// through STUN and an updated offer, if needed. In any session, one agent - /// is always controlling. The other is the controlled agent. - Controlling, - - /// ICERoleControlled indicates that an ICE agent that waits for the - /// controlling agent to select the final choice of candidate pairs. - Controlled, -} - -const ICE_ROLE_CONTROLLING_STR: &str = "controlling"; -const ICE_ROLE_CONTROLLED_STR: &str = "controlled"; - -impl From<&str> for RTCIceRole { - fn from(raw: &str) -> Self { - match raw { - ICE_ROLE_CONTROLLING_STR => RTCIceRole::Controlling, - ICE_ROLE_CONTROLLED_STR => RTCIceRole::Controlled, - _ => RTCIceRole::Unspecified, - } - } -} - -impl fmt::Display for RTCIceRole { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceRole::Controlling => write!(f, "{ICE_ROLE_CONTROLLING_STR}"), - RTCIceRole::Controlled => write!(f, "{ICE_ROLE_CONTROLLED_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_ice_role() { - let tests = vec![ - ("Unspecified", RTCIceRole::Unspecified), - ("controlling", RTCIceRole::Controlling), - ("controlled", RTCIceRole::Controlled), - ]; - - for (role_string, expected_role) in tests { - assert_eq!(RTCIceRole::from(role_string), expected_role); - } - } - - #[test] - fn test_ice_role_string() { - let tests = vec![ - (RTCIceRole::Unspecified, "Unspecified"), - (RTCIceRole::Controlling, "controlling"), - (RTCIceRole::Controlled, "controlled"), - ]; - - for (proto, expected_string) in tests { - assert_eq!(proto.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/ice_transport/ice_server.rs b/webrtc/src/ice_transport/ice_server.rs deleted file mode 100644 index c43e2d2b8..000000000 --- a/webrtc/src/ice_transport/ice_server.rs +++ /dev/null @@ -1,173 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::error::{Error, Result}; -use crate::ice_transport::ice_credential_type::RTCIceCredentialType; - -/// ICEServer describes a single STUN and TURN server that can be used by -/// the ICEAgent to establish a connection with a peer. -#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, Hash)] -pub struct RTCIceServer { - pub urls: Vec, - pub username: String, - pub credential: String, - pub credential_type: RTCIceCredentialType, -} - -impl RTCIceServer { - pub(crate) fn parse_url(&self, url_str: &str) -> Result { - Ok(ice::url::Url::parse_url(url_str)?) - } - - pub(crate) fn validate(&self) -> Result<()> { - self.urls()?; - Ok(()) - } - - pub(crate) fn urls(&self) -> Result> { - let mut urls = vec![]; - - for url_str in &self.urls { - let mut url = self.parse_url(url_str)?; - if url.scheme == ice::url::SchemeType::Turn || url.scheme == ice::url::SchemeType::Turns - { - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.2) - if self.username.is_empty() || self.credential.is_empty() { - return Err(Error::ErrNoTurnCredentials); - } - url.username.clone_from(&self.username); - - match self.credential_type { - RTCIceCredentialType::Password => { - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.3) - url.password.clone_from(&self.credential); - } - RTCIceCredentialType::Oauth => { - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.4) - /*if _, ok: = s.Credential.(OAuthCredential); !ok { - return nil, - &rtcerr.InvalidAccessError{Err: ErrTurnCredentials - } - }*/ - } - _ => return Err(Error::ErrTurnCredentials), - }; - } - - urls.push(url); - } - - Ok(urls) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ice_server_validate_success() { - let tests = vec![ - ( - RTCIceServer { - urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], - username: "unittest".to_owned(), - credential: "placeholder".to_owned(), - credential_type: RTCIceCredentialType::Password, - }, - true, - ), - ( - RTCIceServer { - urls: vec!["turn:[2001:db8:1234:5678::1]?transport=udp".to_owned()], - username: "unittest".to_owned(), - credential: "placeholder".to_owned(), - credential_type: RTCIceCredentialType::Password, - }, - true, - ), - /*TODO:(ICEServer{ - URLs: []string{"turn:192.158.29.39?transport=udp"}, - Username: "unittest".to_owned(), - Credential: OAuthCredential{ - MACKey: "WmtzanB3ZW9peFhtdm42NzUzNG0=", - AccessToken: "AAwg3kPHWPfvk9bDFL936wYvkoctMADzQ5VhNDgeMR3+ZlZ35byg972fW8QjpEl7bx91YLBPFsIhsxloWcXPhA==", - }, - CredentialType: ICECredentialTypeOauth, - }, true),*/ - ]; - - for (ice_server, expected_validate) in tests { - let result = ice_server.urls(); - assert_eq!(result.is_ok(), expected_validate); - } - } - - #[test] - fn test_ice_server_validate_failure() { - let tests = vec![ - ( - RTCIceServer { - urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], - ..Default::default() - }, - Error::ErrNoTurnCredentials, - ), - ( - RTCIceServer { - urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], - username: "unittest".to_owned(), - credential: String::new(), - credential_type: RTCIceCredentialType::Password, - }, - Error::ErrNoTurnCredentials, - ), - ( - RTCIceServer { - urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], - username: "unittest".to_owned(), - credential: String::new(), - credential_type: RTCIceCredentialType::Oauth, - }, - Error::ErrNoTurnCredentials, - ), - ( - RTCIceServer { - urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], - username: "unittest".to_owned(), - credential: String::new(), - credential_type: RTCIceCredentialType::Unspecified, - }, - Error::ErrNoTurnCredentials, - ), - ]; - - for (ice_server, expected_err) in tests { - if let Err(err) = ice_server.urls() { - assert_eq!(err, expected_err, "{ice_server:?} with err {err:?}"); - } else { - panic!("expected error, but got ok"); - } - } - } - - #[test] - fn test_ice_server_validate_failure_err_stun_query() { - let tests = vec![( - RTCIceServer { - urls: vec!["stun:google.de?transport=udp".to_owned()], - username: "unittest".to_owned(), - credential: String::new(), - credential_type: RTCIceCredentialType::Oauth, - }, - ice::Error::ErrStunQuery, - )]; - - for (ice_server, expected_err) in tests { - if let Err(err) = ice_server.urls() { - assert_eq!(err, expected_err, "{ice_server:?} with err {err:?}"); - } else { - panic!("expected error, but got ok"); - } - } - } -} diff --git a/webrtc/src/ice_transport/ice_transport_state.rs b/webrtc/src/ice_transport/ice_transport_state.rs deleted file mode 100644 index 2abdedef4..000000000 --- a/webrtc/src/ice_transport/ice_transport_state.rs +++ /dev/null @@ -1,185 +0,0 @@ -use std::fmt; - -use ice::state::ConnectionState; - -/// ICETransportState represents the current state of the ICE transport. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCIceTransportState { - #[default] - Unspecified, - - /// ICETransportStateNew indicates the ICETransport is waiting - /// for remote candidates to be supplied. - New, - - /// ICETransportStateChecking indicates the ICETransport has - /// received at least one remote candidate, and a local and remote - /// ICECandidateComplete dictionary was not added as the last candidate. - Checking, - - /// ICETransportStateConnected indicates the ICETransport has - /// received a response to an outgoing connectivity check, or has - /// received incoming DTLS/media after a successful response to an - /// incoming connectivity check, but is still checking other candidate - /// pairs to see if there is a better connection. - Connected, - - /// ICETransportStateCompleted indicates the ICETransport tested - /// all appropriate candidate pairs and at least one functioning - /// candidate pair has been found. - Completed, - - /// ICETransportStateFailed indicates the ICETransport the last - /// candidate was added and all appropriate candidate pairs have either - /// failed connectivity checks or have lost consent. - Failed, - - /// ICETransportStateDisconnected indicates the ICETransport has received - /// at least one local and remote candidate, but the final candidate was - /// received yet and all appropriate candidate pairs thus far have been - /// tested and failed. - Disconnected, - - /// ICETransportStateClosed indicates the ICETransport has shut down - /// and is no longer responding to STUN requests. - Closed, -} - -const ICE_TRANSPORT_STATE_NEW_STR: &str = "new"; -const ICE_TRANSPORT_STATE_CHECKING_STR: &str = "checking"; -const ICE_TRANSPORT_STATE_CONNECTED_STR: &str = "connected"; -const ICE_TRANSPORT_STATE_COMPLETED_STR: &str = "completed"; -const ICE_TRANSPORT_STATE_FAILED_STR: &str = "failed"; -const ICE_TRANSPORT_STATE_DISCONNECTED_STR: &str = "disconnected"; -const ICE_TRANSPORT_STATE_CLOSED_STR: &str = "closed"; - -impl From<&str> for RTCIceTransportState { - fn from(raw: &str) -> Self { - match raw { - ICE_TRANSPORT_STATE_NEW_STR => RTCIceTransportState::New, - ICE_TRANSPORT_STATE_CHECKING_STR => RTCIceTransportState::Checking, - ICE_TRANSPORT_STATE_CONNECTED_STR => RTCIceTransportState::Connected, - ICE_TRANSPORT_STATE_COMPLETED_STR => RTCIceTransportState::Completed, - ICE_TRANSPORT_STATE_FAILED_STR => RTCIceTransportState::Failed, - ICE_TRANSPORT_STATE_DISCONNECTED_STR => RTCIceTransportState::Disconnected, - ICE_TRANSPORT_STATE_CLOSED_STR => RTCIceTransportState::Closed, - _ => RTCIceTransportState::Unspecified, - } - } -} - -impl From for RTCIceTransportState { - fn from(v: u8) -> Self { - match v { - 1 => Self::New, - 2 => Self::Checking, - 3 => Self::Connected, - 4 => Self::Completed, - 5 => Self::Failed, - 6 => Self::Disconnected, - 7 => Self::Closed, - _ => Self::Unspecified, - } - } -} - -impl fmt::Display for RTCIceTransportState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCIceTransportState::New => write!(f, "{ICE_TRANSPORT_STATE_NEW_STR}"), - RTCIceTransportState::Checking => write!(f, "{ICE_TRANSPORT_STATE_CHECKING_STR}"), - RTCIceTransportState::Connected => { - write!(f, "{ICE_TRANSPORT_STATE_CONNECTED_STR}") - } - RTCIceTransportState::Completed => write!(f, "{ICE_TRANSPORT_STATE_COMPLETED_STR}"), - RTCIceTransportState::Failed => { - write!(f, "{ICE_TRANSPORT_STATE_FAILED_STR}") - } - RTCIceTransportState::Disconnected => { - write!(f, "{ICE_TRANSPORT_STATE_DISCONNECTED_STR}") - } - RTCIceTransportState::Closed => { - write!(f, "{ICE_TRANSPORT_STATE_CLOSED_STR}") - } - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -impl From for RTCIceTransportState { - fn from(raw: ConnectionState) -> Self { - match raw { - ConnectionState::New => RTCIceTransportState::New, - ConnectionState::Checking => RTCIceTransportState::Checking, - ConnectionState::Connected => RTCIceTransportState::Connected, - ConnectionState::Completed => RTCIceTransportState::Completed, - ConnectionState::Failed => RTCIceTransportState::Failed, - ConnectionState::Disconnected => RTCIceTransportState::Disconnected, - ConnectionState::Closed => RTCIceTransportState::Closed, - _ => RTCIceTransportState::Unspecified, - } - } -} - -impl RTCIceTransportState { - pub(crate) fn to_ice(self) -> ConnectionState { - match self { - RTCIceTransportState::New => ConnectionState::New, - RTCIceTransportState::Checking => ConnectionState::Checking, - RTCIceTransportState::Connected => ConnectionState::Connected, - RTCIceTransportState::Completed => ConnectionState::Completed, - RTCIceTransportState::Failed => ConnectionState::Failed, - RTCIceTransportState::Disconnected => ConnectionState::Disconnected, - RTCIceTransportState::Closed => ConnectionState::Closed, - _ => ConnectionState::Unspecified, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ice_transport_state_string() { - let tests = vec![ - (RTCIceTransportState::Unspecified, "Unspecified"), - (RTCIceTransportState::New, "new"), - (RTCIceTransportState::Checking, "checking"), - (RTCIceTransportState::Connected, "connected"), - (RTCIceTransportState::Completed, "completed"), - (RTCIceTransportState::Failed, "failed"), - (RTCIceTransportState::Disconnected, "disconnected"), - (RTCIceTransportState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string); - } - } - - #[test] - fn test_ice_transport_state_convert() { - let tests = vec![ - ( - RTCIceTransportState::Unspecified, - ConnectionState::Unspecified, - ), - (RTCIceTransportState::New, ConnectionState::New), - (RTCIceTransportState::Checking, ConnectionState::Checking), - (RTCIceTransportState::Connected, ConnectionState::Connected), - (RTCIceTransportState::Completed, ConnectionState::Completed), - (RTCIceTransportState::Failed, ConnectionState::Failed), - ( - RTCIceTransportState::Disconnected, - ConnectionState::Disconnected, - ), - (RTCIceTransportState::Closed, ConnectionState::Closed), - ]; - - for (native, ice_state) in tests { - assert_eq!(native.to_ice(), ice_state); - assert_eq!(native, RTCIceTransportState::from(ice_state)); - } - } -} diff --git a/webrtc/src/ice_transport/ice_transport_test.rs b/webrtc/src/ice_transport/ice_transport_test.rs deleted file mode 100644 index 7f6511204..000000000 --- a/webrtc/src/ice_transport/ice_transport_test.rs +++ /dev/null @@ -1,122 +0,0 @@ -use portable_atomic::AtomicU32; -use tokio::time::Duration; -use waitgroup::WaitGroup; - -use super::*; -use crate::api::media_engine::MediaEngine; -use crate::api::APIBuilder; -use crate::error::Result; -use crate::ice_transport::ice_connection_state::RTCIceConnectionState; -use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; -use crate::peer_connection::peer_connection_test::{ - close_pair_now, new_pair, signal_pair, until_connection_state, -}; - -#[tokio::test] -async fn test_ice_transport_on_selected_candidate_pair_change() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; - - let (ice_complete_tx, mut ice_complete_rx) = mpsc::channel::<()>(1); - let ice_complete_tx = Arc::new(Mutex::new(Some(ice_complete_tx))); - pc_answer.on_ice_connection_state_change(Box::new(move |ice_state: RTCIceConnectionState| { - let ice_complete_tx2 = Arc::clone(&ice_complete_tx); - Box::pin(async move { - if ice_state == RTCIceConnectionState::Connected { - tokio::time::sleep(Duration::from_secs(1)).await; - let mut done = ice_complete_tx2.lock().await; - done.take(); - } - }) - })); - - let sender_called_candidate_change = Arc::new(AtomicU32::new(0)); - let sender_called_candidate_change2 = Arc::clone(&sender_called_candidate_change); - pc_offer - .sctp() - .transport() - .ice_transport() - .on_selected_candidate_pair_change(Box::new(move |_: RTCIceCandidatePair| { - sender_called_candidate_change2.store(1, Ordering::SeqCst); - Box::pin(async {}) - })); - - signal_pair(&mut pc_offer, &mut pc_answer).await?; - - let _ = ice_complete_rx.recv().await; - assert_eq!( - sender_called_candidate_change.load(Ordering::SeqCst), - 1, - "Sender ICETransport OnSelectedCandidateChange was never called" - ); - - close_pair_now(&pc_offer, &pc_answer).await; - - Ok(()) -} - -#[tokio::test] -async fn test_ice_transport_get_selected_candidate_pair() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut offerer, mut answerer) = new_pair(&api).await?; - - let peer_connection_connected = WaitGroup::new(); - until_connection_state( - &mut offerer, - &peer_connection_connected, - RTCPeerConnectionState::Connected, - ) - .await; - until_connection_state( - &mut answerer, - &peer_connection_connected, - RTCPeerConnectionState::Connected, - ) - .await; - - let offerer_selected_pair = offerer - .sctp() - .transport() - .ice_transport() - .get_selected_candidate_pair() - .await; - assert!(offerer_selected_pair.is_none()); - - let answerer_selected_pair = answerer - .sctp() - .transport() - .ice_transport() - .get_selected_candidate_pair() - .await; - assert!(answerer_selected_pair.is_none()); - - signal_pair(&mut offerer, &mut answerer).await?; - - peer_connection_connected.wait().await; - - let offerer_selected_pair = offerer - .sctp() - .transport() - .ice_transport() - .get_selected_candidate_pair() - .await; - assert!(offerer_selected_pair.is_some()); - - let answerer_selected_pair = answerer - .sctp() - .transport() - .ice_transport() - .get_selected_candidate_pair() - .await; - assert!(answerer_selected_pair.is_some()); - - close_pair_now(&offerer, &answerer).await; - - Ok(()) -} diff --git a/webrtc/src/ice_transport/mod.rs b/webrtc/src/ice_transport/mod.rs deleted file mode 100644 index e2b96da53..000000000 --- a/webrtc/src/ice_transport/mod.rs +++ /dev/null @@ -1,356 +0,0 @@ -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use ice::candidate::Candidate; -use ice::state::ConnectionState; -use ice_candidate::RTCIceCandidate; -use ice_candidate_pair::RTCIceCandidatePair; -use ice_gatherer::RTCIceGatherer; -use ice_role::RTCIceRole; -use portable_atomic::AtomicU8; -use tokio::sync::{mpsc, Mutex}; -use util::Conn; - -use crate::error::{flatten_errs, Error, Result}; -use crate::ice_transport::ice_parameters::RTCIceParameters; -use crate::ice_transport::ice_transport_state::RTCIceTransportState; -use crate::mux::endpoint::Endpoint; -use crate::mux::mux_func::MatchFunc; -use crate::mux::{Config, Mux}; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::ICETransportStats; -use crate::stats::StatsReportType::Transport; - -#[cfg(test)] -mod ice_transport_test; - -pub mod ice_candidate; -pub mod ice_candidate_pair; -pub mod ice_candidate_type; -pub mod ice_connection_state; -pub mod ice_credential_type; -pub mod ice_gatherer; -pub mod ice_gatherer_state; -pub mod ice_gathering_state; -pub mod ice_parameters; -pub mod ice_protocol; -pub mod ice_role; -pub mod ice_server; -pub mod ice_transport_state; - -pub type OnConnectionStateChangeHdlrFn = Box< - dyn (FnMut(RTCIceTransportState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnSelectedCandidatePairChangeHdlrFn = Box< - dyn (FnMut(RTCIceCandidatePair) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -#[derive(Default)] -struct ICETransportInternal { - role: RTCIceRole, - conn: Option>, //AgentConn - mux: Option, - cancel_tx: Option>, -} - -/// ICETransport allows an application access to information about the ICE -/// transport over which packets are sent and received. -#[derive(Default)] -pub struct RTCIceTransport { - pub(crate) gatherer: Arc, - on_connection_state_change_handler: Arc>>, - on_selected_candidate_pair_change_handler: - Arc>>, - state: Arc, // ICETransportState - internal: Mutex, -} - -impl RTCIceTransport { - /// creates a new new_icetransport. - pub(crate) fn new(gatherer: Arc) -> Self { - RTCIceTransport { - state: Arc::new(AtomicU8::new(RTCIceTransportState::New as u8)), - gatherer, - ..Default::default() - } - } - - /// get_selected_candidate_pair returns the selected candidate pair on which packets are sent - /// if there is no selected pair nil is returned - pub async fn get_selected_candidate_pair(&self) -> Option { - if let Some(agent) = self.gatherer.get_agent().await { - if let Some(ice_pair) = agent.get_selected_candidate_pair() { - let local = RTCIceCandidate::from(&ice_pair.local); - let remote = RTCIceCandidate::from(&ice_pair.remote); - return Some(RTCIceCandidatePair::new(local, remote)); - } - } - None - } - - /// Start incoming connectivity checks based on its configured role. - pub async fn start(&self, params: &RTCIceParameters, role: Option) -> Result<()> { - if self.state() != RTCIceTransportState::New { - return Err(Error::ErrICETransportNotInNew); - } - - self.ensure_gatherer().await?; - - if let Some(agent) = self.gatherer.get_agent().await { - let state = Arc::clone(&self.state); - - let on_connection_state_change_handler = - Arc::clone(&self.on_connection_state_change_handler); - agent.on_connection_state_change(Box::new(move |ice_state: ConnectionState| { - let s = RTCIceTransportState::from(ice_state); - let on_connection_state_change_handler_clone = - Arc::clone(&on_connection_state_change_handler); - state.store(s as u8, Ordering::SeqCst); - Box::pin(async move { - if let Some(handler) = &*on_connection_state_change_handler_clone.load() { - let mut f = handler.lock().await; - f(s).await; - } - }) - })); - - let on_selected_candidate_pair_change_handler = - Arc::clone(&self.on_selected_candidate_pair_change_handler); - agent.on_selected_candidate_pair_change(Box::new( - move |local: &Arc, - remote: &Arc| { - let on_selected_candidate_pair_change_handler_clone = - Arc::clone(&on_selected_candidate_pair_change_handler); - let local = RTCIceCandidate::from(local); - let remote = RTCIceCandidate::from(remote); - Box::pin(async move { - if let Some(handler) = - &*on_selected_candidate_pair_change_handler_clone.load() - { - let mut f = handler.lock().await; - f(RTCIceCandidatePair::new(local, remote)).await; - } - }) - }, - )); - - let role = if let Some(role) = role { - role - } else { - RTCIceRole::Controlled - }; - - let (cancel_tx, cancel_rx) = mpsc::channel(1); - { - let mut internal = self.internal.lock().await; - internal.role = role; - internal.cancel_tx = Some(cancel_tx); - } - - let conn: Arc = match role { - RTCIceRole::Controlling => { - agent - .dial( - cancel_rx, - params.username_fragment.clone(), - params.password.clone(), - ) - .await? - } - - RTCIceRole::Controlled => { - agent - .accept( - cancel_rx, - params.username_fragment.clone(), - params.password.clone(), - ) - .await? - } - - _ => return Err(Error::ErrICERoleUnknown), - }; - - let config = Config { - conn: Arc::clone(&conn), - buffer_size: self.gatherer.setting_engine.get_receive_mtu(), - }; - - { - let mut internal = self.internal.lock().await; - internal.conn = Some(conn); - internal.mux = Some(Mux::new(config)); - } - - Ok(()) - } else { - Err(Error::ErrICEAgentNotExist) - } - } - - /// restart is not exposed currently because ORTC has users create a whole new ICETransport - /// so for now lets keep it private so we don't cause ORTC users to depend on non-standard APIs - pub(crate) async fn restart(&self) -> Result<()> { - if let Some(agent) = self.gatherer.get_agent().await { - agent - .restart( - self.gatherer - .setting_engine - .candidates - .username_fragment - .clone(), - self.gatherer.setting_engine.candidates.password.clone(), - ) - .await?; - } else { - return Err(Error::ErrICEAgentNotExist); - } - self.gatherer.gather().await - } - - /// Stop irreversibly stops the ICETransport. - pub async fn stop(&self) -> Result<()> { - self.set_state(RTCIceTransportState::Closed); - - let mut errs: Vec = vec![]; - { - let mut internal = self.internal.lock().await; - internal.cancel_tx.take(); - if let Some(mut mux) = internal.mux.take() { - mux.close().await; - } - if let Some(conn) = internal.conn.take() { - if let Err(err) = conn.close().await { - errs.push(err.into()); - } - } - } - - if let Err(err) = self.gatherer.close().await { - errs.push(err); - } - - flatten_errs(errs) - } - - /// on_selected_candidate_pair_change sets a handler that is invoked when a new - /// ICE candidate pair is selected - pub fn on_selected_candidate_pair_change(&self, f: OnSelectedCandidatePairChangeHdlrFn) { - self.on_selected_candidate_pair_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_connection_state_change sets a handler that is fired when the ICE - /// connection state changes. - pub fn on_connection_state_change(&self, f: OnConnectionStateChangeHdlrFn) { - self.on_connection_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// Role indicates the current role of the ICE transport. - pub async fn role(&self) -> RTCIceRole { - let internal = self.internal.lock().await; - internal.role - } - - /// set_remote_candidates sets the sequence of candidates associated with the remote ICETransport. - pub async fn set_remote_candidates(&self, remote_candidates: &[RTCIceCandidate]) -> Result<()> { - self.ensure_gatherer().await?; - - if let Some(agent) = self.gatherer.get_agent().await { - for rc in remote_candidates { - let c: Arc = Arc::new(rc.to_ice()?); - agent.add_remote_candidate(&c)?; - } - Ok(()) - } else { - Err(Error::ErrICEAgentNotExist) - } - } - - /// adds a candidate associated with the remote ICETransport. - pub async fn add_remote_candidate( - &self, - remote_candidate: Option, - ) -> Result<()> { - self.ensure_gatherer().await?; - - if let Some(agent) = self.gatherer.get_agent().await { - if let Some(r) = remote_candidate { - let c: Arc = Arc::new(r.to_ice()?); - agent.add_remote_candidate(&c)?; - } - - Ok(()) - } else { - Err(Error::ErrICEAgentNotExist) - } - } - - /// State returns the current ice transport state. - pub fn state(&self) -> RTCIceTransportState { - RTCIceTransportState::from(self.state.load(Ordering::SeqCst)) - } - - pub(crate) fn set_state(&self, s: RTCIceTransportState) { - self.state.store(s as u8, Ordering::SeqCst) - } - - pub(crate) async fn new_endpoint(&self, f: MatchFunc) -> Option> { - let internal = self.internal.lock().await; - if let Some(mux) = &internal.mux { - Some(mux.new_endpoint(f).await) - } else { - None - } - } - - pub(crate) async fn ensure_gatherer(&self) -> Result<()> { - if self.gatherer.get_agent().await.is_none() { - self.gatherer.create_agent().await - } else { - Ok(()) - } - } - - pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { - if let Some(agent) = self.gatherer.get_agent().await { - let stats = ICETransportStats::new("ice_transport".to_string(), agent); - - collector.insert("ice_transport".to_string(), Transport(stats)); - } - } - - pub(crate) async fn have_remote_credentials_change( - &self, - new_ufrag: &str, - new_pwd: &str, - ) -> bool { - if let Some(agent) = self.gatherer.get_agent().await { - let (ufrag, upwd) = agent.get_remote_user_credentials().await; - ufrag != new_ufrag || upwd != new_pwd - } else { - false - } - } - - pub(crate) async fn set_remote_credentials( - &self, - new_ufrag: String, - new_pwd: String, - ) -> Result<()> { - if let Some(agent) = self.gatherer.get_agent().await { - Ok(agent.set_remote_credentials(new_ufrag, new_pwd).await?) - } else { - Err(Error::ErrICEAgentNotExist) - } - } -} diff --git a/webrtc/src/lib.rs b/webrtc/src/lib.rs deleted file mode 100644 index a2decd674..000000000 --- a/webrtc/src/lib.rs +++ /dev/null @@ -1,44 +0,0 @@ -#![warn(rust_2018_idioms)] -#![allow(dead_code)] - -pub use {data, dtls, ice, interceptor, mdns, media, rtcp, rtp, sctp, sdp, srtp, stun, turn, util}; - -/// [`peer_connection::RTCPeerConnection`] allows to establish connection between two peers given RTC configuration. Its API is similar to one in JavaScript. -pub mod peer_connection; - -/// The utilities defining transport between peers. Contains [`ice_transport::ice_server::RTCIceServer`] struct which describes how peer does ICE (Interactive Connectivity Establishment). -pub mod ice_transport; - -/// WebRTC DataChannel can be used for peer-to-peer transmitting arbitrary binary data. -pub mod data_channel; - -/// Module responsible for multiplexing data streams of different protocols on one socket. Custom [`mux::endpoint::Endpoint`] with [`mux::mux_func::MatchFunc`] can be used for parsing your application-specific byte stream. -pub mod mux; // TODO: why is this public? does someone really extend WebRTC stack? - -/// Measuring connection statistics, such as amount of data transmitted or round trip time. -pub mod stats; - -/// [`Error`] enumerates WebRTC problems, [`error::OnErrorHdlrFn`] defines type for callback-logger. -pub mod error; - -/// Set of constructors for WebRTC primitives. Subject to deprecation in future. -pub mod api; - -pub mod dtls_transport; -pub mod rtp_transceiver; -pub mod sctp_transport; -pub mod track; - -pub use error::Error; - -#[macro_use] -extern crate lazy_static; - -pub(crate) const UNSPECIFIED_STR: &str = "Unspecified"; - -/// Equal to UDP MTU -pub(crate) const RECEIVE_MTU: usize = 1460; - -pub(crate) const SDP_ATTRIBUTE_RID: &str = "rid"; -pub(crate) const SDP_ATTRIBUTE_SIMULCAST: &str = "simulcast"; -pub(crate) const GENERATED_CERTIFICATE_ORIGIN: &str = "WebRTC"; diff --git a/webrtc/src/mux/endpoint.rs b/webrtc/src/mux/endpoint.rs deleted file mode 100644 index ac01f4027..000000000 --- a/webrtc/src/mux/endpoint.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::collections::HashMap; -use std::io; -use std::net::SocketAddr; -use std::sync::Arc; - -use async_trait::async_trait; -use tokio::sync::Mutex; -use util::{Buffer, Conn}; - -use crate::mux::mux_func::MatchFunc; - -/// Endpoint implements net.Conn. It is used to read muxed packets. -pub struct Endpoint { - pub(crate) id: usize, - pub(crate) buffer: Buffer, - pub(crate) match_fn: MatchFunc, - pub(crate) next_conn: Arc, - pub(crate) endpoints: Arc>>>, -} - -impl Endpoint { - /// Close unregisters the endpoint from the Mux - pub async fn close(&self) -> Result<()> { - self.buffer.close().await; - - let mut endpoints = self.endpoints.lock().await; - endpoints.remove(&self.id); - - Ok(()) - } -} - -type Result = std::result::Result; - -#[async_trait] -impl Conn for Endpoint { - async fn connect(&self, _addr: SocketAddr) -> Result<()> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - /// reads a packet of len(p) bytes from the underlying conn - /// that are matched by the associated MuxFunc - async fn recv(&self, buf: &mut [u8]) -> Result { - match self.buffer.read(buf, None).await { - Ok(n) => Ok(n), - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), - } - } - async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - /// writes bytes to the underlying conn - async fn send(&self, buf: &[u8]) -> Result { - self.next_conn.send(buf).await - } - - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> Result { - self.next_conn.local_addr() - } - - fn remote_addr(&self) -> Option { - self.next_conn.remote_addr() - } - - async fn close(&self) -> Result<()> { - self.next_conn.close().await - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} diff --git a/webrtc/src/mux/mod.rs b/webrtc/src/mux/mod.rs deleted file mode 100644 index 876150900..000000000 --- a/webrtc/src/mux/mod.rs +++ /dev/null @@ -1,157 +0,0 @@ -#[cfg(test)] -mod mux_test; - -pub mod endpoint; -pub mod mux_func; - -use std::collections::HashMap; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicUsize; -use tokio::sync::{mpsc, Mutex}; -use util::{Buffer, Conn}; - -use crate::error::Result; -use crate::mux::endpoint::Endpoint; -use crate::mux::mux_func::MatchFunc; -use crate::util::Error; - -/// mux multiplexes packets on a single socket (RFC7983) - -/// The maximum amount of data that can be buffered before returning errors. -const MAX_BUFFER_SIZE: usize = 1000 * 1000; // 1MB - -/// Config collects the arguments to mux.Mux construction into -/// a single structure -pub struct Config { - pub conn: Arc, - pub buffer_size: usize, -} - -/// Mux allows multiplexing -#[derive(Clone)] -pub struct Mux { - id: Arc, - next_conn: Arc, - endpoints: Arc>>>, - buffer_size: usize, - closed_ch_tx: Option>, -} - -impl Mux { - pub fn new(config: Config) -> Self { - let (closed_ch_tx, closed_ch_rx) = mpsc::channel(1); - let m = Mux { - id: Arc::new(AtomicUsize::new(0)), - next_conn: Arc::clone(&config.conn), - endpoints: Arc::new(Mutex::new(HashMap::new())), - buffer_size: config.buffer_size, - closed_ch_tx: Some(closed_ch_tx), - }; - - let buffer_size = m.buffer_size; - let next_conn = Arc::clone(&m.next_conn); - let endpoints = Arc::clone(&m.endpoints); - tokio::spawn(async move { - Mux::read_loop(buffer_size, next_conn, closed_ch_rx, endpoints).await; - }); - - m - } - - /// creates a new Endpoint - pub async fn new_endpoint(&self, f: MatchFunc) -> Arc { - let mut endpoints = self.endpoints.lock().await; - - let id = self.id.fetch_add(1, Ordering::SeqCst); - // Set a maximum size of the buffer in bytes. - let e = Arc::new(Endpoint { - id, - buffer: Buffer::new(0, MAX_BUFFER_SIZE), - match_fn: f, - next_conn: Arc::clone(&self.next_conn), - endpoints: Arc::clone(&self.endpoints), - }); - - endpoints.insert(e.id, Arc::clone(&e)); - - e - } - - /// remove_endpoint removes an endpoint from the Mux - pub async fn remove_endpoint(&mut self, e: &Endpoint) { - let mut endpoints = self.endpoints.lock().await; - endpoints.remove(&e.id); - } - - /// Close closes the Mux and all associated Endpoints. - pub async fn close(&mut self) { - self.closed_ch_tx.take(); - - let mut endpoints = self.endpoints.lock().await; - endpoints.clear(); - } - - async fn read_loop( - buffer_size: usize, - next_conn: Arc, - mut closed_ch_rx: mpsc::Receiver<()>, - endpoints: Arc>>>, - ) { - let mut buf = vec![0u8; buffer_size]; - let mut n = 0usize; - loop { - tokio::select! { - _ = closed_ch_rx.recv() => break, - result = next_conn.recv(&mut buf) => { - if let Ok(m) = result{ - n = m; - } - } - }; - - if let Err(err) = Mux::dispatch(&buf[..n], &endpoints).await { - log::error!("mux: ending readLoop dispatch error {:?}", err); - break; - } - } - } - - async fn dispatch( - buf: &[u8], - endpoints: &Arc>>>, - ) -> Result<()> { - let mut endpoint = None; - - { - let eps = endpoints.lock().await; - for ep in eps.values() { - if (ep.match_fn)(buf) { - endpoint = Some(Arc::clone(ep)); - break; - } - } - } - - if let Some(ep) = endpoint { - match ep.buffer.write(buf).await { - // Expected when bytes are received faster than the endpoint can process them - Err(Error::ErrBufferFull) => { - log::info!("mux: endpoint buffer is full, dropping packet") - } - Ok(_) => (), - Err(e) => return Err(crate::Error::Util(e)), - } - } else if !buf.is_empty() { - log::warn!( - "Warning: mux: no endpoint for packet starting with {}", - buf[0] - ); - } else { - log::warn!("Warning: mux: no endpoint for zero length packet"); - } - - Ok(()) - } -} diff --git a/webrtc/src/mux/mux_func.rs b/webrtc/src/mux/mux_func.rs deleted file mode 100644 index dfc30eefc..000000000 --- a/webrtc/src/mux/mux_func.rs +++ /dev/null @@ -1,63 +0,0 @@ -/// MatchFunc allows custom logic for mapping packets to an Endpoint -pub type MatchFunc = Box bool) + Send + Sync>; - -/// match_all always returns true -pub fn match_all(_b: &[u8]) -> bool { - true -} - -/// match_range is a MatchFunc that accepts packets with the first byte in [lower..upper] -pub fn match_range(lower: u8, upper: u8) -> MatchFunc { - Box::new(move |buf: &[u8]| -> bool { - if buf.is_empty() { - return false; - } - let b = buf[0]; - b >= lower && b <= upper - }) -} - -/// MatchFuncs as described in RFC7983 -/// -/// +----------------+ -/// | [0..3] -+--> forward to STUN -/// | | -/// | [16..19] -+--> forward to ZRTP -/// | | -/// packet --> | [20..63] -+--> forward to DTLS -/// | | -/// | [64..79] -+--> forward to TURN Channel -/// | | -/// | [128..191] -+--> forward to RTP/RTCP -/// +----------------+ -/// match_dtls is a MatchFunc that accepts packets with the first byte in [20..63] -/// as defined in RFC7983 -pub fn match_dtls(b: &[u8]) -> bool { - match_range(20, 63)(b) -} - -// match_srtp_or_srtcp is a MatchFunc that accepts packets with the first byte in [128..191] -// as defined in RFC7983 -pub fn match_srtp_or_srtcp(b: &[u8]) -> bool { - match_range(128, 191)(b) -} - -pub(crate) fn is_rtcp(buf: &[u8]) -> bool { - // Not long enough to determine RTP/RTCP - if buf.len() < 4 { - return false; - } - - let rtcp_packet_type = buf[1]; - (192..=223).contains(&rtcp_packet_type) -} - -/// match_srtp is a MatchFunc that only matches SRTP and not SRTCP -pub fn match_srtp(buf: &[u8]) -> bool { - match_srtp_or_srtcp(buf) && !is_rtcp(buf) -} - -/// match_srtcp is a MatchFunc that only matches SRTCP and not SRTP -pub fn match_srtcp(buf: &[u8]) -> bool { - match_srtp_or_srtcp(buf) && is_rtcp(buf) -} diff --git a/webrtc/src/mux/mux_test.rs b/webrtc/src/mux/mux_test.rs deleted file mode 100644 index e9a8b5e0a..000000000 --- a/webrtc/src/mux/mux_test.rs +++ /dev/null @@ -1,142 +0,0 @@ -use std::io; -use std::net::SocketAddr; -use std::sync::atomic::Ordering; - -use async_trait::async_trait; -use portable_atomic::AtomicUsize; -use util::conn::conn_pipe::pipe; - -use super::*; -use crate::mux::mux_func::{match_all, match_srtp}; - -const TEST_PIPE_BUFFER_SIZE: usize = 8192; - -#[tokio::test] -async fn test_no_endpoints() -> crate::error::Result<()> { - // In memory pipe - let (ca, _) = pipe(); - - let mut m = Mux::new(Config { - conn: Arc::new(ca), - buffer_size: TEST_PIPE_BUFFER_SIZE, - }); - - Mux::dispatch(&[0], &m.endpoints).await?; - m.close().await; - - Ok(()) -} - -struct MuxErrorConn { - idx: AtomicUsize, - data: Vec>, -} - -type Result = std::result::Result; - -#[async_trait] -impl Conn for MuxErrorConn { - async fn connect(&self, _addr: SocketAddr) -> Result<()> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, buf: &mut [u8]) -> Result { - let idx = self.idx.fetch_add(1, Ordering::SeqCst); - if idx < self.data.len() { - let n = std::cmp::min(buf.len(), self.data[idx].len()); - buf[..n].copy_from_slice(&self.data[idx][..n]); - Ok(n) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - format!("idx {} >= data.len {}", idx, self.data.len()), - ) - .into()) - } - } - - async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr)> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn send(&self, _buf: &[u8]) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - fn local_addr(&self) -> Result { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - fn remote_addr(&self) -> Option { - None - } - - async fn close(&self) -> Result<()> { - Ok(()) - } - - fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) { - self - } -} - -#[tokio::test] -async fn test_non_fatal_read() -> Result<()> { - let expected_data = b"expected_data".to_vec(); - - let conn = Arc::new(MuxErrorConn { - idx: AtomicUsize::new(0), - data: vec![ - expected_data.clone(), - expected_data.clone(), - expected_data.clone(), - ], - }); - - let mut m = Mux::new(Config { - conn, - buffer_size: TEST_PIPE_BUFFER_SIZE, - }); - - let e = m.new_endpoint(Box::new(match_all)).await; - let mut buff = vec![0u8; TEST_PIPE_BUFFER_SIZE]; - - let n = e.recv(&mut buff).await?; - assert_eq!(&buff[..n], expected_data); - - let n = e.recv(&mut buff).await?; - assert_eq!(&buff[..n], expected_data); - - let n = e.recv(&mut buff).await?; - assert_eq!(&buff[..n], expected_data); - - m.close().await; - - Ok(()) -} - -#[tokio::test] -async fn test_non_fatal_dispatch() -> Result<()> { - let (ca, cb) = pipe(); - - let mut m = Mux::new(Config { - conn: Arc::new(ca), - buffer_size: TEST_PIPE_BUFFER_SIZE, - }); - - let e = m.new_endpoint(Box::new(match_srtp)).await; - e.buffer.set_limit_size(1).await; - - for _ in 0..25 { - let srtp_packet = [128, 1, 2, 3, 4].to_vec(); - cb.send(&srtp_packet).await?; - } - - m.close().await; - - Ok(()) -} diff --git a/webrtc/src/peer_connection/certificate.rs b/webrtc/src/peer_connection/certificate.rs deleted file mode 100644 index 5f9b962a9..000000000 --- a/webrtc/src/peer_connection/certificate.rs +++ /dev/null @@ -1,291 +0,0 @@ -use std::ops::Add; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use dtls::crypto::{CryptoPrivateKey, CryptoPrivateKeyKind}; -use rcgen::{CertificateParams, KeyPair}; -use ring::rand::SystemRandom; -use ring::rsa; -use ring::signature::{EcdsaKeyPair, Ed25519KeyPair}; -use sha2::{Digest, Sha256}; - -use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; -use crate::error::{Error, Result}; -use crate::peer_connection::math_rand_alpha; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::{CertificateStats, StatsReportType}; - -/// Certificate represents a X.509 certificate used to authenticate WebRTC communications. -#[derive(Clone, Debug)] -pub struct RTCCertificate { - /// DTLS certificate. - pub(crate) dtls_certificate: dtls::crypto::Certificate, - /// Timestamp after which this certificate is no longer valid. - pub(crate) expires: SystemTime, - /// Certificate's ID used for statistics. - /// - /// Example: "certificate-1667202302853538793" - /// - /// See [`CertificateStats`]. - pub(crate) stats_id: String, -} - -impl PartialEq for RTCCertificate { - fn eq(&self, other: &Self) -> bool { - self.dtls_certificate == other.dtls_certificate - } -} - -impl RTCCertificate { - /// Generates a new certificate from the given parameters. - /// - /// See [`rcgen::Certificate::from_params`]. - fn from_params(params: CertificateParams, key_pair: KeyPair) -> Result { - let not_after = params.not_after; - - let x509_cert = params.self_signed(&key_pair).unwrap(); - let serialized_der = key_pair.serialize_der(); - - let private_key = if key_pair.is_compatible(&rcgen::PKCS_ED25519) { - CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Ed25519( - Ed25519KeyPair::from_pkcs8(&serialized_der) - .map_err(|e| Error::new(e.to_string()))?, - ), - serialized_der, - } - } else if key_pair.is_compatible(&rcgen::PKCS_ECDSA_P256_SHA256) { - CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Ecdsa256( - EcdsaKeyPair::from_pkcs8( - &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING, - &serialized_der, - &SystemRandom::new(), - ) - .map_err(|e| Error::new(e.to_string()))?, - ), - serialized_der, - } - } else if key_pair.is_compatible(&rcgen::PKCS_RSA_SHA256) { - CryptoPrivateKey { - kind: CryptoPrivateKeyKind::Rsa256( - rsa::KeyPair::from_pkcs8(&serialized_der) - .map_err(|e| Error::new(e.to_string()))?, - ), - serialized_der, - } - } else { - return Err(Error::new("Unsupported key_pair".to_owned())); - }; - - let expires = if cfg!(target_arch = "arm") { - // Workaround for issue overflow when adding duration to instant on armv7 - // https://github.com/webrtc-rs/examples/issues/5 https://github.com/chronotope/chrono/issues/343 - SystemTime::now().add(Duration::from_secs(172800)) //60*60*48 or 2 days - } else { - not_after.into() - }; - - Ok(Self { - dtls_certificate: dtls::crypto::Certificate { - certificate: vec![x509_cert.der().to_owned()], - private_key, - }, - expires, - stats_id: gen_stats_id(), - }) - } - - /// Generates a new certificate with default [`CertificateParams`] using the given keypair. - pub fn from_key_pair(key_pair: KeyPair) -> Result { - if !(key_pair.is_compatible(&rcgen::PKCS_ED25519) - || key_pair.is_compatible(&rcgen::PKCS_ECDSA_P256_SHA256) - || key_pair.is_compatible(&rcgen::PKCS_RSA_SHA256)) - { - return Err(Error::new("Unsupported key_pair".to_owned())); - } - - RTCCertificate::from_params( - CertificateParams::new(vec![math_rand_alpha(16)]).unwrap(), - key_pair, - ) - } - - /// Parses a certificate from the ASCII PEM format. - #[cfg(feature = "pem")] - pub fn from_pem(pem_str: &str) -> Result { - let mut pem_blocks = pem_str.split("\n\n"); - let first_block = if let Some(b) = pem_blocks.next() { - b - } else { - return Err(Error::InvalidPEM("empty PEM".into())); - }; - let expires_pem = - pem::parse(first_block).map_err(|e| Error::new(format!("can't parse PEM: {e}")))?; - if expires_pem.tag() != "EXPIRES" { - return Err(Error::InvalidPEM(format!( - "invalid tag (expected: 'EXPIRES', got '{}')", - expires_pem.tag() - ))); - } - let mut bytes = [0u8; 8]; - bytes.copy_from_slice(&expires_pem.contents()[..8]); - let expires = if let Some(e) = - SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(u64::from_le_bytes(bytes))) - { - e - } else { - return Err(Error::InvalidPEM("failed to calculate SystemTime".into())); - }; - let dtls_certificate = - dtls::crypto::Certificate::from_pem(&pem_blocks.collect::>().join("\n\n"))?; - Ok(RTCCertificate::from_existing(dtls_certificate, expires)) - } - - /// Builds a [`RTCCertificate`] using the existing DTLS certificate. - /// - /// Use this method when you have a persistent certificate (i.e. you don't want to generate a - /// new one for each DTLS connection). - /// - /// NOTE: ID used for statistics will be different as it's neither derived from the given - /// certificate nor persisted along it when using [`RTCCertificate::serialize_pem`]. - pub fn from_existing(dtls_certificate: dtls::crypto::Certificate, expires: SystemTime) -> Self { - Self { - dtls_certificate, - expires, - // TODO: figure out if it needs to be persisted - stats_id: gen_stats_id(), - } - } - - /// Serializes the certificate (including the private key) in PKCS#8 format in PEM. - #[cfg(any(doc, feature = "pem"))] - pub fn serialize_pem(&self) -> String { - // Encode `expires` as a PEM block. - // - // TODO: serialize as nanos when https://github.com/rust-lang/rust/issues/103332 is fixed. - let expires_pem = pem::Pem::new( - "EXPIRES".to_string(), - self.expires - .duration_since(SystemTime::UNIX_EPOCH) - .expect("expires to be valid") - .as_secs() - .to_le_bytes() - .to_vec(), - ); - format!( - "{}\n{}", - pem::encode(&expires_pem), - self.dtls_certificate.serialize_pem() - ) - } - - /// get_fingerprints returns a SHA-256 fingerprint of this certificate. - /// - /// TODO: return a fingerprint computed with the digest algorithm used in the certificate - /// signature. - pub fn get_fingerprints(&self) -> Vec { - let mut fingerprints = Vec::new(); - - for c in &self.dtls_certificate.certificate { - let mut h = Sha256::new(); - h.update(c.as_ref()); - let hashed = h.finalize(); - let values: Vec = hashed.iter().map(|x| format! {"{x:02x}"}).collect(); - - fingerprints.push(RTCDtlsFingerprint { - algorithm: "sha-256".to_owned(), - value: values.join(":"), - }); - } - - fingerprints - } - - pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { - if let Some(fingerprint) = self.get_fingerprints().into_iter().next() { - let stats = CertificateStats::new(self, fingerprint); - collector.insert( - self.stats_id.clone(), - StatsReportType::CertificateStats(stats), - ); - } - } -} - -fn gen_stats_id() -> String { - format!( - "certificate-{}", - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos() as u64 - ) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_generate_certificate_rsa() -> Result<()> { - let key_pair = KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256); - assert!(key_pair.is_err(), "RcgenError::KeyGenerationUnavailable"); - - Ok(()) - } - - #[test] - fn test_generate_certificate_ecdsa() -> Result<()> { - let kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let _cert = RTCCertificate::from_key_pair(kp)?; - - Ok(()) - } - - #[test] - fn test_generate_certificate_eddsa() -> Result<()> { - let kp = KeyPair::generate_for(&rcgen::PKCS_ED25519)?; - let _cert = RTCCertificate::from_key_pair(kp)?; - - Ok(()) - } - - #[test] - fn test_certificate_equal() -> Result<()> { - let kp1 = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let cert1 = RTCCertificate::from_key_pair(kp1)?; - - let kp2 = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let cert2 = RTCCertificate::from_key_pair(kp2)?; - - assert_ne!(cert1, cert2); - - Ok(()) - } - - #[test] - fn test_generate_certificate_expires_and_stats_id() -> Result<()> { - let kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let cert = RTCCertificate::from_key_pair(kp)?; - - let now = SystemTime::now(); - assert!(cert.expires.duration_since(now).is_ok()); - assert!(cert.stats_id.contains("certificate")); - - Ok(()) - } - - #[cfg(feature = "pem")] - #[test] - fn test_certificate_serialize_pem_and_from_pem() -> Result<()> { - let kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let cert = RTCCertificate::from_key_pair(kp)?; - - let pem = cert.serialize_pem(); - let loaded_cert = RTCCertificate::from_pem(&pem)?; - - assert_eq!(loaded_cert, cert); - - Ok(()) - } -} diff --git a/webrtc/src/peer_connection/configuration.rs b/webrtc/src/peer_connection/configuration.rs deleted file mode 100644 index 58793fb53..000000000 --- a/webrtc/src/peer_connection/configuration.rs +++ /dev/null @@ -1,148 +0,0 @@ -use crate::ice_transport::ice_server::RTCIceServer; -use crate::peer_connection::certificate::RTCCertificate; -use crate::peer_connection::policy::bundle_policy::RTCBundlePolicy; -use crate::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy; -use crate::peer_connection::policy::rtcp_mux_policy::RTCRtcpMuxPolicy; - -/// A Configuration defines how peer-to-peer communication via PeerConnection -/// is established or re-established. -/// Configurations may be set up once and reused across multiple connections. -/// Configurations are treated as readonly. As long as they are unmodified, -/// they are safe for concurrent use. -#[derive(Default, Clone)] -pub struct RTCConfiguration { - /// iceservers defines a slice describing servers available to be used by - /// ICE, such as STUN and TURN servers. - pub ice_servers: Vec, - - /// icetransport_policy indicates which candidates the ICEAgent is allowed - /// to use. - pub ice_transport_policy: RTCIceTransportPolicy, - - /// bundle_policy indicates which media-bundling policy to use when gathering - /// ICE candidates. - pub bundle_policy: RTCBundlePolicy, - - /// rtcp_mux_policy indicates which rtcp-mux policy to use when gathering ICE - /// candidates. - pub rtcp_mux_policy: RTCRtcpMuxPolicy, - - /// peer_identity sets the target peer identity for the PeerConnection. - /// The PeerConnection will not establish a connection to a remote peer - /// unless it can be successfully authenticated with the provided name. - pub peer_identity: String, - - /// Certificates describes a set of certificates that the PeerConnection - /// uses to authenticate. Valid values for this parameter are created - /// through calls to the generate_certificate function. Although any given - /// DTLS connection will use only one certificate, this attribute allows the - /// caller to provide multiple certificates that support different - /// algorithms. The final certificate will be selected based on the DTLS - /// handshake, which establishes which certificates are allowed. The - /// PeerConnection implementation selects which of the certificates is - /// used for a given connection; how certificates are selected is outside - /// the scope of this specification. If this value is absent, then a default - /// set of certificates is generated for each PeerConnection instance. - pub certificates: Vec, - - /// icecandidate_pool_size describes the size of the prefetched ICE pool. - pub ice_candidate_pool_size: u8, -} - -impl RTCConfiguration { - /// get_iceservers side-steps the strict parsing mode of the ice package - /// (as defined in https://tools.ietf.org/html/rfc7064) by copying and then - /// stripping any erroneous queries from "stun(s):" URLs before parsing. - #[allow(clippy::assigning_clones)] - pub(crate) fn get_ice_servers(&self) -> Vec { - let mut ice_servers = self.ice_servers.clone(); - - for ice_server in &mut ice_servers { - for raw_url in &mut ice_server.urls { - if raw_url.starts_with("stun") { - // strip the query from "stun(s):" if present - let parts: Vec<&str> = raw_url.split('?').collect(); - *raw_url = parts[0].to_owned(); - } - } - } - - ice_servers - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_configuration_get_iceservers() { - { - let expected_server_str = "stun:stun.l.google.com:19302"; - let cfg = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec![expected_server_str.to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - let parsed_urls = cfg.get_ice_servers(); - assert_eq!(parsed_urls[0].urls[0], expected_server_str); - } - - { - // ignore the fact that stun URLs shouldn't have a query - let server_str = "stun:global.stun.twilio.com:3478?transport=udp"; - let expected_server_str = "stun:global.stun.twilio.com:3478"; - let cfg = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec![server_str.to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - let parsed_urls = cfg.get_ice_servers(); - assert_eq!(parsed_urls[0].urls[0], expected_server_str); - } - } - - /*TODO:#[test] fn test_configuration_json() { - - let j = r#" - { - "iceServers": [{"URLs": ["turn:turn.example.org"], - "username": "jch", - "credential": "topsecret" - }], - "iceTransportPolicy": "relay", - "bundlePolicy": "balanced", - "rtcpMuxPolicy": "require" - }"#; - - conf := Configuration{ - ICEServers: []ICEServer{ - { - URLs: []string{"turn:turn.example.org"}, - Username: "jch", - Credential: "topsecret", - }, - }, - ICETransportPolicy: ICETransportPolicyRelay, - BundlePolicy: BundlePolicyBalanced, - RTCPMuxPolicy: RTCPMuxPolicyRequire, - } - - var conf2 Configuration - assert.NoError(t, json.Unmarshal([]byte(j), &conf2)) - assert.Equal(t, conf, conf2) - - j2, err := json.Marshal(conf2) - assert.NoError(t, err) - - var conf3 Configuration - assert.NoError(t, json.Unmarshal(j2, &conf3)) - assert.Equal(t, conf2, conf3) - }*/ -} diff --git a/webrtc/src/peer_connection/mod.rs b/webrtc/src/peer_connection/mod.rs deleted file mode 100644 index e5cec0894..000000000 --- a/webrtc/src/peer_connection/mod.rs +++ /dev/null @@ -1,2118 +0,0 @@ -#[cfg(test)] -pub(crate) mod peer_connection_test; - -/// Custom media-related options, such as `voice_activity_detection`, which are negotiated while establishing connection. -pub mod offer_answer_options; - -/// [`RTCSessionDescription`] - wrapper for SDP text and negotiations stage ([`RTCSdpType`]: offer - pranswer - answer - rollback). -pub mod sdp; - -pub mod certificate; -pub mod configuration; -pub(crate) mod operation; -mod peer_connection_internal; -pub mod peer_connection_state; -pub mod policy; -pub mod signaling_state; - -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; - -use ::ice::candidate::candidate_base::unmarshal_candidate; -use ::ice::candidate::Candidate; -use ::sdp::description::session::*; -use ::sdp::util::ConnectionRole; -use arc_swap::ArcSwapOption; -use async_trait::async_trait; -use interceptor::{stats, Attributes, Interceptor, RTCPWriter}; -use peer_connection_internal::*; -use portable_atomic::{AtomicBool, AtomicU64, AtomicU8}; -use rand::{thread_rng, Rng}; -use rcgen::KeyPair; -use smol_str::SmolStr; -use srtp::stream::Stream; -use tokio::sync::{mpsc, Mutex}; - -use crate::api::media_engine::MediaEngine; -use crate::api::setting_engine::SettingEngine; -use crate::api::API; -use crate::data_channel::data_channel_init::RTCDataChannelInit; -use crate::data_channel::data_channel_parameters::DataChannelParameters; -use crate::data_channel::data_channel_state::RTCDataChannelState; -use crate::data_channel::RTCDataChannel; -use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; -use crate::dtls_transport::dtls_parameters::DTLSParameters; -use crate::dtls_transport::dtls_role::{ - DTLSRole, DEFAULT_DTLS_ROLE_ANSWER, DEFAULT_DTLS_ROLE_OFFER, -}; -use crate::dtls_transport::dtls_transport_state::RTCDtlsTransportState; -use crate::dtls_transport::RTCDtlsTransport; -use crate::error::{flatten_errs, Error, Result}; -use crate::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; -use crate::ice_transport::ice_connection_state::RTCIceConnectionState; -use crate::ice_transport::ice_gatherer::{ - OnGatheringCompleteHdlrFn, OnICEGathererStateChangeHdlrFn, OnLocalCandidateHdlrFn, - RTCIceGatherOptions, RTCIceGatherer, -}; -use crate::ice_transport::ice_gatherer_state::RTCIceGathererState; -use crate::ice_transport::ice_gathering_state::RTCIceGatheringState; -use crate::ice_transport::ice_parameters::RTCIceParameters; -use crate::ice_transport::ice_role::RTCIceRole; -use crate::ice_transport::ice_transport_state::RTCIceTransportState; -use crate::ice_transport::RTCIceTransport; -use crate::peer_connection::certificate::RTCCertificate; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::peer_connection::offer_answer_options::{RTCAnswerOptions, RTCOfferOptions}; -use crate::peer_connection::operation::{Operation, Operations}; -use crate::peer_connection::peer_connection_state::{ - NegotiationNeededState, RTCPeerConnectionState, -}; -use crate::peer_connection::sdp::sdp_type::RTCSdpType; -use crate::peer_connection::sdp::session_description::RTCSessionDescription; -use crate::peer_connection::sdp::*; -use crate::peer_connection::signaling_state::{ - check_next_signaling_state, RTCSignalingState, StateChangeOp, -}; -use crate::rtp_transceiver::rtp_codec::{RTCRtpHeaderExtensionCapability, RTPCodecType}; -use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver; -use crate::rtp_transceiver::rtp_sender::RTCRtpSender; -use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use crate::rtp_transceiver::{ - find_by_mid, handle_unknown_rtp_packet, satisfy_type_and_direction, RTCRtpTransceiver, - RTCRtpTransceiverInit, SSRC, -}; -use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; -use crate::sctp_transport::sctp_transport_state::RTCSctpTransportState; -use crate::sctp_transport::RTCSctpTransport; -use crate::stats::StatsReport; -use crate::track::track_local::TrackLocal; -use crate::track::track_remote::TrackRemote; - -/// SIMULCAST_PROBE_COUNT is the amount of RTP Packets -/// that handleUndeclaredSSRC will read and try to dispatch from -/// mid and rid values -pub(crate) const SIMULCAST_PROBE_COUNT: usize = 10; - -/// SIMULCAST_MAX_PROBE_ROUTINES is how many active routines can be used to probe -/// If the total amount of incoming SSRCes exceeds this new requests will be ignored -pub(crate) const SIMULCAST_MAX_PROBE_ROUTINES: u64 = 25; - -pub(crate) const MEDIA_SECTION_APPLICATION: &str = "application"; - -const RUNES_ALPHA: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - -/// math_rand_alpha generates a mathematical random alphabet sequence of the requested length. -pub fn math_rand_alpha(n: usize) -> String { - let mut rng = thread_rng(); - - let rand_string: String = (0..n) - .map(|_| { - let idx = rng.gen_range(0..RUNES_ALPHA.len()); - RUNES_ALPHA[idx] as char - }) - .collect(); - - rand_string -} - -pub type OnSignalingStateChangeHdlrFn = Box< - dyn (FnMut(RTCSignalingState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnICEConnectionStateChangeHdlrFn = Box< - dyn (FnMut(RTCIceConnectionState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnPeerConnectionStateChangeHdlrFn = Box< - dyn (FnMut(RTCPeerConnectionState) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnDataChannelHdlrFn = Box< - dyn (FnMut(Arc) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnTrackHdlrFn = Box< - dyn (FnMut( - Arc, - Arc, - Arc, - ) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnNegotiationNeededHdlrFn = - Box Pin + Send + 'static>>) + Send + Sync>; - -#[derive(Clone)] -struct StartTransportsParams { - ice_transport: Arc, - dtls_transport: Arc, - on_peer_connection_state_change_handler: Arc>>, - is_closed: Arc, - peer_connection_state: Arc, - ice_connection_state: Arc, -} - -#[derive(Clone)] -struct CheckNegotiationNeededParams { - sctp_transport: Arc, - rtp_transceivers: Arc>>>, - current_local_description: Arc>>, - current_remote_description: Arc>>, -} - -#[derive(Clone)] -struct NegotiationNeededParams { - on_negotiation_needed_handler: Arc>>, - is_closed: Arc, - ops: Arc, - negotiation_needed_state: Arc, - is_negotiation_needed: Arc, - signaling_state: Arc, - check_negotiation_needed_params: CheckNegotiationNeededParams, -} - -/// PeerConnection represents a WebRTC connection that establishes a -/// peer-to-peer communications with another PeerConnection instance in a -/// browser, or to another endpoint implementing the required protocols. -pub struct RTCPeerConnection { - stats_id: String, - idp_login_url: Option, - - configuration: Mutex, - - interceptor_rtcp_writer: Arc, - - interceptor: Arc, - - pub(crate) internal: Arc, -} - -impl std::fmt::Debug for RTCPeerConnection { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RTCPeerConnection") - .field("stats_id", &self.stats_id) - .field("idp_login_url", &self.idp_login_url) - .field("signaling_state", &self.signaling_state()) - .field("ice_connection_state", &self.ice_connection_state()) - .finish() - } -} - -impl std::fmt::Display for RTCPeerConnection { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "(RTCPeerConnection {})", self.stats_id) - } -} - -impl RTCPeerConnection { - /// creates a PeerConnection with the default codecs and - /// interceptors. See register_default_codecs and register_default_interceptors. - /// - /// If you wish to customize the set of available codecs or the set of - /// active interceptors, create a MediaEngine and call api.new_peer_connection - /// instead of this function. - pub(crate) async fn new(api: &API, mut configuration: RTCConfiguration) -> Result { - RTCPeerConnection::init_configuration(&mut configuration)?; - - let (interceptor, stats_interceptor): (Arc, _) = { - let mut chain = api.interceptor_registry.build_chain("")?; - let stats_interceptor = stats::make_stats_interceptor(""); - chain.add(stats_interceptor.clone()); - - (Arc::new(chain), stats_interceptor) - }; - - let weak_interceptor = Arc::downgrade(&interceptor); - let (internal, configuration) = - PeerConnectionInternal::new(api, weak_interceptor, stats_interceptor, configuration) - .await?; - let internal_rtcp_writer = Arc::clone(&internal) as Arc; - let interceptor_rtcp_writer = interceptor.bind_rtcp_writer(internal_rtcp_writer).await; - - // (Step #2) - // Some variables defined explicitly despite their implicit zero values to - // allow better readability to understand what is happening. - Ok(RTCPeerConnection { - stats_id: format!( - "PeerConnection-{}", - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos() - ), - interceptor, - interceptor_rtcp_writer, - internal, - configuration: Mutex::new(configuration), - idp_login_url: None, - }) - } - - /// init_configuration defines validation of the specified Configuration and - /// its assignment to the internal configuration variable. This function differs - /// from its set_configuration counterpart because most of the checks do not - /// include verification statements related to the existing state. Thus the - /// function describes only minor verification of some the struct variables. - fn init_configuration(configuration: &mut RTCConfiguration) -> Result<()> { - let sanitized_ice_servers = configuration.get_ice_servers(); - if !sanitized_ice_servers.is_empty() { - for server in &sanitized_ice_servers { - server.validate()?; - } - } - - // (step #3) - if !configuration.certificates.is_empty() { - let now = SystemTime::now(); - for cert in &configuration.certificates { - cert.expires - .duration_since(now) - .map_err(|_| Error::ErrCertificateExpired)?; - } - } else { - let kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let cert = RTCCertificate::from_key_pair(kp)?; - configuration.certificates = vec![cert]; - }; - - Ok(()) - } - - /// on_signaling_state_change sets an event handler which is invoked when the - /// peer connection's signaling state changes - pub fn on_signaling_state_change(&self, f: OnSignalingStateChangeHdlrFn) { - self.internal - .on_signaling_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))) - } - - async fn do_signaling_state_change(&self, new_state: RTCSignalingState) { - log::info!("signaling state changed to {}", new_state); - if let Some(handler) = &*self.internal.on_signaling_state_change_handler.load() { - let mut f = handler.lock().await; - f(new_state).await; - } - } - - /// on_data_channel sets an event handler which is invoked when a data - /// channel message arrives from a remote peer. - pub fn on_data_channel(&self, f: OnDataChannelHdlrFn) { - self.internal - .on_data_channel_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_negotiation_needed sets an event handler which is invoked when - /// a change has occurred which requires session negotiation - pub fn on_negotiation_needed(&self, f: OnNegotiationNeededHdlrFn) { - self.internal - .on_negotiation_needed_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - fn do_negotiation_needed_inner(params: &NegotiationNeededParams) -> bool { - // https://w3c.github.io/webrtc-pc/#updating-the-negotiation-needed-flag - // non-canon step 1 - let state: NegotiationNeededState = params - .negotiation_needed_state - .load(Ordering::SeqCst) - .into(); - if state == NegotiationNeededState::Run { - params - .negotiation_needed_state - .store(NegotiationNeededState::Queue as u8, Ordering::SeqCst); - false - } else if state == NegotiationNeededState::Queue { - false - } else { - params - .negotiation_needed_state - .store(NegotiationNeededState::Run as u8, Ordering::SeqCst); - true - } - } - /// do_negotiation_needed enqueues negotiation_needed_op if necessary - /// caller of this method should hold `pc.mu` lock - async fn do_negotiation_needed(params: NegotiationNeededParams) { - if !RTCPeerConnection::do_negotiation_needed_inner(¶ms) { - return; - } - - let params2 = params.clone(); - let _ = params - .ops - .enqueue(Operation::new( - move || { - let params3 = params2.clone(); - Box::pin(async move { RTCPeerConnection::negotiation_needed_op(params3).await }) - }, - "do_negotiation_needed", - )) - .await; - } - - async fn after_negotiation_needed_op(params: NegotiationNeededParams) -> bool { - let old_negotiation_needed_state = params.negotiation_needed_state.load(Ordering::SeqCst); - - params - .negotiation_needed_state - .store(NegotiationNeededState::Empty as u8, Ordering::SeqCst); - - if old_negotiation_needed_state == NegotiationNeededState::Queue as u8 { - RTCPeerConnection::do_negotiation_needed_inner(¶ms) - } else { - false - } - } - - async fn negotiation_needed_op(params: NegotiationNeededParams) -> bool { - // Don't run NegotiatedNeeded checks if on_negotiation_needed is not set - let handler = &*params.on_negotiation_needed_handler.load(); - if handler.is_none() { - return false; - } - - // https://www.w3.org/TR/webrtc/#updating-the-negotiation-needed-flag - // Step 2.1 - if params.is_closed.load(Ordering::SeqCst) { - return false; - } - // non-canon step 2.2 - if !params.ops.is_empty().await { - //enqueue negotiation_needed_op again by return true - return true; - } - - // non-canon, run again if there was a request - // starting defer(after_do_negotiation_needed(params).await); - - // Step 2.3 - if params.signaling_state.load(Ordering::SeqCst) != RTCSignalingState::Stable as u8 { - return RTCPeerConnection::after_negotiation_needed_op(params).await; - } - - // Step 2.4 - if !RTCPeerConnection::check_negotiation_needed(¶ms.check_negotiation_needed_params) - .await - { - params.is_negotiation_needed.store(false, Ordering::SeqCst); - return RTCPeerConnection::after_negotiation_needed_op(params).await; - } - - // Step 2.5 - if params.is_negotiation_needed.load(Ordering::SeqCst) { - return RTCPeerConnection::after_negotiation_needed_op(params).await; - } - - // Step 2.6 - params.is_negotiation_needed.store(true, Ordering::SeqCst); - - // Step 2.7 - if let Some(handler) = handler { - let mut f = handler.lock().await; - f().await; - } - - RTCPeerConnection::after_negotiation_needed_op(params).await - } - - async fn check_negotiation_needed(params: &CheckNegotiationNeededParams) -> bool { - // To check if negotiation is needed for connection, perform the following checks: - // Skip 1, 2 steps - // Step 3 - let current_local_description = { - let current_local_description = params.current_local_description.lock().await; - current_local_description.clone() - }; - let current_remote_description = { - let current_remote_description = params.current_remote_description.lock().await; - current_remote_description.clone() - }; - - if let Some(local_desc) = ¤t_local_description { - let len_data_channel = { - let data_channels = params.sctp_transport.data_channels.lock().await; - data_channels.len() - }; - - if len_data_channel != 0 && have_data_channel(local_desc).is_none() { - return true; - } - - let transceivers = params.rtp_transceivers.lock().await; - for t in &*transceivers { - // https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag - // Step 5.1 - // if t.stopping && !t.stopped { - // return true - // } - let mid = t.mid(); - let m = mid - .as_ref() - .and_then(|mid| get_by_mid(mid.as_str(), local_desc)); - // Step 5.2 - if !t.stopped.load(Ordering::SeqCst) { - if m.is_none() { - return true; - } - - if let Some(m) = m { - // Step 5.3.1 - if t.direction().has_send() { - let dmsid = match m.attribute(ATTR_KEY_MSID).and_then(|o| o) { - Some(m) => m, - None => return true, // doesn't contain a single a=msid line - }; - - let sender = t.sender().await; - // (...)or the number of MSIDs from the a=msid lines in this m= section, - // or the MSID values themselves, differ from what is in - // transceiver.sender.[[AssociatedMediaStreamIds]], return true. - - // TODO: This check should be robuster by storing all streams in the - // local description so we can compare all of them. For no we only - // consider the first one. - - let stream_ids = sender.associated_media_stream_ids(); - // Different number of lines, 1 vs 0 - if stream_ids.is_empty() { - return true; - } - - // different stream id - if dmsid.split_whitespace().next() != Some(&stream_ids[0]) { - return true; - } - } - match local_desc.sdp_type { - RTCSdpType::Offer => { - // Step 5.3.2 - if let Some(remote_desc) = ¤t_remote_description { - if let Some(rm) = t - .mid() - .and_then(|mid| get_by_mid(mid.as_str(), remote_desc)) - { - if get_peer_direction(m) != t.direction() - && get_peer_direction(rm) != t.direction().reverse() - { - return true; - } - } else { - return true; - } - } - } - RTCSdpType::Answer => { - let remote_desc = match ¤t_remote_description { - Some(d) => d, - None => return true, - }; - let offered_direction = match t - .mid() - .and_then(|mid| get_by_mid(mid.as_str(), remote_desc)) - { - Some(d) => { - let dir = get_peer_direction(d); - if dir == RTCRtpTransceiverDirection::Unspecified { - RTCRtpTransceiverDirection::Inactive - } else { - dir - } - } - None => RTCRtpTransceiverDirection::Inactive, - }; - - let current_direction = get_peer_direction(m); - // Step 5.3.3 - if current_direction - != t.direction().intersect(offered_direction.reverse()) - { - return true; - } - } - _ => {} - }; - } - } - // Step 5.4 - if t.stopped.load(Ordering::SeqCst) { - let search_mid = match t.mid() { - Some(mid) => mid, - None => return false, - }; - - if let Some(remote_desc) = &*params.current_remote_description.lock().await { - return get_by_mid(search_mid.as_str(), local_desc).is_some() - || get_by_mid(search_mid.as_str(), remote_desc).is_some(); - } - } - } - // Step 6 - false - } else { - true - } - } - - /// on_ice_candidate sets an event handler which is invoked when a new ICE - /// candidate is found. - /// Take note that the handler is gonna be called with a nil pointer when - /// gathering is finished. - pub fn on_ice_candidate(&self, f: OnLocalCandidateHdlrFn) { - self.internal.ice_gatherer.on_local_candidate(f) - } - - /// on_ice_gathering_state_change sets an event handler which is invoked when the - /// ICE candidate gathering state has changed. - pub fn on_ice_gathering_state_change(&self, f: OnICEGathererStateChangeHdlrFn) { - self.internal.ice_gatherer.on_state_change(f) - } - - /// on_track sets an event handler which is called when remote track - /// arrives from a remote peer. - pub fn on_track(&self, f: OnTrackHdlrFn) { - self.internal - .on_track_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - fn do_track( - on_track_handler: Arc>>, - track: Arc, - receiver: Arc, - transceiver: Arc, - ) { - log::debug!("got new track: {:?}", track); - - tokio::spawn(async move { - if let Some(handler) = &*on_track_handler.load() { - let mut f = handler.lock().await; - f(track, receiver, transceiver).await; - } else { - log::warn!("on_track unset, unable to handle incoming media streams"); - } - }); - } - - /// on_ice_connection_state_change sets an event handler which is called - /// when an ICE connection state is changed. - pub fn on_ice_connection_state_change(&self, f: OnICEConnectionStateChangeHdlrFn) { - self.internal - .on_ice_connection_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - async fn do_ice_connection_state_change( - handler: &Arc>>, - ice_connection_state: &Arc, - cs: RTCIceConnectionState, - ) { - ice_connection_state.store(cs as u8, Ordering::SeqCst); - - log::info!("ICE connection state changed: {}", cs); - if let Some(handler) = &*handler.load() { - let mut f = handler.lock().await; - f(cs).await; - } - } - - /// on_peer_connection_state_change sets an event handler which is called - /// when the PeerConnectionState has changed - pub fn on_peer_connection_state_change(&self, f: OnPeerConnectionStateChangeHdlrFn) { - self.internal - .on_peer_connection_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - async fn do_peer_connection_state_change( - handler: &Arc>>, - cs: RTCPeerConnectionState, - ) { - if let Some(handler) = &*handler.load() { - let mut f = handler.lock().await; - f(cs).await; - } - } - - // set_configuration updates the configuration of this PeerConnection object. - pub async fn set_configuration(&self, configuration: RTCConfiguration) -> Result<()> { - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) - let mut config_lock = self.configuration.lock().await; - - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3) - if !configuration.peer_identity.is_empty() { - if configuration.peer_identity != config_lock.peer_identity { - return Err(Error::ErrModifyingPeerIdentity); - } - config_lock.peer_identity = configuration.peer_identity; - } - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #4) - if !configuration.certificates.is_empty() { - if configuration.certificates.len() != config_lock.certificates.len() { - return Err(Error::ErrModifyingCertificates); - } - - config_lock.certificates = configuration.certificates; - } - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #5) - - if configuration.bundle_policy != config_lock.bundle_policy { - return Err(Error::ErrModifyingBundlePolicy); - } - config_lock.bundle_policy = configuration.bundle_policy; - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #6) - if configuration.rtcp_mux_policy != config_lock.rtcp_mux_policy { - return Err(Error::ErrModifyingRTCPMuxPolicy); - } - config_lock.rtcp_mux_policy = configuration.rtcp_mux_policy; - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #7) - if configuration.ice_candidate_pool_size != 0 { - if config_lock.ice_candidate_pool_size != configuration.ice_candidate_pool_size - && self.local_description().await.is_some() - { - return Err(Error::ErrModifyingICECandidatePoolSize); - } - config_lock.ice_candidate_pool_size = configuration.ice_candidate_pool_size; - } - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #8) - - config_lock.ice_transport_policy = configuration.ice_transport_policy; - - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11) - if !configuration.ice_servers.is_empty() { - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3) - for server in &configuration.ice_servers { - server.validate()?; - } - config_lock.ice_servers = configuration.ice_servers - } - Ok(()) - } - - /// get_configuration returns a Configuration object representing the current - /// configuration of this PeerConnection object. The returned object is a - /// copy and direct mutation on it will not take affect until set_configuration - /// has been called with Configuration passed as its only argument. - /// - pub async fn get_configuration(&self) -> RTCConfiguration { - let configuration = self.configuration.lock().await; - configuration.clone() - } - - pub fn get_stats_id(&self) -> &str { - self.stats_id.as_str() - } - - /// create_offer starts the PeerConnection and generates the localDescription - /// - pub async fn create_offer( - &self, - options: Option, - ) -> Result { - let use_identity = self.idp_login_url.is_some(); - if use_identity { - return Err(Error::ErrIdentityProviderNotImplemented); - } else if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - if let Some(options) = options { - if options.ice_restart { - self.internal.ice_transport.restart().await?; - } - } - - // This may be necessary to recompute if, for example, createOffer was called when only an - // audio RTCRtpTransceiver was added to connection, but while performing the in-parallel - // steps to create an offer, a video RTCRtpTransceiver was added, requiring additional - // inspection of video system resources. - let mut count = 0; - let mut offer; - - loop { - // We cache current transceivers to ensure they aren't - // mutated during offer generation. We later check if they have - // been mutated and recompute the offer if necessary. - let current_transceivers = { - let rtp_transceivers = self.internal.rtp_transceivers.lock().await; - rtp_transceivers.clone() - }; - - // include unmatched local transceivers - // update the greater mid if the remote description provides a greater one - { - let current_remote_description = - self.internal.current_remote_description.lock().await; - if let Some(d) = &*current_remote_description { - if let Some(parsed) = &d.parsed { - for media in &parsed.media_descriptions { - if let Some(mid) = get_mid_value(media) { - if mid.is_empty() { - continue; - } - let numeric_mid = match mid.parse::() { - Ok(n) => n, - Err(_) => continue, - }; - if numeric_mid > self.internal.greater_mid.load(Ordering::SeqCst) { - self.internal - .greater_mid - .store(numeric_mid, Ordering::SeqCst); - } - } - } - } - } - } - for t in ¤t_transceivers { - if t.mid().is_some() { - continue; - } - - if let Some(gen) = &self.internal.setting_engine.mid_generator { - let current_greatest = self.internal.greater_mid.load(Ordering::SeqCst); - let mid = (gen)(current_greatest); - - // If it's possible to parse the returned mid as numeric, we will update the greater_mid field. - if let Ok(numeric_mid) = mid.parse::() { - if numeric_mid > self.internal.greater_mid.load(Ordering::SeqCst) { - self.internal - .greater_mid - .store(numeric_mid, Ordering::SeqCst); - } - } - - t.set_mid(SmolStr::from(mid))?; - } else { - let greater_mid = self.internal.greater_mid.fetch_add(1, Ordering::SeqCst); - t.set_mid(SmolStr::from(format!("{}", greater_mid + 1)))?; - } - } - - let current_remote_description_is_none = { - let current_remote_description = - self.internal.current_remote_description.lock().await; - current_remote_description.is_none() - }; - - let mut d = if current_remote_description_is_none { - self.internal - .generate_unmatched_sdp(current_transceivers, use_identity) - .await? - } else { - self.internal - .generate_matched_sdp( - current_transceivers, - use_identity, - true, /*includeUnmatched */ - DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ) - .await? - }; - - { - let mut sdp_origin = self.internal.sdp_origin.lock().await; - update_sdp_origin(&mut sdp_origin, &mut d); - } - let sdp = d.marshal(); - - offer = RTCSessionDescription { - sdp_type: RTCSdpType::Offer, - sdp, - parsed: Some(d), - }; - - // Verify local media hasn't changed during offer - // generation. Recompute if necessary - if !self.internal.has_local_description_changed(&offer).await { - break; - } - count += 1; - if count >= 128 { - return Err(Error::ErrExcessiveRetries); - } - } - - { - let mut last_offer = self.internal.last_offer.lock().await; - last_offer.clone_from(&offer.sdp); - } - Ok(offer) - } - - /// Update the PeerConnectionState given the state of relevant transports - /// - async fn update_connection_state( - on_peer_connection_state_change_handler: &Arc< - ArcSwapOption>, - >, - is_closed: &Arc, - peer_connection_state: &Arc, - ice_connection_state: RTCIceConnectionState, - dtls_transport_state: RTCDtlsTransportState, - ) { - let connection_state = - // The RTCPeerConnection object's [[IsClosed]] slot is true. - if is_closed.load(Ordering::SeqCst) { - RTCPeerConnectionState::Closed - } else if ice_connection_state == RTCIceConnectionState::Failed || dtls_transport_state == RTCDtlsTransportState::Failed { - // Any of the RTCIceTransports or RTCDtlsTransports are in a "failed" state. - RTCPeerConnectionState::Failed - } else if ice_connection_state == RTCIceConnectionState::Disconnected { - // Any of the RTCIceTransports or RTCDtlsTransports are in the "disconnected" - // state and none of them are in the "failed" or "connecting" or "checking" state. - RTCPeerConnectionState::Disconnected - } else if ice_connection_state == RTCIceConnectionState::Connected && dtls_transport_state == RTCDtlsTransportState::Connected { - // All RTCIceTransports and RTCDtlsTransports are in the "connected", "completed" or "closed" - // state and at least one of them is in the "connected" or "completed" state. - RTCPeerConnectionState::Connected - } else if ice_connection_state == RTCIceConnectionState::Checking && dtls_transport_state == RTCDtlsTransportState::Connecting { - // Any of the RTCIceTransports or RTCDtlsTransports are in the "connecting" or - // "checking" state and none of them is in the "failed" state. - RTCPeerConnectionState::Connecting - } else { - RTCPeerConnectionState::New - }; - - if peer_connection_state.load(Ordering::SeqCst) == connection_state as u8 { - return; - } - - log::info!("peer connection state changed: {}", connection_state); - peer_connection_state.store(connection_state as u8, Ordering::SeqCst); - - RTCPeerConnection::do_peer_connection_state_change( - on_peer_connection_state_change_handler, - connection_state, - ) - .await; - } - - /// create_answer starts the PeerConnection and generates the localDescription - pub async fn create_answer( - &self, - _options: Option, - ) -> Result { - let use_identity = self.idp_login_url.is_some(); - let remote_desc = self.remote_description().await; - let remote_description: RTCSessionDescription; - if let Some(desc) = remote_desc { - remote_description = desc; - } else { - return Err(Error::ErrNoRemoteDescription); - } - if use_identity { - return Err(Error::ErrIdentityProviderNotImplemented); - } else if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } else if self.signaling_state() != RTCSignalingState::HaveRemoteOffer - && self.signaling_state() != RTCSignalingState::HaveLocalPranswer - { - return Err(Error::ErrIncorrectSignalingState); - } - - let mut connection_role = self - .internal - .setting_engine - .answering_dtls_role - .to_connection_role(); - if connection_role == ConnectionRole::Unspecified { - connection_role = DEFAULT_DTLS_ROLE_ANSWER.to_connection_role(); - if let Some(parsed) = remote_description.parsed { - if Self::is_lite_set(&parsed) && !self.internal.setting_engine.candidates.ice_lite { - connection_role = DTLSRole::Server.to_connection_role(); - } - } - } - - let local_transceivers = self.get_transceivers().await; - let mut d = self - .internal - .generate_matched_sdp( - local_transceivers, - use_identity, - false, /*includeUnmatched */ - connection_role, - ) - .await?; - - { - let mut sdp_origin = self.internal.sdp_origin.lock().await; - update_sdp_origin(&mut sdp_origin, &mut d); - } - let sdp = d.marshal(); - - let answer = RTCSessionDescription { - sdp_type: RTCSdpType::Answer, - sdp, - parsed: Some(d), - }; - - { - let mut last_answer = self.internal.last_answer.lock().await; - last_answer.clone_from(&answer.sdp); - } - Ok(answer) - } - - // 4.4.1.6 Set the SessionDescription - pub(crate) async fn set_description( - &self, - sd: &RTCSessionDescription, - op: StateChangeOp, - ) -> Result<()> { - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } else if sd.sdp_type == RTCSdpType::Unspecified { - return Err(Error::ErrPeerConnSDPTypeInvalidValue); - } - - let next_state = { - let cur = self.signaling_state(); - let new_sdpdoes_not_match_offer = Error::ErrSDPDoesNotMatchOffer; - let new_sdpdoes_not_match_answer = Error::ErrSDPDoesNotMatchAnswer; - - match op { - StateChangeOp::SetLocal => { - match sd.sdp_type { - // stable->SetLocal(offer)->have-local-offer - RTCSdpType::Offer => { - let check = { - let last_offer = self.internal.last_offer.lock().await; - sd.sdp != *last_offer - }; - if check { - Err(new_sdpdoes_not_match_offer) - } else { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::HaveLocalOffer, - StateChangeOp::SetLocal, - sd.sdp_type, - ); - if next_state.is_ok() { - let mut pending_local_description = - self.internal.pending_local_description.lock().await; - *pending_local_description = Some(sd.clone()); - } - next_state - } - } - // have-remote-offer->SetLocal(answer)->stable - // have-local-pranswer->SetLocal(answer)->stable - RTCSdpType::Answer => { - let check = { - let last_answer = self.internal.last_answer.lock().await; - sd.sdp != *last_answer - }; - if check { - Err(new_sdpdoes_not_match_answer) - } else { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::Stable, - StateChangeOp::SetLocal, - sd.sdp_type, - ); - if next_state.is_ok() { - let pending_remote_description = { - let mut pending_remote_description = - self.internal.pending_remote_description.lock().await; - pending_remote_description.take() - }; - let _pending_local_description = { - let mut pending_local_description = - self.internal.pending_local_description.lock().await; - pending_local_description.take() - }; - - { - let mut current_local_description = - self.internal.current_local_description.lock().await; - *current_local_description = Some(sd.clone()); - } - { - let mut current_remote_description = - self.internal.current_remote_description.lock().await; - *current_remote_description = pending_remote_description; - } - } - next_state - } - } - RTCSdpType::Rollback => { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::Stable, - StateChangeOp::SetLocal, - sd.sdp_type, - ); - if next_state.is_ok() { - let mut pending_local_description = - self.internal.pending_local_description.lock().await; - *pending_local_description = None; - } - next_state - } - // have-remote-offer->SetLocal(pranswer)->have-local-pranswer - RTCSdpType::Pranswer => { - let check = { - let last_answer = self.internal.last_answer.lock().await; - sd.sdp != *last_answer - }; - if check { - Err(new_sdpdoes_not_match_answer) - } else { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::HaveLocalPranswer, - StateChangeOp::SetLocal, - sd.sdp_type, - ); - if next_state.is_ok() { - let mut pending_local_description = - self.internal.pending_local_description.lock().await; - *pending_local_description = Some(sd.clone()); - } - next_state - } - } - _ => Err(Error::ErrPeerConnStateChangeInvalid), - } - } - StateChangeOp::SetRemote => { - match sd.sdp_type { - // stable->SetRemote(offer)->have-remote-offer - RTCSdpType::Offer => { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::HaveRemoteOffer, - StateChangeOp::SetRemote, - sd.sdp_type, - ); - if next_state.is_ok() { - let mut pending_remote_description = - self.internal.pending_remote_description.lock().await; - *pending_remote_description = Some(sd.clone()); - } - next_state - } - // have-local-offer->SetRemote(answer)->stable - // have-remote-pranswer->SetRemote(answer)->stable - RTCSdpType::Answer => { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::Stable, - StateChangeOp::SetRemote, - sd.sdp_type, - ); - if next_state.is_ok() { - let pending_local_description = { - let mut pending_local_description = - self.internal.pending_local_description.lock().await; - pending_local_description.take() - }; - - let _pending_remote_description = { - let mut pending_remote_description = - self.internal.pending_remote_description.lock().await; - pending_remote_description.take() - }; - - { - let mut current_remote_description = - self.internal.current_remote_description.lock().await; - *current_remote_description = Some(sd.clone()); - } - { - let mut current_local_description = - self.internal.current_local_description.lock().await; - *current_local_description = pending_local_description; - } - } - next_state - } - RTCSdpType::Rollback => { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::Stable, - StateChangeOp::SetRemote, - sd.sdp_type, - ); - if next_state.is_ok() { - let mut pending_remote_description = - self.internal.pending_remote_description.lock().await; - *pending_remote_description = None; - } - next_state - } - // have-local-offer->SetRemote(pranswer)->have-remote-pranswer - RTCSdpType::Pranswer => { - let next_state = check_next_signaling_state( - cur, - RTCSignalingState::HaveRemotePranswer, - StateChangeOp::SetRemote, - sd.sdp_type, - ); - if next_state.is_ok() { - let mut pending_remote_description = - self.internal.pending_remote_description.lock().await; - *pending_remote_description = Some(sd.clone()); - } - next_state - } - _ => Err(Error::ErrPeerConnStateChangeInvalid), - } - } //_ => Err(Error::ErrPeerConnStateChangeUnhandled.into()), - } - }; - - match next_state { - Ok(next_state) => { - self.internal - .signaling_state - .store(next_state as u8, Ordering::SeqCst); - if self.signaling_state() == RTCSignalingState::Stable { - self.internal - .is_negotiation_needed - .store(false, Ordering::SeqCst); - self.internal.trigger_negotiation_needed().await; - } - self.do_signaling_state_change(next_state).await; - Ok(()) - } - Err(err) => Err(err), - } - } - - /// set_local_description sets the SessionDescription of the local peer - pub async fn set_local_description(&self, mut desc: RTCSessionDescription) -> Result<()> { - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let have_local_description = { - let current_local_description = self.internal.current_local_description.lock().await; - current_local_description.is_some() - }; - - // JSEP 5.4 - if desc.sdp.is_empty() { - match desc.sdp_type { - RTCSdpType::Answer | RTCSdpType::Pranswer => { - let last_answer = self.internal.last_answer.lock().await; - desc.sdp.clone_from(&last_answer); - } - RTCSdpType::Offer => { - let last_offer = self.internal.last_offer.lock().await; - desc.sdp.clone_from(&last_offer); - } - _ => return Err(Error::ErrPeerConnSDPTypeInvalidValueSetLocalDescription), - } - } - - desc.parsed = Some(desc.unmarshal()?); - self.set_description(&desc, StateChangeOp::SetLocal).await?; - - let we_answer = desc.sdp_type == RTCSdpType::Answer; - let remote_description = self.remote_description().await; - let mut local_transceivers = self.get_transceivers().await; - if we_answer { - if let Some(parsed) = desc.parsed { - // WebRTC Spec 1.0 https://www.w3.org/TR/webrtc/ - // Section 4.4.1.5 - for media in &parsed.media_descriptions { - if media.media_name.media == MEDIA_SECTION_APPLICATION { - continue; - } - - let kind = RTPCodecType::from(media.media_name.media.as_str()); - let direction = get_peer_direction(media); - if kind == RTPCodecType::Unspecified - || direction == RTCRtpTransceiverDirection::Unspecified - { - continue; - } - - let mid_value = match get_mid_value(media) { - Some(mid) if !mid.is_empty() => mid, - _ => continue, - }; - - let t = match find_by_mid(mid_value, &mut local_transceivers).await { - Some(t) => t, - None => continue, - }; - let previous_direction = t.current_direction(); - // 4.9.1.7.3 applying a local answer or pranswer - // Set transceiver.[[CurrentDirection]] and transceiver.[[FiredDirection]] to direction. - - // TODO: Also set FiredDirection here. - t.set_current_direction(direction); - t.process_new_current_direction(previous_direction).await?; - } - } - - if let Some(remote_desc) = remote_description { - self.start_rtp_senders().await?; - - let pci = Arc::clone(&self.internal); - let remote_desc = Arc::new(remote_desc); - self.internal - .ops - .enqueue(Operation::new( - move || { - let pc = Arc::clone(&pci); - let rd = Arc::clone(&remote_desc); - Box::pin(async move { - let _ = pc.start_rtp(have_local_description, rd).await; - false - }) - }, - "set_local_description", - )) - .await?; - } - } - - if self.internal.ice_gatherer.state() == RTCIceGathererState::New { - self.internal.ice_gatherer.gather().await - } else { - Ok(()) - } - } - - /// local_description returns PendingLocalDescription if it is not null and - /// otherwise it returns CurrentLocalDescription. This property is used to - /// determine if set_local_description has already been called. - /// - pub async fn local_description(&self) -> Option { - if let Some(pending_local_description) = self.pending_local_description().await { - return Some(pending_local_description); - } - self.current_local_description().await - } - - pub fn is_lite_set(desc: &SessionDescription) -> bool { - for a in &desc.attributes { - if a.key.trim() == ATTR_KEY_ICELITE { - return true; - } - } - false - } - - /// set_remote_description sets the SessionDescription of the remote peer - pub async fn set_remote_description(&self, mut desc: RTCSessionDescription) -> Result<()> { - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let is_renegotiation = { - let current_remote_description = self.internal.current_remote_description.lock().await; - current_remote_description.is_some() - }; - - desc.parsed = Some(desc.unmarshal()?); - self.set_description(&desc, StateChangeOp::SetRemote) - .await?; - - if let Some(parsed) = &desc.parsed { - self.internal - .media_engine - .update_from_remote_description(parsed) - .await?; - - let mut local_transceivers = self.get_transceivers().await; - let remote_description = self.remote_description().await; - let we_offer = desc.sdp_type == RTCSdpType::Answer; - - if !we_offer { - if let Some(parsed) = remote_description.as_ref().and_then(|r| r.parsed.as_ref()) { - for media in &parsed.media_descriptions { - let mid_value = match get_mid_value(media) { - Some(m) => { - if m.is_empty() { - return Err(Error::ErrPeerConnRemoteDescriptionWithoutMidValue); - } else { - m - } - } - None => continue, - }; - - if media.media_name.media == MEDIA_SECTION_APPLICATION { - continue; - } - - let kind = RTPCodecType::from(media.media_name.media.as_str()); - let direction = get_peer_direction(media); - if kind == RTPCodecType::Unspecified - || direction == RTCRtpTransceiverDirection::Unspecified - { - continue; - } - - let t = if let Some(t) = - find_by_mid(mid_value, &mut local_transceivers).await - { - Some(t) - } else { - satisfy_type_and_direction(kind, direction, &mut local_transceivers) - .await - }; - - if let Some(t) = t { - if t.mid().is_none() { - t.set_mid(SmolStr::from(mid_value))?; - } - } else { - let local_direction = - if direction == RTCRtpTransceiverDirection::Recvonly { - RTCRtpTransceiverDirection::Sendonly - } else { - RTCRtpTransceiverDirection::Recvonly - }; - - let receive_mtu = self.internal.setting_engine.get_receive_mtu(); - - let receiver = Arc::new(RTCRtpReceiver::new( - receive_mtu, - kind, - Arc::clone(&self.internal.dtls_transport), - Arc::clone(&self.internal.media_engine), - Arc::clone(&self.interceptor), - )); - - let sender = Arc::new( - RTCRtpSender::new( - receive_mtu, - None, - kind, - Arc::clone(&self.internal.dtls_transport), - Arc::clone(&self.internal.media_engine), - Arc::clone(&self.interceptor), - false, - ) - .await, - ); - - let t = RTCRtpTransceiver::new( - receiver, - sender, - local_direction, - kind, - vec![], - Arc::clone(&self.internal.media_engine), - Some(Box::new(self.internal.make_negotiation_needed_trigger())), - ) - .await; - - self.internal.add_rtp_transceiver(Arc::clone(&t)).await; - - if t.mid().is_none() { - t.set_mid(SmolStr::from(mid_value))?; - } - } - } - } - } - - if we_offer { - // WebRTC Spec 1.0 https://www.w3.org/TR/webrtc/ - // 4.5.9.2 - // This is an answer from the remote. - if let Some(parsed) = remote_description.as_ref().and_then(|r| r.parsed.as_ref()) { - for media in &parsed.media_descriptions { - let mid_value = match get_mid_value(media) { - Some(m) => { - if m.is_empty() { - return Err(Error::ErrPeerConnRemoteDescriptionWithoutMidValue); - } else { - m - } - } - None => continue, - }; - - if media.media_name.media == MEDIA_SECTION_APPLICATION { - continue; - } - let kind = RTPCodecType::from(media.media_name.media.as_str()); - let direction = get_peer_direction(media); - if kind == RTPCodecType::Unspecified - || direction == RTCRtpTransceiverDirection::Unspecified - { - continue; - } - - if let Some(t) = find_by_mid(mid_value, &mut local_transceivers).await { - let previous_direction = t.current_direction(); - - // 4.5.9.2.9 - // Let direction be an RTCRtpTransceiverDirection value representing the direction - // from the media description, but with the send and receive directions reversed to - // represent this peer's point of view. If the media description is rejected, - // set direction to "inactive". - let reversed_direction = direction.reverse(); - - // 4.5.9.2.13.2 - // Set transceiver.[[CurrentDirection]] and transceiver.[[Direction]]s to direction. - t.set_current_direction(reversed_direction); - // TODO: According to the specification we should set - // transceiver.[[Direction]] here, however libWebrtc doesn't do this. - // NOTE: After raising this it seems like the specification might - // change to remove the setting of transceiver.[[Direction]]. - // See https://github.com/w3c/webrtc-pc/issues/2751#issuecomment-1185901962 - // t.set_direction_internal(reversed_direction); - t.process_new_current_direction(previous_direction).await?; - } - } - } - } - - let (remote_ufrag, remote_pwd, candidates) = extract_ice_details(parsed).await?; - - if is_renegotiation - && self - .internal - .ice_transport - .have_remote_credentials_change(&remote_ufrag, &remote_pwd) - .await - { - // An ICE Restart only happens implicitly for a set_remote_description of type offer - if !we_offer { - self.internal.ice_transport.restart().await?; - } - - self.internal - .ice_transport - .set_remote_credentials(remote_ufrag.clone(), remote_pwd.clone()) - .await?; - } - - for candidate in candidates { - self.internal - .ice_transport - .add_remote_candidate(Some(candidate)) - .await?; - } - - if is_renegotiation { - if we_offer { - self.start_rtp_senders().await?; - - let pci = Arc::clone(&self.internal); - let remote_desc = Arc::new(desc); - self.internal - .ops - .enqueue(Operation::new( - move || { - let pc = Arc::clone(&pci); - let rd = Arc::clone(&remote_desc); - Box::pin(async move { - let _ = pc.start_rtp(true, rd).await; - false - }) - }, - "set_remote_description renegotiation", - )) - .await?; - } - return Ok(()); - } - - let remote_is_lite = Self::is_lite_set(parsed); - - let (fingerprint, fingerprint_hash) = extract_fingerprint(parsed)?; - - // If one of the agents is lite and the other one is not, the lite agent must be the controlling agent. - // If both or neither agents are lite the offering agent is controlling. - // RFC 8445 S6.1.1 - let ice_role = if (we_offer - && remote_is_lite == self.internal.setting_engine.candidates.ice_lite) - || (remote_is_lite && !self.internal.setting_engine.candidates.ice_lite) - { - RTCIceRole::Controlling - } else { - RTCIceRole::Controlled - }; - - // Start the networking in a new routine since it will block until - // the connection is actually established. - if we_offer { - self.start_rtp_senders().await?; - } - - //log::trace!("start_transports: parsed={:?}", parsed); - - let pci = Arc::clone(&self.internal); - let dtls_role = DTLSRole::from(parsed); - let remote_desc = Arc::new(desc); - self.internal - .ops - .enqueue(Operation::new( - move || { - let pc = Arc::clone(&pci); - let rd = Arc::clone(&remote_desc); - let ru = remote_ufrag.clone(); - let rp = remote_pwd.clone(); - let fp = fingerprint.clone(); - let fp_hash = fingerprint_hash.clone(); - Box::pin(async move { - log::trace!( - "start_transports: ice_role={}, dtls_role={}", - ice_role, - dtls_role, - ); - pc.start_transports(ice_role, dtls_role, ru, rp, fp, fp_hash) - .await; - - if we_offer { - let _ = pc.start_rtp(false, rd).await; - } - false - }) - }, - "set_remote_description", - )) - .await?; - } - - Ok(()) - } - - /// start_rtp_senders starts all outbound RTP streams - pub(crate) async fn start_rtp_senders(&self) -> Result<()> { - let current_transceivers = self.internal.rtp_transceivers.lock().await; - for transceiver in &*current_transceivers { - let sender = transceiver.sender().await; - if !sender.track_encodings.lock().await.is_empty() - && sender.is_negotiated() - && !sender.has_sent() - { - sender.send(&sender.get_parameters().await).await?; - } - } - - Ok(()) - } - - /// remote_description returns pending_remote_description if it is not null and - /// otherwise it returns current_remote_description. This property is used to - /// determine if setRemoteDescription has already been called. - /// - pub async fn remote_description(&self) -> Option { - self.internal.remote_description().await - } - - /// add_ice_candidate accepts an ICE candidate string and adds it - /// to the existing set of candidates. - pub async fn add_ice_candidate(&self, candidate: RTCIceCandidateInit) -> Result<()> { - if self.remote_description().await.is_none() { - return Err(Error::ErrNoRemoteDescription); - } - - let candidate_value = match candidate.candidate.strip_prefix("candidate:") { - Some(s) => s, - None => candidate.candidate.as_str(), - }; - - let ice_candidate = if !candidate_value.is_empty() { - let candidate: Arc = - Arc::new(unmarshal_candidate(candidate_value)?); - - Some(RTCIceCandidate::from(&candidate)) - } else { - None - }; - - self.internal - .ice_transport - .add_remote_candidate(ice_candidate) - .await - } - - /// ice_connection_state returns the ICE connection state of the - /// PeerConnection instance. - pub fn ice_connection_state(&self) -> RTCIceConnectionState { - self.internal - .ice_connection_state - .load(Ordering::SeqCst) - .into() - } - - /// get_senders returns the RTPSender that are currently attached to this PeerConnection - pub async fn get_senders(&self) -> Vec> { - let mut senders = vec![]; - let rtp_transceivers = self.internal.rtp_transceivers.lock().await; - for transceiver in &*rtp_transceivers { - let sender = transceiver.sender().await; - senders.push(sender); - } - senders - } - - /// get_receivers returns the RTPReceivers that are currently attached to this PeerConnection - pub async fn get_receivers(&self) -> Vec> { - let mut receivers = vec![]; - let rtp_transceivers = self.internal.rtp_transceivers.lock().await; - for transceiver in &*rtp_transceivers { - receivers.push(transceiver.receiver().await); - } - receivers - } - - /// get_transceivers returns the RtpTransceiver that are currently attached to this PeerConnection - pub async fn get_transceivers(&self) -> Vec> { - let rtp_transceivers = self.internal.rtp_transceivers.lock().await; - rtp_transceivers.clone() - } - - /// add_track adds a Track to the PeerConnection - pub async fn add_track( - &self, - track: Arc, - ) -> Result> { - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - { - let rtp_transceivers = self.internal.rtp_transceivers.lock().await; - for t in &*rtp_transceivers { - if !t.stopped.load(Ordering::SeqCst) - && t.kind == track.kind() - && t.sender() - .await - .initial_track_id() - .is_some_and(|id| id == track.id()) - { - let sender = t.sender().await; - if sender.track().await.is_none() { - if let Err(err) = sender.replace_track(Some(track)).await { - let _ = sender.stop().await; - return Err(err); - } - - t.set_direction_internal(RTCRtpTransceiverDirection::from_send_recv( - true, - t.direction().has_recv(), - )); - - self.internal.trigger_negotiation_needed().await; - return Ok(sender); - } - } - } - } - - let transceiver = self - .internal - .new_transceiver_from_track(RTCRtpTransceiverDirection::Sendrecv, track) - .await?; - self.internal - .add_rtp_transceiver(Arc::clone(&transceiver)) - .await; - - Ok(transceiver.sender().await) - } - - /// remove_track removes a Track from the PeerConnection - pub async fn remove_track(&self, sender: &Arc) -> Result<()> { - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let mut transceiver = None; - { - let rtp_transceivers = self.internal.rtp_transceivers.lock().await; - for t in &*rtp_transceivers { - if t.sender().await.id == sender.id { - if sender.track().await.is_none() { - return Ok(()); - } - transceiver = Some(t.clone()); - break; - } - } - } - - let t = transceiver.ok_or(Error::ErrSenderNotCreatedByConnection)?; - - // This also happens in `set_sending_track` but we need to make sure we do this - // before we call sender.stop to avoid a race condition when removing tracks and - // generating offers. - t.set_direction_internal(RTCRtpTransceiverDirection::from_send_recv( - false, - t.direction().has_recv(), - )); - // Stop the sender - let sender_result = sender.stop().await; - // This also updates direction - let sending_track_result = t.set_sending_track(None).await; - - if sender_result.is_ok() && sending_track_result.is_ok() { - self.internal.trigger_negotiation_needed().await; - } - Ok(()) - } - - /// add_transceiver_from_kind Create a new RtpTransceiver and adds it to the set of transceivers. - pub async fn add_transceiver_from_kind( - &self, - kind: RTPCodecType, - init: Option, - ) -> Result> { - self.internal.add_transceiver_from_kind(kind, init).await - } - - /// add_transceiver_from_track Create a new RtpTransceiver(SendRecv or SendOnly) and add it to the set of transceivers. - pub async fn add_transceiver_from_track( - &self, - track: Arc, - init: Option, - ) -> Result> { - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let direction = init - .map(|init| init.direction) - .unwrap_or(RTCRtpTransceiverDirection::Sendrecv); - - let t = self - .internal - .new_transceiver_from_track(direction, track) - .await?; - - self.internal.add_rtp_transceiver(Arc::clone(&t)).await; - - Ok(t) - } - - /// create_data_channel creates a new DataChannel object with the given label - /// and optional DataChannelInit used to configure properties of the - /// underlying channel such as data reliability. - pub async fn create_data_channel( - &self, - label: &str, - options: Option, - ) -> Result> { - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #2) - if self.internal.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let mut params = DataChannelParameters { - label: label.to_owned(), - ordered: true, - ..Default::default() - }; - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #19) - if let Some(options) = options { - // Ordered indicates if data is allowed to be delivered out of order. The - // default value of true, guarantees that data will be delivered in order. - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #9) - if let Some(ordered) = options.ordered { - params.ordered = ordered; - } - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #7) - if let Some(max_packet_life_time) = options.max_packet_life_time { - params.max_packet_life_time = max_packet_life_time; - } - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #8) - if let Some(max_retransmits) = options.max_retransmits { - params.max_retransmits = max_retransmits; - } - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #10) - if let Some(protocol) = options.protocol { - params.protocol = protocol; - } - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #11) - if params.protocol.len() > 65535 { - return Err(Error::ErrProtocolTooLarge); - } - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #12) - params.negotiated = options.negotiated; - } - - let d = Arc::new(RTCDataChannel::new( - params, - Arc::clone(&self.internal.setting_engine), - )); - - // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #16) - if d.max_packet_lifetime != 0 && d.max_retransmits != 0 { - return Err(Error::ErrRetransmitsOrPacketLifeTime); - } - - { - let mut data_channels = self.internal.sctp_transport.data_channels.lock().await; - data_channels.push(Arc::clone(&d)); - } - self.internal - .sctp_transport - .data_channels_requested - .fetch_add(1, Ordering::SeqCst); - - // If SCTP already connected open all the channels - if self.internal.sctp_transport.state() == RTCSctpTransportState::Connected { - d.open(Arc::clone(&self.internal.sctp_transport)).await?; - } - - self.internal.trigger_negotiation_needed().await; - - Ok(d) - } - - /// set_identity_provider is used to configure an identity provider to generate identity assertions - pub fn set_identity_provider(&self, _provider: &str) -> Result<()> { - Err(Error::ErrPeerConnSetIdentityProviderNotImplemented) - } - - /// write_rtcp sends a user provided RTCP packet to the connected peer. If no peer is connected the - /// packet is discarded. It also runs any configured interceptors. - pub async fn write_rtcp( - &self, - pkts: &[Box], - ) -> Result { - let a = Attributes::new(); - Ok(self.interceptor_rtcp_writer.write(pkts, &a).await?) - } - - /// close ends the PeerConnection - pub async fn close(&self) -> Result<()> { - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) - if self.internal.is_closed.load(Ordering::SeqCst) { - return Ok(()); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) - self.internal.is_closed.store(true, Ordering::SeqCst); - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3) - self.internal - .signaling_state - .store(RTCSignalingState::Closed as u8, Ordering::SeqCst); - - // Try closing everything and collect the errors - // Shutdown strategy: - // 1. All Conn close by closing their underlying Conn. - // 2. A Mux stops this chain. It won't close the underlying - // Conn if one of the endpoints is closed down. To - // continue the chain the Mux has to be closed. - let mut close_errs = vec![]; - - if let Err(err) = self.interceptor.close().await { - close_errs.push(Error::new(format!("interceptor: {err}"))); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4) - { - let mut rtp_transceivers = self.internal.rtp_transceivers.lock().await; - for t in &*rtp_transceivers { - if let Err(err) = t.stop().await { - close_errs.push(Error::new(format!("rtp_transceivers: {err}"))); - } - } - rtp_transceivers.clear(); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #5) - { - let mut data_channels = self.internal.sctp_transport.data_channels.lock().await; - for d in &*data_channels { - if let Err(err) = d.close().await { - close_errs.push(Error::new(format!("data_channels: {err}"))); - } - } - data_channels.clear(); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #6) - if let Err(err) = self.internal.sctp_transport.stop().await { - close_errs.push(Error::new(format!("sctp_transport: {err}"))); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #7) - if let Err(err) = self.internal.dtls_transport.stop().await { - close_errs.push(Error::new(format!("dtls_transport: {err}"))); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10) - if let Err(err) = self.internal.ice_transport.stop().await { - close_errs.push(Error::new(format!("ice_transport: {err}"))); - } - - // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) - RTCPeerConnection::update_connection_state( - &self.internal.on_peer_connection_state_change_handler, - &self.internal.is_closed, - &self.internal.peer_connection_state, - self.ice_connection_state(), - self.internal.dtls_transport.state(), - ) - .await; - - if let Err(err) = self.internal.ops.close().await { - close_errs.push(Error::new(format!("ops: {err}"))); - } - - flatten_errs(close_errs) - } - - /// CurrentLocalDescription represents the local description that was - /// successfully negotiated the last time the PeerConnection transitioned - /// into the stable state plus any local candidates that have been generated - /// by the ICEAgent since the offer or answer was created. - pub async fn current_local_description(&self) -> Option { - let local_description = { - let current_local_description = self.internal.current_local_description.lock().await; - current_local_description.clone() - }; - let ice_gather = Some(&self.internal.ice_gatherer); - let ice_gathering_state = self.ice_gathering_state(); - - populate_local_candidates(local_description.as_ref(), ice_gather, ice_gathering_state).await - } - - /// PendingLocalDescription represents a local description that is in the - /// process of being negotiated plus any local candidates that have been - /// generated by the ICEAgent since the offer or answer was created. If the - /// PeerConnection is in the stable state, the value is null. - pub async fn pending_local_description(&self) -> Option { - let local_description = { - let pending_local_description = self.internal.pending_local_description.lock().await; - pending_local_description.clone() - }; - let ice_gather = Some(&self.internal.ice_gatherer); - let ice_gathering_state = self.ice_gathering_state(); - - populate_local_candidates(local_description.as_ref(), ice_gather, ice_gathering_state).await - } - - /// current_remote_description represents the last remote description that was - /// successfully negotiated the last time the PeerConnection transitioned - /// into the stable state plus any remote candidates that have been supplied - /// via add_icecandidate() since the offer or answer was created. - pub async fn current_remote_description(&self) -> Option { - let current_remote_description = self.internal.current_remote_description.lock().await; - current_remote_description.clone() - } - - /// pending_remote_description represents a remote description that is in the - /// process of being negotiated, complete with any remote candidates that - /// have been supplied via add_icecandidate() since the offer or answer was - /// created. If the PeerConnection is in the stable state, the value is - /// null. - pub async fn pending_remote_description(&self) -> Option { - let pending_remote_description = self.internal.pending_remote_description.lock().await; - pending_remote_description.clone() - } - - /// signaling_state attribute returns the signaling state of the - /// PeerConnection instance. - pub fn signaling_state(&self) -> RTCSignalingState { - self.internal.signaling_state.load(Ordering::SeqCst).into() - } - - /// icegathering_state attribute returns the ICE gathering state of the - /// PeerConnection instance. - pub fn ice_gathering_state(&self) -> RTCIceGatheringState { - self.internal.ice_gathering_state() - } - - /// connection_state attribute returns the connection state of the - /// PeerConnection instance. - pub fn connection_state(&self) -> RTCPeerConnectionState { - self.internal - .peer_connection_state - .load(Ordering::SeqCst) - .into() - } - - pub async fn get_stats(&self) -> StatsReport { - self.internal - .get_stats(self.get_stats_id().to_owned()) - .await - .into() - } - - /// sctp returns the SCTPTransport for this PeerConnection - /// - /// The SCTP transport over which SCTP data is sent and received. If SCTP has not been negotiated, the value is nil. - /// - pub fn sctp(&self) -> Arc { - Arc::clone(&self.internal.sctp_transport) - } - - /// gathering_complete_promise is a Pion specific helper function that returns a channel that is closed when gathering is complete. - /// This function may be helpful in cases where you are unable to trickle your ICE Candidates. - /// - /// It is better to not use this function, and instead trickle candidates. If you use this function you will see longer connection startup times. - /// When the call is connected you will see no impact however. - pub async fn gathering_complete_promise(&self) -> mpsc::Receiver<()> { - let (gathering_complete_tx, gathering_complete_rx) = mpsc::channel(1); - - // It's possible to miss the GatherComplete event since setGatherCompleteHandler is an atomic operation and the - // promise might have been created after the gathering is finished. Therefore, we need to check if the ICE gathering - // state has changed to complete so that we don't block the caller forever. - let done = Arc::new(Mutex::new(Some(gathering_complete_tx))); - let done2 = Arc::clone(&done); - self.internal.set_gather_complete_handler(Box::new(move || { - log::trace!("setGatherCompleteHandler"); - let done3 = Arc::clone(&done2); - Box::pin(async move { - let mut d = done3.lock().await; - d.take(); - }) - })); - - if self.ice_gathering_state() == RTCIceGatheringState::Complete { - log::trace!("ICEGatheringState::Complete"); - let mut d = done.lock().await; - d.take(); - } - - gathering_complete_rx - } - - /// Returns the internal [`RTCDtlsTransport`]. - pub fn dtls_transport(&self) -> Arc { - Arc::clone(&self.internal.dtls_transport) - } - - /// Adds the specified [`RTCRtpTransceiver`] to this [`RTCPeerConnection`]. - pub async fn add_transceiver(&self, t: Arc) { - self.internal.add_rtp_transceiver(t).await - } -} diff --git a/webrtc/src/peer_connection/offer_answer_options.rs b/webrtc/src/peer_connection/offer_answer_options.rs deleted file mode 100644 index a7be6490b..000000000 --- a/webrtc/src/peer_connection/offer_answer_options.rs +++ /dev/null @@ -1,22 +0,0 @@ -/// AnswerOptions structure describes the options used to control the answer -/// creation process. -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone)] -pub struct RTCAnswerOptions { - /// voice_activity_detection allows the application to provide information - /// about whether it wishes voice detection feature to be enabled or disabled. - pub voice_activity_detection: bool, -} - -/// OfferOptions structure describes the options used to control the offer -/// creation process -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone)] -pub struct RTCOfferOptions { - /// voice_activity_detection allows the application to provide information - /// about whether it wishes voice detection feature to be enabled or disabled. - pub voice_activity_detection: bool, - - /// ice_restart forces the underlying ice gathering process to be restarted. - /// When this value is true, the generated description will have ICE - /// credentials that are different from the current credentials - pub ice_restart: bool, -} diff --git a/webrtc/src/peer_connection/operation/mod.rs b/webrtc/src/peer_connection/operation/mod.rs deleted file mode 100644 index f8556e3ab..000000000 --- a/webrtc/src/peer_connection/operation/mod.rs +++ /dev/null @@ -1,140 +0,0 @@ -#[cfg(test)] -mod operation_test; - -use std::fmt; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use portable_atomic::AtomicUsize; -use tokio::sync::mpsc; -use waitgroup::WaitGroup; - -use crate::error::Result; - -/// Operation is a function -pub struct Operation( - pub Box Pin + Send + 'static>>) + Send + Sync>, - pub &'static str, -); - -impl Operation { - pub(crate) fn new( - op: impl FnMut() -> Pin + Send + 'static>> + Send + Sync + 'static, - description: &'static str, - ) -> Self { - Self(Box::new(op), description) - } -} - -impl fmt::Debug for Operation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Operation") - .field(&"_") - .field(&self.1) - .finish() - } -} - -/// Operations is a task executor. -#[derive(Default)] -pub(crate) struct Operations { - length: Arc, - ops_tx: Option>>, - close_tx: Option>, -} - -impl Operations { - pub(crate) fn new() -> Self { - let length = Arc::new(AtomicUsize::new(0)); - let (ops_tx, ops_rx) = mpsc::unbounded_channel(); - let (close_tx, close_rx) = mpsc::channel(1); - let l = Arc::clone(&length); - let ops_tx = Arc::new(ops_tx); - let ops_tx2 = Arc::clone(&ops_tx); - tokio::spawn(async move { - Operations::start(l, ops_tx, ops_rx, close_rx).await; - }); - - Operations { - length, - ops_tx: Some(ops_tx2), - close_tx: Some(close_tx), - } - } - - /// enqueue adds a new action to be executed. If there are no actions scheduled, - /// the execution will start immediately in a new goroutine. - pub(crate) async fn enqueue(&self, op: Operation) -> Result<()> { - if let Some(ops_tx) = &self.ops_tx { - return Operations::enqueue_inner(op, ops_tx, &self.length); - } - - Ok(()) - } - - fn enqueue_inner( - op: Operation, - ops_tx: &Arc>, - length: &Arc, - ) -> Result<()> { - length.fetch_add(1, Ordering::SeqCst); - ops_tx.send(op)?; - - Ok(()) - } - - /// is_empty checks if there are tasks in the queue - pub(crate) async fn is_empty(&self) -> bool { - self.length.load(Ordering::SeqCst) == 0 - } - - /// Done blocks until all currently enqueued operations are finished executing. - /// For more complex synchronization, use Enqueue directly. - pub(crate) async fn done(&self) { - let wg = WaitGroup::new(); - let mut w = Some(wg.worker()); - let _ = self - .enqueue(Operation::new( - move || { - let _d = w.take(); - Box::pin(async { false }) - }, - "Operation::done", - )) - .await; - wg.wait().await; - } - - pub(crate) async fn start( - length: Arc, - ops_tx: Arc>, - mut ops_rx: mpsc::UnboundedReceiver, - mut close_rx: mpsc::Receiver<()>, - ) { - loop { - tokio::select! { - _ = close_rx.recv() => { - break; - } - result = ops_rx.recv() => { - if let Some(mut f) = result { - length.fetch_sub(1, Ordering::SeqCst); - if f.0().await { - // Requeue this operation - let _ = Operations::enqueue_inner(f, &ops_tx, &length); - } - } - } - } - } - } - - pub(crate) async fn close(&self) -> Result<()> { - if let Some(close_tx) = &self.close_tx { - close_tx.send(()).await?; - } - Ok(()) - } -} diff --git a/webrtc/src/peer_connection/operation/operation_test.rs b/webrtc/src/peer_connection/operation/operation_test.rs deleted file mode 100644 index 0ecc344ca..000000000 --- a/webrtc/src/peer_connection/operation/operation_test.rs +++ /dev/null @@ -1,47 +0,0 @@ -use tokio::sync::Mutex; - -use super::*; -use crate::error::Result; - -#[tokio::test] -async fn test_operations_enqueue() -> Result<()> { - let ops = Operations::new(); - for _ in 0..100 { - let results = Arc::new(Mutex::new(vec![0; 16])); - for k in 0..16 { - let r = Arc::clone(&results); - ops.enqueue(Operation::new( - move || { - let r2 = Arc::clone(&r); - Box::pin(async move { - let mut r3 = r2.lock().await; - r3[k] += k * k; - r3[k] == 225 - }) - }, - "test_operations_enqueue", - )) - .await?; - } - - ops.done().await; - let expected = vec![ - 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 450, - ]; - { - let r = results.lock().await; - assert_eq!(r.len(), expected.len()); - assert_eq!(&*r, &expected); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_operations_done() -> Result<()> { - let ops = Operations::new(); - ops.done().await; - - Ok(()) -} diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs deleted file mode 100644 index 50c2d8d89..000000000 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ /dev/null @@ -1,1522 +0,0 @@ -use std::collections::VecDeque; -use std::sync::Weak; - -use arc_swap::ArcSwapOption; -use portable_atomic::AtomicIsize; -use smol_str::SmolStr; -use tokio::time::Instant; -use util::Unmarshal; - -use super::*; -use crate::rtp_transceiver::create_stream_info; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::{ - InboundRTPStats, OutboundRTPStats, RTCStatsType, RemoteInboundRTPStats, RemoteOutboundRTPStats, - StatsReportType, -}; -use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use crate::track::TrackStream; -use crate::SDP_ATTRIBUTE_RID; - -pub(crate) struct PeerConnectionInternal { - /// a value containing the last known greater mid value - /// we internally generate mids as numbers. Needed since JSEP - /// requires that when reusing a media section a new unique mid - /// should be defined (see JSEP 3.4.1). - pub(super) greater_mid: AtomicIsize, - pub(super) sdp_origin: Mutex<::sdp::description::session::Origin>, - pub(super) last_offer: Mutex, - pub(super) last_answer: Mutex, - - pub(super) on_negotiation_needed_handler: Arc>>, - pub(super) is_closed: Arc, - - /// ops is an operations queue which will ensure the enqueued actions are - /// executed in order. It is used for asynchronously, but serially processing - /// remote and local descriptions - pub(crate) ops: Arc, - pub(super) negotiation_needed_state: Arc, - pub(super) is_negotiation_needed: Arc, - pub(super) signaling_state: Arc, - - pub(super) ice_transport: Arc, - pub(super) dtls_transport: Arc, - pub(super) on_peer_connection_state_change_handler: - Arc>>, - pub(super) peer_connection_state: Arc, - pub(super) ice_connection_state: Arc, - - pub(super) sctp_transport: Arc, - pub(super) rtp_transceivers: Arc>>>, - - pub(super) on_track_handler: Arc>>, - pub(super) on_signaling_state_change_handler: - ArcSwapOption>, - pub(super) on_ice_connection_state_change_handler: - Arc>>, - pub(super) on_data_channel_handler: Arc>>, - - pub(super) ice_gatherer: Arc, - - pub(super) current_local_description: Arc>>, - pub(super) current_remote_description: Arc>>, - pub(super) pending_local_description: Arc>>, - pub(super) pending_remote_description: Arc>>, - - // A reference to the associated API state used by this connection - pub(super) setting_engine: Arc, - pub(crate) media_engine: Arc, - pub(super) interceptor: Weak, - stats_interceptor: Arc, -} - -impl PeerConnectionInternal { - pub(super) async fn new( - api: &API, - interceptor: Weak, - stats_interceptor: Arc, - mut configuration: RTCConfiguration, - ) -> Result<(Arc, RTCConfiguration)> { - let mut pc = PeerConnectionInternal { - greater_mid: AtomicIsize::new(-1), - sdp_origin: Mutex::new(Default::default()), - last_offer: Mutex::new("".to_owned()), - last_answer: Mutex::new("".to_owned()), - - on_negotiation_needed_handler: Arc::new(ArcSwapOption::empty()), - ops: Arc::new(Operations::new()), - is_closed: Arc::new(AtomicBool::new(false)), - is_negotiation_needed: Arc::new(AtomicBool::new(false)), - negotiation_needed_state: Arc::new(AtomicU8::new(NegotiationNeededState::Empty as u8)), - signaling_state: Arc::new(AtomicU8::new(RTCSignalingState::Stable as u8)), - ice_transport: Arc::new(Default::default()), - dtls_transport: Arc::new(Default::default()), - ice_connection_state: Arc::new(AtomicU8::new(RTCIceConnectionState::New as u8)), - sctp_transport: Arc::new(Default::default()), - rtp_transceivers: Arc::new(Default::default()), - on_track_handler: Arc::new(ArcSwapOption::empty()), - on_signaling_state_change_handler: ArcSwapOption::empty(), - on_ice_connection_state_change_handler: Arc::new(ArcSwapOption::empty()), - on_data_channel_handler: Arc::new(Default::default()), - ice_gatherer: Arc::new(Default::default()), - current_local_description: Arc::new(Default::default()), - current_remote_description: Arc::new(Default::default()), - pending_local_description: Arc::new(Default::default()), - peer_connection_state: Arc::new(AtomicU8::new(RTCPeerConnectionState::New as u8)), - - setting_engine: Arc::clone(&api.setting_engine), - media_engine: if !api.setting_engine.disable_media_engine_copy { - Arc::new(api.media_engine.clone_to()) - } else { - Arc::clone(&api.media_engine) - }, - interceptor, - stats_interceptor, - on_peer_connection_state_change_handler: Arc::new(ArcSwapOption::empty()), - pending_remote_description: Arc::new(Default::default()), - }; - - // Create the ice gatherer - pc.ice_gatherer = Arc::new(api.new_ice_gatherer(RTCIceGatherOptions { - ice_servers: configuration.get_ice_servers(), - ice_gather_policy: configuration.ice_transport_policy, - })?); - - // Create the ice transport - pc.ice_transport = pc.create_ice_transport(api).await; - - // Create the DTLS transport - let certificates = configuration.certificates.drain(..).collect(); - pc.dtls_transport = - Arc::new(api.new_dtls_transport(Arc::clone(&pc.ice_transport), certificates)?); - - // Create the SCTP transport - pc.sctp_transport = Arc::new(api.new_sctp_transport(Arc::clone(&pc.dtls_transport))?); - - // Wire up the on datachannel handler - let on_data_channel_handler = Arc::clone(&pc.on_data_channel_handler); - pc.sctp_transport - .on_data_channel(Box::new(move |d: Arc| { - let on_data_channel_handler2 = Arc::clone(&on_data_channel_handler); - Box::pin(async move { - if let Some(handler) = &*on_data_channel_handler2.load() { - let mut f = handler.lock().await; - f(d).await; - } - }) - })); - - Ok((Arc::new(pc), configuration)) - } - - pub(super) async fn start_rtp( - self: &Arc, - is_renegotiation: bool, - remote_desc: Arc, - ) -> Result<()> { - let mut track_details = if let Some(parsed) = &remote_desc.parsed { - track_details_from_sdp(parsed, false) - } else { - vec![] - }; - - let current_transceivers = { - let current_transceivers = self.rtp_transceivers.lock().await; - current_transceivers.clone() - }; - - if !is_renegotiation { - self.undeclared_media_processor(); - } else { - for t in ¤t_transceivers { - let receiver = t.receiver().await; - let tracks = receiver.tracks().await; - if tracks.is_empty() { - continue; - } - - let mut receiver_needs_stopped = false; - - for t in tracks { - if !t.rid().is_empty() { - if let Some(details) = - track_details_for_rid(&track_details, SmolStr::from(t.rid())) - { - t.set_id(details.id.clone()); - t.set_stream_id(details.stream_id.clone()); - continue; - } - } else if t.ssrc() != 0 { - if let Some(details) = track_details_for_ssrc(&track_details, t.ssrc()) { - t.set_id(details.id.clone()); - t.set_stream_id(details.stream_id.clone()); - continue; - } - } - - receiver_needs_stopped = true; - } - - if !receiver_needs_stopped { - continue; - } - - log::info!("Stopping receiver {:?}", receiver); - if let Err(err) = receiver.stop().await { - log::warn!("Failed to stop RtpReceiver: {}", err); - continue; - } - - let interceptor = self - .interceptor - .upgrade() - .ok_or(Error::ErrInterceptorNotBind)?; - - let receiver = Arc::new(RTCRtpReceiver::new( - self.setting_engine.get_receive_mtu(), - receiver.kind(), - Arc::clone(&self.dtls_transport), - Arc::clone(&self.media_engine), - interceptor, - )); - t.set_receiver(receiver).await; - } - } - - self.start_rtp_receivers(&mut track_details, ¤t_transceivers) - .await?; - if let Some(parsed) = &remote_desc.parsed { - if have_application_media_section(parsed) { - self.start_sctp().await; - } - } - - Ok(()) - } - - /// undeclared_media_processor handles RTP/RTCP packets that don't match any a:ssrc lines - fn undeclared_media_processor(self: &Arc) { - let dtls_transport = Arc::clone(&self.dtls_transport); - let is_closed = Arc::clone(&self.is_closed); - let pci = Arc::clone(self); - - // SRTP acceptor - tokio::spawn(async move { - let simulcast_routine_count = Arc::new(AtomicU64::new(0)); - loop { - let srtp_session = match dtls_transport.get_srtp_session().await { - Some(s) => s, - None => { - log::warn!("undeclared_media_processor failed to open SrtpSession"); - return; - } - }; - - let stream = match srtp_session.accept().await { - Ok(stream) => stream, - Err(err) => { - log::warn!("Failed to accept RTP {}", err); - return; - } - }; - - if is_closed.load(Ordering::SeqCst) { - if let Err(err) = stream.close().await { - log::warn!("Failed to close RTP stream {}", err); - } - continue; - } - - if simulcast_routine_count.fetch_add(1, Ordering::SeqCst) + 1 - >= SIMULCAST_MAX_PROBE_ROUTINES - { - simulcast_routine_count.fetch_sub(1, Ordering::SeqCst); - log::warn!("{:?}", Error::ErrSimulcastProbeOverflow); - continue; - } - - { - let dtls_transport = Arc::clone(&dtls_transport); - let simulcast_routine_count = Arc::clone(&simulcast_routine_count); - let pci = Arc::clone(&pci); - tokio::spawn(async move { - let ssrc = stream.get_ssrc(); - - dtls_transport - .store_simulcast_stream(ssrc, Arc::clone(&stream)) - .await; - - if let Err(err) = pci.handle_incoming_ssrc(stream, ssrc).await { - log::error!( - "Incoming unhandled RTP ssrc({}), on_track will not be fired. {}", - ssrc, - err - ); - } - - simulcast_routine_count.fetch_sub(1, Ordering::SeqCst); - }); - } - } - }); - - // SRTCP acceptor - { - let dtls_transport = Arc::clone(&self.dtls_transport); - tokio::spawn(async move { - loop { - let srtcp_session = match dtls_transport.get_srtcp_session().await { - Some(s) => s, - None => { - log::warn!("undeclared_media_processor failed to open SrtcpSession"); - return; - } - }; - - let stream = match srtcp_session.accept().await { - Ok(stream) => stream, - Err(err) => { - log::warn!("Failed to accept RTCP {}", err); - return; - } - }; - log::warn!( - "Incoming unhandled RTCP ssrc({}), on_track will not be fired", - stream.get_ssrc() - ); - } - }); - } - } - - /// start_rtp_receivers opens knows inbound SRTP streams from the remote_description - async fn start_rtp_receivers( - self: &Arc, - incoming_tracks: &mut Vec, - local_transceivers: &[Arc], - ) -> Result<()> { - // Ensure we haven't already started a transceiver for this ssrc - let mut filtered_tracks = incoming_tracks.clone(); - for incoming_track in incoming_tracks { - // If we already have a TrackRemote for a given SSRC don't handle it again - for t in local_transceivers { - let receiver = t.receiver().await; - for track in receiver.tracks().await { - for ssrc in &incoming_track.ssrcs { - if *ssrc == track.ssrc() { - filter_track_with_ssrc(&mut filtered_tracks, track.ssrc()); - } - } - } - } - } - - let mut unhandled_tracks = vec![]; // filtered_tracks[:0] - for incoming_track in filtered_tracks.iter() { - let mut track_handled = false; - for t in local_transceivers { - if t.mid().as_ref() != Some(&incoming_track.mid) { - continue; - } - - if (incoming_track.kind != t.kind()) - || (t.direction() != RTCRtpTransceiverDirection::Recvonly - && t.direction() != RTCRtpTransceiverDirection::Sendrecv) - { - continue; - } - - let receiver = t.receiver().await; - if receiver.have_received().await { - continue; - } - PeerConnectionInternal::start_receiver( - self.setting_engine.get_receive_mtu(), - incoming_track, - receiver, - Arc::clone(t), - Arc::clone(&self.on_track_handler), - ) - .await; - track_handled = true; - } - - if !track_handled { - unhandled_tracks.push(incoming_track); - } - } - - Ok(()) - } - - /// Start SCTP subsystem - async fn start_sctp(&self) { - // Start sctp - if let Err(err) = self - .sctp_transport - .start(SCTPTransportCapabilities { - max_message_size: 0, - }) - .await - { - log::warn!("Failed to start SCTP: {}", err); - if let Err(err) = self.sctp_transport.stop().await { - log::warn!("Failed to stop SCTPTransport: {}", err); - } - - return; - } - - // DataChannels that need to be opened now that SCTP is available - // make a copy we may have incoming DataChannels mutating this while we open - let data_channels = { - let data_channels = self.sctp_transport.data_channels.lock().await; - data_channels.clone() - }; - - let mut opened_dc_count = 0; - for d in data_channels { - if d.ready_state() == RTCDataChannelState::Connecting { - if let Err(err) = d.open(Arc::clone(&self.sctp_transport)).await { - log::warn!("failed to open data channel: {}", err); - continue; - } - opened_dc_count += 1; - } - } - - self.sctp_transport - .data_channels_opened - .fetch_add(opened_dc_count, Ordering::SeqCst); - } - - pub(super) async fn add_transceiver_from_kind( - &self, - kind: RTPCodecType, - init: Option, - ) -> Result> { - if self.is_closed.load(Ordering::SeqCst) { - return Err(Error::ErrConnectionClosed); - } - - let direction = init - .map(|value| value.direction) - .unwrap_or(RTCRtpTransceiverDirection::Sendrecv); - - let t = match direction { - RTCRtpTransceiverDirection::Sendonly | RTCRtpTransceiverDirection::Sendrecv => { - let codec = self - .media_engine - .get_codecs_by_kind(kind) - .first() - .map(|c| c.capability.clone()) - .ok_or(Error::ErrNoCodecsAvailable)?; - let track = Arc::new(TrackLocalStaticSample::new( - codec, - math_rand_alpha(16), - math_rand_alpha(16), - )); - self.new_transceiver_from_track(direction, track).await? - } - RTCRtpTransceiverDirection::Recvonly => { - let interceptor = self - .interceptor - .upgrade() - .ok_or(Error::ErrInterceptorNotBind)?; - let receiver = Arc::new(RTCRtpReceiver::new( - self.setting_engine.get_receive_mtu(), - kind, - Arc::clone(&self.dtls_transport), - Arc::clone(&self.media_engine), - Arc::clone(&interceptor), - )); - - let sender = Arc::new( - RTCRtpSender::new( - self.setting_engine.get_receive_mtu(), - None, - kind, - Arc::clone(&self.dtls_transport), - Arc::clone(&self.media_engine), - interceptor, - false, - ) - .await, - ); - - RTCRtpTransceiver::new( - receiver, - sender, - direction, - kind, - vec![], - Arc::clone(&self.media_engine), - Some(Box::new(self.make_negotiation_needed_trigger())), - ) - .await - } - _ => return Err(Error::ErrPeerConnAddTransceiverFromKindSupport), - }; - - self.add_rtp_transceiver(Arc::clone(&t)).await; - - Ok(t) - } - - pub(super) async fn new_transceiver_from_track( - &self, - direction: RTCRtpTransceiverDirection, - track: Arc, - ) -> Result> { - let interceptor = self - .interceptor - .upgrade() - .ok_or(Error::ErrInterceptorNotBind)?; - - if direction == RTCRtpTransceiverDirection::Unspecified { - return Err(Error::ErrPeerConnAddTransceiverFromTrackSupport); - } - - let r = Arc::new(RTCRtpReceiver::new( - self.setting_engine.get_receive_mtu(), - track.kind(), - Arc::clone(&self.dtls_transport), - Arc::clone(&self.media_engine), - Arc::clone(&interceptor), - )); - - let s = Arc::new( - RTCRtpSender::new( - self.setting_engine.get_receive_mtu(), - Some(Arc::clone(&track)), - track.kind(), - Arc::clone(&self.dtls_transport), - Arc::clone(&self.media_engine), - Arc::clone(&interceptor), - false, - ) - .await, - ); - - Ok(RTCRtpTransceiver::new( - r, - s, - direction, - track.kind(), - vec![], - Arc::clone(&self.media_engine), - Some(Box::new(self.make_negotiation_needed_trigger())), - ) - .await) - } - - /// add_rtp_transceiver appends t into rtp_transceivers - /// and fires onNegotiationNeeded; - /// caller of this method should hold `self.mu` lock - pub(super) async fn add_rtp_transceiver(&self, t: Arc) { - { - let mut rtp_transceivers = self.rtp_transceivers.lock().await; - rtp_transceivers.push(t); - } - self.trigger_negotiation_needed().await; - } - - /// Helper to trigger a negotiation needed. - pub(crate) async fn trigger_negotiation_needed(&self) { - RTCPeerConnection::do_negotiation_needed(self.create_negotiation_needed_params()).await; - } - - /// Creates the parameters needed to trigger a negotiation needed. - fn create_negotiation_needed_params(&self) -> NegotiationNeededParams { - NegotiationNeededParams { - on_negotiation_needed_handler: Arc::clone(&self.on_negotiation_needed_handler), - is_closed: Arc::clone(&self.is_closed), - ops: Arc::clone(&self.ops), - negotiation_needed_state: Arc::clone(&self.negotiation_needed_state), - is_negotiation_needed: Arc::clone(&self.is_negotiation_needed), - signaling_state: Arc::clone(&self.signaling_state), - check_negotiation_needed_params: CheckNegotiationNeededParams { - sctp_transport: Arc::clone(&self.sctp_transport), - rtp_transceivers: Arc::clone(&self.rtp_transceivers), - current_local_description: Arc::clone(&self.current_local_description), - current_remote_description: Arc::clone(&self.current_remote_description), - }, - } - } - - pub(crate) fn make_negotiation_needed_trigger( - &self, - ) -> impl Fn() -> Pin + Send + Sync>> + Send + Sync { - let params = self.create_negotiation_needed_params(); - move || { - let params = params.clone(); - Box::pin(async move { - let params = params.clone(); - RTCPeerConnection::do_negotiation_needed(params).await; - }) - } - } - - pub(super) async fn remote_description(&self) -> Option { - let pending_remote_description = self.pending_remote_description.lock().await; - if pending_remote_description.is_some() { - pending_remote_description.clone() - } else { - let current_remote_description = self.current_remote_description.lock().await; - current_remote_description.clone() - } - } - - pub(super) fn set_gather_complete_handler(&self, f: OnGatheringCompleteHdlrFn) { - self.ice_gatherer.on_gathering_complete(f); - } - - /// Start all transports. PeerConnection now has enough state - pub(super) async fn start_transports( - self: &Arc, - ice_role: RTCIceRole, - dtls_role: DTLSRole, - remote_ufrag: String, - remote_pwd: String, - fingerprint: String, - fingerprint_hash: String, - ) { - // Start the ice transport - if let Err(err) = self - .ice_transport - .start( - &RTCIceParameters { - username_fragment: remote_ufrag, - password: remote_pwd, - ice_lite: false, - }, - Some(ice_role), - ) - .await - { - log::warn!("Failed to start manager ice: {}", err); - return; - } - - // Start the dtls_transport transport - let result = self - .dtls_transport - .start(DTLSParameters { - role: dtls_role, - fingerprints: vec![RTCDtlsFingerprint { - algorithm: fingerprint_hash, - value: fingerprint, - }], - }) - .await; - RTCPeerConnection::update_connection_state( - &self.on_peer_connection_state_change_handler, - &self.is_closed, - &self.peer_connection_state, - self.ice_connection_state.load(Ordering::SeqCst).into(), - self.dtls_transport.state(), - ) - .await; - if let Err(err) = result { - log::warn!("Failed to start manager dtls: {}", err); - } - } - - /// generate_unmatched_sdp generates an SDP that doesn't take remote state into account - /// This is used for the initial call for CreateOffer - pub(super) async fn generate_unmatched_sdp( - &self, - local_transceivers: Vec>, - use_identity: bool, - ) -> Result { - let d = SessionDescription::new_jsep_session_description(use_identity); - - let ice_params = self.ice_gatherer.get_local_parameters().await?; - - let candidates = self.ice_gatherer.get_local_candidates().await?; - - let mut media_sections = vec![]; - - for t in &local_transceivers { - if t.stopped.load(Ordering::SeqCst) { - // An "m=" section is generated for each - // RtpTransceiver that has been added to the PeerConnection, excluding - // any stopped RtpTransceivers; - continue; - } - - // TODO: This is dubious because of rollbacks. - t.sender().await.set_negotiated(); - media_sections.push(MediaSection { - id: t.mid().unwrap().to_string(), - transceivers: vec![Arc::clone(t)], - ..Default::default() - }); - } - - if self - .sctp_transport - .data_channels_requested - .load(Ordering::SeqCst) - != 0 - { - media_sections.push(MediaSection { - id: format!("{}", media_sections.len()), - data: true, - ..Default::default() - }); - } - - let dtls_fingerprints = if let Some(cert) = self.dtls_transport.certificates.first() { - cert.get_fingerprints() - } else { - return Err(Error::ErrNonCertificate); - }; - - let params = PopulateSdpParams { - media_description_fingerprint: self.setting_engine.sdp_media_level_fingerprints, - is_icelite: self.setting_engine.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: self.ice_gathering_state(), - match_bundle_group: None, - }; - populate_sdp( - d, - &dtls_fingerprints, - &self.media_engine, - &candidates, - &ice_params, - &media_sections, - params, - ) - .await - } - - /// generate_matched_sdp generates a SDP and takes the remote state into account - /// this is used everytime we have a remote_description - pub(super) async fn generate_matched_sdp( - &self, - mut local_transceivers: Vec>, - use_identity: bool, - include_unmatched: bool, - connection_role: ConnectionRole, - ) -> Result { - let d = SessionDescription::new_jsep_session_description(use_identity); - - let ice_params = self.ice_gatherer.get_local_parameters().await?; - let candidates = self.ice_gatherer.get_local_candidates().await?; - - let remote_description = self.remote_description().await; - let mut media_sections = vec![]; - let mut already_have_application_media_section = false; - if let Some(remote_description) = remote_description.as_ref() { - if let Some(parsed) = &remote_description.parsed { - for media in &parsed.media_descriptions { - if let Some(mid_value) = get_mid_value(media) { - if mid_value.is_empty() { - return Err(Error::ErrPeerConnRemoteDescriptionWithoutMidValue); - } - - if media.media_name.media == MEDIA_SECTION_APPLICATION { - media_sections.push(MediaSection { - id: mid_value.to_owned(), - data: true, - ..Default::default() - }); - already_have_application_media_section = true; - continue; - } - - let kind = RTPCodecType::from(media.media_name.media.as_str()); - let direction = get_peer_direction(media); - if kind == RTPCodecType::Unspecified - || direction == RTCRtpTransceiverDirection::Unspecified - { - continue; - } - - if let Some(t) = find_by_mid(mid_value, &mut local_transceivers).await { - t.sender().await.set_negotiated(); - let media_transceivers = vec![t]; - - // NB: The below could use `then_some`, but with our current MSRV - // it's not possible to actually do this. The clippy version that - // ships with 1.64.0 complains about this so we disable it for now. - #[allow(clippy::unnecessary_lazy_evaluations)] - media_sections.push(MediaSection { - id: mid_value.to_owned(), - transceivers: media_transceivers, - rid_map: get_rids(media), - offered_direction: (!include_unmatched).then(|| direction), - ..Default::default() - }); - } else { - return Err(Error::ErrPeerConnTransceiverMidNil); - } - } - } - } - } - - // If we are offering also include unmatched local transceivers - let match_bundle_group = if include_unmatched { - for t in &local_transceivers { - t.sender().await.set_negotiated(); - media_sections.push(MediaSection { - id: t.mid().unwrap().to_string(), - transceivers: vec![Arc::clone(t)], - ..Default::default() - }); - } - - if self - .sctp_transport - .data_channels_requested - .load(Ordering::SeqCst) - != 0 - && !already_have_application_media_section - { - media_sections.push(MediaSection { - id: format!("{}", media_sections.len()), - data: true, - ..Default::default() - }); - } - None - } else { - remote_description - .as_ref() - .and_then(|d| d.parsed.as_ref()) - .and_then(|d| d.attribute(ATTR_KEY_GROUP)) - .map(ToOwned::to_owned) - }; - - let dtls_fingerprints = if let Some(cert) = self.dtls_transport.certificates.first() { - cert.get_fingerprints() - } else { - return Err(Error::ErrNonCertificate); - }; - - let params = PopulateSdpParams { - media_description_fingerprint: self.setting_engine.sdp_media_level_fingerprints, - is_icelite: self.setting_engine.candidates.ice_lite, - connection_role, - ice_gathering_state: self.ice_gathering_state(), - match_bundle_group, - }; - populate_sdp( - d, - &dtls_fingerprints, - &self.media_engine, - &candidates, - &ice_params, - &media_sections, - params, - ) - .await - } - - pub(super) fn ice_gathering_state(&self) -> RTCIceGatheringState { - match self.ice_gatherer.state() { - RTCIceGathererState::New => RTCIceGatheringState::New, - RTCIceGathererState::Gathering => RTCIceGatheringState::Gathering, - _ => RTCIceGatheringState::Complete, - } - } - - async fn handle_undeclared_ssrc( - self: &Arc, - ssrc: SSRC, - remote_description: &SessionDescription, - ) -> Result { - if remote_description.media_descriptions.len() != 1 { - return Ok(false); - } - - let only_media_section = &remote_description.media_descriptions[0]; - let mut stream_id = ""; - let mut id = ""; - let mut has_rid = false; - let mut has_ssrc = false; - - for a in &only_media_section.attributes { - match a.key.as_str() { - ATTR_KEY_MSID => { - if let Some(value) = &a.value { - let split: Vec<&str> = value.split(' ').collect(); - if split.len() == 2 { - stream_id = split[0]; - id = split[1]; - } - } - } - ATTR_KEY_SSRC => has_ssrc = true, - SDP_ATTRIBUTE_RID => has_rid = true, - _ => {} - }; - } - - if has_rid { - return Ok(false); - } else if has_ssrc { - return Err(Error::ErrPeerConnSingleMediaSectionHasExplicitSSRC); - } - - let mut incoming = TrackDetails { - ssrcs: vec![ssrc], - kind: RTPCodecType::Video, - stream_id: stream_id.to_owned(), - id: id.to_owned(), - ..Default::default() - }; - if only_media_section.media_name.media == RTPCodecType::Audio.to_string() { - incoming.kind = RTPCodecType::Audio; - } - - let t = self - .add_transceiver_from_kind( - incoming.kind, - Some(RTCRtpTransceiverInit { - direction: RTCRtpTransceiverDirection::Sendrecv, - send_encodings: vec![], - }), - ) - .await?; - - let receiver = t.receiver().await; - PeerConnectionInternal::start_receiver( - self.setting_engine.get_receive_mtu(), - &incoming, - receiver, - t, - Arc::clone(&self.on_track_handler), - ) - .await; - Ok(true) - } - - async fn handle_incoming_ssrc( - self: &Arc, - rtp_stream: Arc, - ssrc: SSRC, - ) -> Result<()> { - let parsed = match self.remote_description().await.and_then(|rd| rd.parsed) { - Some(r) => r, - None => return Err(Error::ErrPeerConnRemoteDescriptionNil), - }; - // If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared - let handled = self.handle_undeclared_ssrc(ssrc, &parsed).await?; - if handled { - return Ok(()); - } - - // Get MID extension ID - let (mid_extension_id, audio_supported, video_supported) = self - .media_engine - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: ::sdp::extmap::SDES_MID_URI.to_owned(), - }) - .await; - if !audio_supported && !video_supported { - return Err(Error::ErrPeerConnSimulcastMidRTPExtensionRequired); - } - - // Get RID extension ID - let (sid_extension_id, audio_supported, video_supported) = self - .media_engine - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: ::sdp::extmap::SDES_RTP_STREAM_ID_URI.to_owned(), - }) - .await; - if !audio_supported && !video_supported { - return Err(Error::ErrPeerConnSimulcastStreamIDRTPExtensionRequired); - } - - let (rsid_extension_id, _, _) = self - .media_engine - .get_header_extension_id(RTCRtpHeaderExtensionCapability { - uri: ::sdp::extmap::SDES_REPAIR_RTP_STREAM_ID_URI.to_owned(), - }) - .await; - - let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()]; - // Packets that we read as part of simulcast probing that we need to make available - // if we do find a track later. - let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default(); - - let n = rtp_stream.read(&mut buf).await?; - - let (mut mid, mut rid, mut rsid, payload_type) = handle_unknown_rtp_packet( - &buf[..n], - mid_extension_id as u8, - sid_extension_id as u8, - rsid_extension_id as u8, - )?; - - let packet = rtp::packet::Packet::unmarshal(&mut buf.as_slice()).unwrap(); - - // TODO: Can we have attributes on the first packets? - buffered_packets.push_back((packet, Attributes::new())); - - let params = self - .media_engine - .get_rtp_parameters_by_payload_type(payload_type) - .await?; - - let icpr = match self.interceptor.upgrade() { - Some(i) => i, - None => return Err(Error::ErrInterceptorNotBind), - }; - - let stream_info = create_stream_info( - "".to_owned(), - ssrc, - params.codecs[0].payload_type, - params.codecs[0].capability.clone(), - ¶ms.header_extensions, - ); - let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = self - .dtls_transport - .streams_for_ssrc(ssrc, &stream_info, &icpr) - .await?; - - let a = Attributes::new(); - for _ in 0..=SIMULCAST_PROBE_COUNT { - if mid.is_empty() || (rid.is_empty() && rsid.is_empty()) { - let (pkt, _) = rtp_interceptor.read(&mut buf, &a).await?; - let (m, r, rs, _) = handle_unknown_rtp_packet( - &buf[..n], - mid_extension_id as u8, - sid_extension_id as u8, - rsid_extension_id as u8, - )?; - mid = m; - rid = r; - rsid = rs; - - buffered_packets.push_back((pkt, a.clone())); - continue; - } - - let transceivers = self.rtp_transceivers.lock().await; - for t in &*transceivers { - if t.mid().as_ref() != Some(&SmolStr::from(&mid)) { - continue; - } - - let receiver = t.receiver().await; - - if !rsid.is_empty() { - return receiver - .receive_for_rtx( - 0, - rsid, - TrackStream { - stream_info: Some(stream_info.clone()), - rtp_read_stream: Some(rtp_read_stream), - rtp_interceptor: Some(rtp_interceptor), - rtcp_read_stream: Some(rtcp_read_stream), - rtcp_interceptor: Some(rtcp_interceptor), - }, - ) - .await; - } - - let track = receiver - .receive_for_rid( - SmolStr::from(rid), - params, - TrackStream { - stream_info: Some(stream_info.clone()), - rtp_read_stream: Some(rtp_read_stream), - rtp_interceptor: Some(rtp_interceptor), - rtcp_read_stream: Some(rtcp_read_stream), - rtcp_interceptor: Some(rtcp_interceptor), - }, - ) - .await?; - track.prepopulate_peeked_data(buffered_packets).await; - - RTCPeerConnection::do_track( - Arc::clone(&self.on_track_handler), - track, - receiver, - Arc::clone(t), - ); - return Ok(()); - } - } - - let _ = rtp_read_stream.close().await; - let _ = rtcp_read_stream.close().await; - icpr.unbind_remote_stream(&stream_info).await; - self.dtls_transport.remove_simulcast_stream(ssrc).await; - - Err(Error::ErrPeerConnSimulcastIncomingSSRCFailed) - } - - async fn start_receiver( - receive_mtu: usize, - incoming: &TrackDetails, - receiver: Arc, - transceiver: Arc, - on_track_handler: Arc>>, - ) { - receiver.start(incoming).await; - for track in receiver.tracks().await { - if track.ssrc() == 0 { - return; - } - - let receiver = Arc::clone(&receiver); - let transceiver = Arc::clone(&transceiver); - let on_track_handler = Arc::clone(&on_track_handler); - tokio::spawn(async move { - let mut b = vec![0u8; receive_mtu]; - let pkt = match track.peek(&mut b).await { - Ok((pkt, _)) => pkt, - Err(err) => { - log::warn!( - "Could not determine PayloadType for SSRC {} ({})", - track.ssrc(), - err - ); - return; - } - }; - - if let Err(err) = track.check_and_update_track(&pkt).await { - log::warn!( - "Failed to set codec settings for track SSRC {} ({})", - track.ssrc(), - err - ); - return; - } - - RTCPeerConnection::do_track(on_track_handler, track, receiver, transceiver); - }); - } - } - - pub(super) async fn create_ice_transport(&self, api: &API) -> Arc { - let ice_transport = Arc::new(api.new_ice_transport(Arc::clone(&self.ice_gatherer))); - - let ice_connection_state = Arc::clone(&self.ice_connection_state); - let peer_connection_state = Arc::clone(&self.peer_connection_state); - let is_closed = Arc::clone(&self.is_closed); - let dtls_transport = Arc::clone(&self.dtls_transport); - let on_ice_connection_state_change_handler = - Arc::clone(&self.on_ice_connection_state_change_handler); - let on_peer_connection_state_change_handler = - Arc::clone(&self.on_peer_connection_state_change_handler); - - ice_transport.on_connection_state_change(Box::new(move |state: RTCIceTransportState| { - let cs = match state { - RTCIceTransportState::New => RTCIceConnectionState::New, - RTCIceTransportState::Checking => RTCIceConnectionState::Checking, - RTCIceTransportState::Connected => RTCIceConnectionState::Connected, - RTCIceTransportState::Completed => RTCIceConnectionState::Completed, - RTCIceTransportState::Failed => RTCIceConnectionState::Failed, - RTCIceTransportState::Disconnected => RTCIceConnectionState::Disconnected, - RTCIceTransportState::Closed => RTCIceConnectionState::Closed, - _ => { - log::warn!("on_connection_state_change: unhandled ICE state: {}", state); - return Box::pin(async {}); - } - }; - - let ice_connection_state2 = Arc::clone(&ice_connection_state); - let on_ice_connection_state_change_handler2 = - Arc::clone(&on_ice_connection_state_change_handler); - let on_peer_connection_state_change_handler2 = - Arc::clone(&on_peer_connection_state_change_handler); - let is_closed2 = Arc::clone(&is_closed); - let dtls_transport_state = dtls_transport.state(); - let peer_connection_state2 = Arc::clone(&peer_connection_state); - Box::pin(async move { - RTCPeerConnection::do_ice_connection_state_change( - &on_ice_connection_state_change_handler2, - &ice_connection_state2, - cs, - ) - .await; - - RTCPeerConnection::update_connection_state( - &on_peer_connection_state_change_handler2, - &is_closed2, - &peer_connection_state2, - cs, - dtls_transport_state, - ) - .await; - }) - })); - - ice_transport - } - - /// has_local_description_changed returns whether local media (rtp_transceivers) has changed - /// caller of this method should hold `pc.mu` lock - pub(super) async fn has_local_description_changed(&self, desc: &RTCSessionDescription) -> bool { - let rtp_transceivers = self.rtp_transceivers.lock().await; - for t in &*rtp_transceivers { - let m = match t.mid().and_then(|mid| get_by_mid(mid.as_str(), desc)) { - Some(m) => m, - None => return true, - }; - - if get_peer_direction(m) != t.direction() { - return true; - } - } - false - } - - pub(super) async fn get_stats(&self, stats_id: String) -> StatsCollector { - let collector = StatsCollector::new(); - let transceivers = { self.rtp_transceivers.lock().await.clone() }; - - tokio::join!( - self.ice_gatherer.collect_stats(&collector), - self.ice_transport.collect_stats(&collector), - self.sctp_transport.collect_stats(&collector, stats_id), - self.dtls_transport.collect_stats(&collector), - self.media_engine.collect_stats(&collector), - self.collect_inbound_stats(&collector, transceivers.clone()), - self.collect_outbound_stats(&collector, transceivers) - ); - - collector - } - - async fn collect_inbound_stats( - &self, - collector: &StatsCollector, - transceivers: Vec>, - ) { - // TODO: There's a lot of await points here that could run concurrently with `futures::join_all`. - struct TrackInfo { - ssrc: SSRC, - mid: SmolStr, - track_id: String, - kind: &'static str, - } - let mut track_infos = vec![]; - for transeiver in transceivers { - let receiver = transeiver.receiver().await; - - if let Some(mid) = transeiver.mid() { - let tracks = receiver.tracks().await; - - for track in tracks { - let track_id = track.id(); - let kind = match track.kind() { - RTPCodecType::Unspecified => continue, - RTPCodecType::Audio => "audio", - RTPCodecType::Video => "video", - }; - - track_infos.push(TrackInfo { - ssrc: track.ssrc(), - mid: mid.clone(), - track_id, - kind, - }); - } - } - } - - let stream_stats = self - .stats_interceptor - .fetch_inbound_stats(track_infos.iter().map(|t| t.ssrc).collect()) - .await; - - for (stats, info) in - (stream_stats.into_iter().zip(track_infos)).filter_map(|(s, i)| s.map(|s| (s, i))) - { - let ssrc = info.ssrc; - let kind = info.kind; - - let id = format!("RTCInboundRTP{}Stream_{}", capitalize(kind), ssrc); - let ( - packets_received, - header_bytes_received, - bytes_received, - last_packet_received_timestamp, - nack_count, - remote_packets_sent, - remote_bytes_sent, - remote_reports_sent, - remote_round_trip_time, - remote_total_round_trip_time, - remote_round_trip_time_measurements, - ) = ( - stats.packets_received(), - stats.header_bytes_received(), - stats.payload_bytes_received(), - stats.last_packet_received_timestamp(), - stats.nacks_sent(), - stats.remote_packets_sent(), - stats.remote_bytes_sent(), - stats.remote_reports_sent(), - stats.remote_round_trip_time(), - stats.remote_total_round_trip_time(), - stats.remote_round_trip_time_measurements(), - ); - - collector.insert( - id.clone(), - crate::stats::StatsReportType::InboundRTP(InboundRTPStats { - timestamp: Instant::now(), - stats_type: RTCStatsType::InboundRTP, - id: id.clone(), - ssrc, - kind: kind.to_owned(), - packets_received, - track_identifier: info.track_id, - mid: info.mid, - last_packet_received_timestamp, - header_bytes_received, - bytes_received, - nack_count, - - fir_count: (info.kind == "video").then(|| stats.firs_sent()), - pli_count: (info.kind == "video").then(|| stats.plis_sent()), - }), - ); - - let local_id = id; - let id = format!( - "RTCRemoteOutboundRTP{}Stream_{}", - capitalize(info.kind), - info.ssrc - ); - collector.insert( - id.clone(), - crate::stats::StatsReportType::RemoteOutboundRTP(RemoteOutboundRTPStats { - timestamp: Instant::now(), - stats_type: RTCStatsType::RemoteOutboundRTP, - id, - - ssrc, - kind: kind.to_owned(), - - packets_sent: remote_packets_sent as u64, - bytes_sent: remote_bytes_sent as u64, - local_id, - reports_sent: remote_reports_sent, - round_trip_time: remote_round_trip_time, - total_round_trip_time: remote_total_round_trip_time, - round_trip_time_measurements: remote_round_trip_time_measurements, - }), - ); - } - } - - async fn collect_outbound_stats( - &self, - collector: &StatsCollector, - transceivers: Vec>, - ) { - // TODO: There's a lot of await points here that could run concurrently with `futures::join_all`. - struct TrackInfo { - track_id: String, - ssrc: SSRC, - mid: SmolStr, - rid: Option, - kind: &'static str, - } - let mut track_infos = vec![]; - for transceiver in transceivers { - let mid = match transceiver.mid() { - Some(mid) => mid, - None => continue, - }; - - let sender = transceiver.sender().await; - let track_encodings = sender.track_encodings.lock().await; - for encoding in track_encodings.iter() { - let track_id = encoding.track.id().to_string(); - let kind = match encoding.track.kind() { - RTPCodecType::Unspecified => continue, - RTPCodecType::Audio => "audio", - RTPCodecType::Video => "video", - }; - - track_infos.push(TrackInfo { - track_id, - ssrc: encoding.ssrc, - mid: mid.to_owned(), - rid: encoding.track.rid().map(Into::into), - kind, - }); - } - } - - let stream_stats = self - .stats_interceptor - .fetch_outbound_stats(track_infos.iter().map(|t| t.ssrc).collect()) - .await; - - for (stats, info) in stream_stats - .into_iter() - .zip(track_infos) - .filter_map(|(s, i)| s.map(|s| (s, i))) - { - // RTCOutboundRtpStreamStats - let id = format!( - "RTCOutboundRTP{}Stream_{}", - capitalize(info.kind), - info.ssrc - ); - let ( - packets_sent, - bytes_sent, - header_bytes_sent, - nack_count, - remote_inbound_packets_received, - remote_inbound_packets_lost, - remote_rtt_ms, - remote_total_rtt_ms, - remote_rtt_measurements, - remote_fraction_lost, - ) = ( - stats.packets_sent(), - stats.payload_bytes_sent(), - stats.header_bytes_sent(), - stats.nacks_received(), - stats.remote_packets_received(), - stats.remote_total_lost(), - stats.remote_round_trip_time(), - stats.remote_total_round_trip_time(), - stats.remote_round_trip_time_measurements(), - stats.remote_fraction_lost(), - ); - - let TrackInfo { - mid, - ssrc, - rid, - kind, - track_id: track_identifier, - } = info; - - collector.insert( - id.clone(), - crate::stats::StatsReportType::OutboundRTP(OutboundRTPStats { - timestamp: Instant::now(), - stats_type: RTCStatsType::OutboundRTP, - track_identifier, - id: id.clone(), - ssrc, - kind: kind.to_owned(), - packets_sent, - mid, - rid, - header_bytes_sent, - bytes_sent, - nack_count, - - fir_count: (info.kind == "video").then(|| stats.firs_received()), - pli_count: (info.kind == "video").then(|| stats.plis_received()), - }), - ); - - let local_id = id; - let id = format!( - "RTCRemoteInboundRTP{}Stream_{}", - capitalize(info.kind), - info.ssrc - ); - - collector.insert( - id.clone(), - StatsReportType::RemoteInboundRTP(RemoteInboundRTPStats { - timestamp: Instant::now(), - stats_type: RTCStatsType::RemoteInboundRTP, - id, - ssrc, - kind: kind.to_owned(), - - packets_received: remote_inbound_packets_received, - packets_lost: remote_inbound_packets_lost as i64, - - local_id, - - round_trip_time: remote_rtt_ms, - total_round_trip_time: remote_total_rtt_ms, - fraction_lost: remote_fraction_lost.unwrap_or(0.0), - round_trip_time_measurements: remote_rtt_measurements, - }), - ); - } - } -} - -type IResult = std::result::Result; - -#[async_trait] -impl RTCPWriter for PeerConnectionInternal { - async fn write( - &self, - pkts: &[Box], - _a: &Attributes, - ) -> IResult { - Ok(self.dtls_transport.write_rtcp(pkts).await?) - } -} - -fn capitalize(s: &str) -> String { - let first = s - .chars() - .next() - .expect("Must have at least one character to uppercase") - .to_uppercase(); - let mut result = String::new(); - - result.extend(first); - result.extend(s.chars().skip(1)); - - result -} diff --git a/webrtc/src/peer_connection/peer_connection_state.rs b/webrtc/src/peer_connection/peer_connection_state.rs deleted file mode 100644 index 905e26769..000000000 --- a/webrtc/src/peer_connection/peer_connection_state.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::fmt; - -/// PeerConnectionState indicates the state of the PeerConnection. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCPeerConnectionState { - #[default] - Unspecified, - - /// PeerConnectionStateNew indicates that any of the ICETransports or - /// DTLSTransports are in the "new" state and none of the transports are - /// in the "connecting", "checking", "failed" or "disconnected" state, or - /// all transports are in the "closed" state, or there are no transports. - New, - - /// PeerConnectionStateConnecting indicates that any of the - /// ICETransports or DTLSTransports are in the "connecting" or - /// "checking" state and none of them is in the "failed" state. - Connecting, - - /// PeerConnectionStateConnected indicates that all ICETransports and - /// DTLSTransports are in the "connected", "completed" or "closed" state - /// and at least one of them is in the "connected" or "completed" state. - Connected, - - /// PeerConnectionStateDisconnected indicates that any of the - /// ICETransports or DTLSTransports are in the "disconnected" state - /// and none of them are in the "failed" or "connecting" or "checking" state. - Disconnected, - - /// PeerConnectionStateFailed indicates that any of the ICETransports - /// or DTLSTransports are in a "failed" state. - Failed, - - /// PeerConnectionStateClosed indicates the peer connection is closed - /// and the isClosed member variable of PeerConnection is true. - Closed, -} - -const PEER_CONNECTION_STATE_NEW_STR: &str = "new"; -const PEER_CONNECTION_STATE_CONNECTING_STR: &str = "connecting"; -const PEER_CONNECTION_STATE_CONNECTED_STR: &str = "connected"; -const PEER_CONNECTION_STATE_DISCONNECTED_STR: &str = "disconnected"; -const PEER_CONNECTION_STATE_FAILED_STR: &str = "failed"; -const PEER_CONNECTION_STATE_CLOSED_STR: &str = "closed"; - -impl From<&str> for RTCPeerConnectionState { - fn from(raw: &str) -> Self { - match raw { - PEER_CONNECTION_STATE_NEW_STR => RTCPeerConnectionState::New, - PEER_CONNECTION_STATE_CONNECTING_STR => RTCPeerConnectionState::Connecting, - PEER_CONNECTION_STATE_CONNECTED_STR => RTCPeerConnectionState::Connected, - PEER_CONNECTION_STATE_DISCONNECTED_STR => RTCPeerConnectionState::Disconnected, - PEER_CONNECTION_STATE_FAILED_STR => RTCPeerConnectionState::Failed, - PEER_CONNECTION_STATE_CLOSED_STR => RTCPeerConnectionState::Closed, - _ => RTCPeerConnectionState::Unspecified, - } - } -} - -impl From for RTCPeerConnectionState { - fn from(v: u8) -> Self { - match v { - 1 => RTCPeerConnectionState::New, - 2 => RTCPeerConnectionState::Connecting, - 3 => RTCPeerConnectionState::Connected, - 4 => RTCPeerConnectionState::Disconnected, - 5 => RTCPeerConnectionState::Failed, - 6 => RTCPeerConnectionState::Closed, - _ => RTCPeerConnectionState::Unspecified, - } - } -} - -impl fmt::Display for RTCPeerConnectionState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCPeerConnectionState::New => PEER_CONNECTION_STATE_NEW_STR, - RTCPeerConnectionState::Connecting => PEER_CONNECTION_STATE_CONNECTING_STR, - RTCPeerConnectionState::Connected => PEER_CONNECTION_STATE_CONNECTED_STR, - RTCPeerConnectionState::Disconnected => PEER_CONNECTION_STATE_DISCONNECTED_STR, - RTCPeerConnectionState::Failed => PEER_CONNECTION_STATE_FAILED_STR, - RTCPeerConnectionState::Closed => PEER_CONNECTION_STATE_CLOSED_STR, - RTCPeerConnectionState::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub(crate) enum NegotiationNeededState { - /// NegotiationNeededStateEmpty not running and queue is empty - #[default] - Empty, - /// NegotiationNeededStateEmpty running and queue is empty - Run, - /// NegotiationNeededStateEmpty running and queue - Queue, -} - -impl From for NegotiationNeededState { - fn from(v: u8) -> Self { - match v { - 1 => NegotiationNeededState::Run, - 2 => NegotiationNeededState::Queue, - _ => NegotiationNeededState::Empty, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_peer_connection_state() { - let tests = vec![ - (crate::UNSPECIFIED_STR, RTCPeerConnectionState::Unspecified), - ("new", RTCPeerConnectionState::New), - ("connecting", RTCPeerConnectionState::Connecting), - ("connected", RTCPeerConnectionState::Connected), - ("disconnected", RTCPeerConnectionState::Disconnected), - ("failed", RTCPeerConnectionState::Failed), - ("closed", RTCPeerConnectionState::Closed), - ]; - - for (state_string, expected_state) in tests { - assert_eq!( - RTCPeerConnectionState::from(state_string), - expected_state, - "testCase: {expected_state}", - ); - } - } - - #[test] - fn test_peer_connection_state_string() { - let tests = vec![ - (RTCPeerConnectionState::Unspecified, crate::UNSPECIFIED_STR), - (RTCPeerConnectionState::New, "new"), - (RTCPeerConnectionState::Connecting, "connecting"), - (RTCPeerConnectionState::Connected, "connected"), - (RTCPeerConnectionState::Disconnected, "disconnected"), - (RTCPeerConnectionState::Failed, "failed"), - (RTCPeerConnectionState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string) - } - } -} diff --git a/webrtc/src/peer_connection/peer_connection_test.rs b/webrtc/src/peer_connection/peer_connection_test.rs deleted file mode 100644 index 2c9676037..000000000 --- a/webrtc/src/peer_connection/peer_connection_test.rs +++ /dev/null @@ -1,640 +0,0 @@ -use std::sync::Arc; - -use bytes::Bytes; -use interceptor::registry::Registry; -use media::Sample; -use portable_atomic::AtomicU32; -use tokio::time::Duration; -use util::vnet::net::{Net, NetConfig}; -use util::vnet::router::{Router, RouterConfig}; -use waitgroup::WaitGroup; - -use super::*; -use crate::api::interceptor_registry::register_default_interceptors; -use crate::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use crate::api::APIBuilder; -use crate::ice_transport::ice_candidate_pair::RTCIceCandidatePair; -use crate::ice_transport::ice_credential_type::RTCIceCredentialType; -use crate::ice_transport::ice_server::RTCIceServer; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use crate::stats::StatsReportType; -use crate::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use crate::Error; - -pub(crate) async fn create_vnet_pair( -) -> Result<(RTCPeerConnection, RTCPeerConnection, Arc>)> { - // Create a root router - let wan = Arc::new(Mutex::new(Router::new(RouterConfig { - cidr: "1.2.3.0/24".to_owned(), - ..Default::default() - })?)); - - // Create a network interface for offerer - let offer_vnet = Arc::new(Net::new(Some(NetConfig { - static_ips: vec!["1.2.3.4".to_owned()], - ..Default::default() - }))); - - // Add the network interface to the router - let nic = offer_vnet.get_nic()?; - { - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - let mut offer_setting_engine = SettingEngine::default(); - offer_setting_engine.set_vnet(Some(offer_vnet)); - offer_setting_engine.set_ice_timeouts( - Some(Duration::from_secs(1)), - Some(Duration::from_secs(1)), - Some(Duration::from_millis(200)), - ); - - // Create a network interface for answerer - let answer_vnet = Arc::new(Net::new(Some(NetConfig { - static_ips: vec!["1.2.3.5".to_owned()], - ..Default::default() - }))); - - // Add the network interface to the router - let nic = answer_vnet.get_nic()?; - { - let mut w = wan.lock().await; - w.add_net(Arc::clone(&nic)).await?; - } - { - let n = nic.lock().await; - n.set_router(Arc::clone(&wan)).await?; - } - - let mut answer_setting_engine = SettingEngine::default(); - answer_setting_engine.set_vnet(Some(answer_vnet)); - answer_setting_engine.set_ice_timeouts( - Some(Duration::from_secs(1)), - Some(Duration::from_secs(1)), - Some(Duration::from_millis(200)), - ); - - // Start the virtual network by calling Start() on the root router - { - let mut w = wan.lock().await; - w.start().await?; - } - - let mut offer_media_engine = MediaEngine::default(); - offer_media_engine.register_default_codecs()?; - let offer_peer_connection = APIBuilder::new() - .with_setting_engine(offer_setting_engine) - .with_media_engine(offer_media_engine) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - let mut answer_media_engine = MediaEngine::default(); - answer_media_engine.register_default_codecs()?; - let answer_peer_connection = APIBuilder::new() - .with_setting_engine(answer_setting_engine) - .with_media_engine(answer_media_engine) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - Ok((offer_peer_connection, answer_peer_connection, wan)) -} - -/// new_pair creates two new peer connections (an offerer and an answerer) -/// *without* using an api (i.e. using the default settings). -pub(crate) async fn new_pair(api: &API) -> Result<(RTCPeerConnection, RTCPeerConnection)> { - let pca = api.new_peer_connection(RTCConfiguration::default()).await?; - let pcb = api.new_peer_connection(RTCConfiguration::default()).await?; - - Ok((pca, pcb)) -} - -pub(crate) async fn signal_pair( - pc_offer: &mut RTCPeerConnection, - pc_answer: &mut RTCPeerConnection, -) -> Result<()> { - // Note(albrow): We need to create a data channel in order to trigger ICE - // candidate gathering in the background for the JavaScript/Wasm bindings. If - // we don't do this, the complete offer including ICE candidates will never be - // generated. - pc_offer - .create_data_channel("initial_data_channel", None) - .await?; - - let offer = pc_offer.create_offer(None).await?; - - let mut offer_gathering_complete = pc_offer.gathering_complete_promise().await; - pc_offer.set_local_description(offer).await?; - - let _ = offer_gathering_complete.recv().await; - - pc_answer - .set_remote_description( - pc_offer - .local_description() - .await - .ok_or(Error::new("non local description".to_owned()))?, - ) - .await?; - - let answer = pc_answer.create_answer(None).await?; - - let mut answer_gathering_complete = pc_answer.gathering_complete_promise().await; - pc_answer.set_local_description(answer).await?; - - let _ = answer_gathering_complete.recv().await; - - pc_offer - .set_remote_description( - pc_answer - .local_description() - .await - .ok_or(Error::new("non local description".to_owned()))?, - ) - .await -} - -pub(crate) async fn close_pair_now(pc1: &RTCPeerConnection, pc2: &RTCPeerConnection) { - let mut fail = false; - if let Err(err) = pc1.close().await { - log::error!("Failed to close PeerConnection: {}", err); - fail = true; - } - if let Err(err) = pc2.close().await { - log::error!("Failed to close PeerConnection: {}", err); - fail = true; - } - - assert!(!fail); -} - -pub(crate) async fn close_pair( - pc1: &RTCPeerConnection, - pc2: &RTCPeerConnection, - mut done_rx: mpsc::Receiver<()>, -) { - let timeout = tokio::time::sleep(Duration::from_secs(10)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - panic!("close_pair timed out waiting for done signal"); - } - _ = done_rx.recv() =>{ - close_pair_now(pc1, pc2).await; - } - } -} - -/* -func offerMediaHasDirection(offer SessionDescription, kind RTPCodecType, direction RTPTransceiverDirection) bool { - parsed := &sdp.SessionDescription{} - if err := parsed.Unmarshal([]byte(offer.SDP)); err != nil { - return false - } - - for _, media := range parsed.MediaDescriptions { - if media.MediaName.Media == kind.String() { - _, exists := media.Attribute(direction.String()) - return exists - } - } - return false -}*/ - -pub(crate) async fn send_video_until_done( - mut done_rx: mpsc::Receiver<()>, - tracks: Vec>, - data: Bytes, - max_sends: Option, -) -> bool { - let mut sends = 0; - - loop { - let timeout = tokio::time::sleep(Duration::from_millis(20)); - tokio::pin!(timeout); - - tokio::select! { - biased; - - _ = done_rx.recv() =>{ - log::debug!("sendVideoUntilDone received done"); - return false; - } - - _ = timeout.as_mut() =>{ - if max_sends.map(|s| sends >= s).unwrap_or(false) { - continue; - } - - log::debug!("sendVideoUntilDone timeout"); - for track in &tracks { - log::debug!("sendVideoUntilDone track.WriteSample"); - let result = track.write_sample(&Sample{ - data: data.clone(), - duration: Duration::from_secs(1), - ..Default::default() - }).await; - assert!(result.is_ok()); - sends += 1; - } - } - } - } -} - -pub(crate) async fn until_connection_state( - pc: &mut RTCPeerConnection, - wg: &WaitGroup, - state: RTCPeerConnectionState, -) { - let w = Arc::new(Mutex::new(Some(wg.worker()))); - pc.on_peer_connection_state_change(Box::new(move |pcs: RTCPeerConnectionState| { - let w2 = Arc::clone(&w); - Box::pin(async move { - if pcs == state { - let mut worker = w2.lock().await; - worker.take(); - } - }) - })); -} - -#[tokio::test] -async fn test_get_stats() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; - - let (ice_complete_tx, mut ice_complete_rx) = mpsc::channel::<()>(1); - let ice_complete_tx = Arc::new(Mutex::new(Some(ice_complete_tx))); - pc_answer.on_ice_connection_state_change(Box::new(move |ice_state: RTCIceConnectionState| { - let ice_complete_tx2 = Arc::clone(&ice_complete_tx); - Box::pin(async move { - if ice_state == RTCIceConnectionState::Connected { - tokio::time::sleep(Duration::from_secs(1)).await; - let mut done = ice_complete_tx2.lock().await; - done.take(); - } - }) - })); - - let sender_called_candidate_change = Arc::new(AtomicU32::new(0)); - let sender_called_candidate_change2 = Arc::clone(&sender_called_candidate_change); - pc_offer - .sctp() - .transport() - .ice_transport() - .on_selected_candidate_pair_change(Box::new(move |_: RTCIceCandidatePair| { - sender_called_candidate_change2.store(1, Ordering::SeqCst); - Box::pin(async {}) - })); - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - pc_offer - .add_track(track.clone()) - .await - .expect("Failed to add track"); - let (packet_tx, packet_rx) = mpsc::channel(1); - - pc_answer.on_track(Box::new(move |track, _, _| { - let packet_tx = packet_tx.clone(); - tokio::spawn(async move { - while let Ok((pkt, _)) = track.read_rtp().await { - dbg!(&pkt); - let last = pkt.payload[pkt.payload.len() - 1]; - - if last == 0xAA { - let _ = packet_tx.send(()).await; - break; - } - } - }); - - Box::pin(async move {}) - })); - - signal_pair(&mut pc_offer, &mut pc_answer).await?; - - let _ = ice_complete_rx.recv().await; - send_video_until_done( - packet_rx, - vec![track], - Bytes::from_static(b"\xDE\xAD\xBE\xEF\xAA"), - Some(1), - ) - .await; - - let offer_stats = pc_offer.get_stats().await; - assert!(!offer_stats.reports.is_empty()); - - match offer_stats.reports.get("ice_transport") { - Some(StatsReportType::Transport(ice_transport_stats)) => { - assert!(ice_transport_stats.bytes_received > 0); - assert!(ice_transport_stats.bytes_sent > 0); - } - Some(_other) => panic!("found the wrong type"), - None => panic!("missed it"), - } - let outbound_stats = offer_stats - .reports - .values() - .find_map(|v| match v { - StatsReportType::OutboundRTP(d) => Some(d), - _ => None, - }) - .expect("Should have produced an RTP Outbound stat"); - assert_eq!(outbound_stats.packets_sent, 1); - assert_eq!(outbound_stats.kind, "video"); - assert_eq!(outbound_stats.bytes_sent, 8); - assert_eq!(outbound_stats.header_bytes_sent, 12); - - let answer_stats = pc_answer.get_stats().await; - let inbound_stats = answer_stats - .reports - .values() - .find_map(|v| match v { - StatsReportType::InboundRTP(d) => Some(d), - _ => None, - }) - .expect("Should have produced an RTP inbound stat"); - assert_eq!(inbound_stats.packets_received, 1); - assert_eq!(inbound_stats.kind, "video"); - assert_eq!(inbound_stats.bytes_received, 8); - assert_eq!(inbound_stats.header_bytes_received, 12); - - close_pair_now(&pc_offer, &pc_answer).await; - - Ok(()) -} - -#[tokio::test] -async fn test_peer_connection_close_is_send() -> Result<()> { - let handle = tokio::spawn(async move { peer().await }); - tokio::join!(handle).0.unwrap() -} - -#[tokio::test] -async fn test_set_get_configuration() { - // initialize MediaEngine and InterceptorRegistry - let media_engine = MediaEngine::default(); - let registry = Registry::default(); - - // create API instance - let api = APIBuilder::new() - .with_media_engine(media_engine) - .with_interceptor_registry(registry) - .build(); - - // create configuration - let initial_config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_string()], - username: "".to_string(), - credential: "".to_string(), - credential_type: RTCIceCredentialType::Unspecified, - }], - ..Default::default() - }; - - // create RTCPeerConnection instance - let peer = Arc::new( - api.new_peer_connection(initial_config.clone()) - .await - .expect("Failed to create RTCPeerConnection"), - ); - - // get configuration and println - let config_before = peer.get_configuration().await; - println!("Initial ICE Servers: {:?}", config_before.ice_servers); - println!( - "Initial ICE Transport Policy: {:?}", - config_before.ice_transport_policy - ); - println!("Initial Bundle Policy: {:?}", config_before.bundle_policy); - println!( - "Initial RTCP Mux Policy: {:?}", - config_before.rtcp_mux_policy - ); - println!("Initial Peer Identity: {:?}", config_before.peer_identity); - println!("Initial Certificates: {:?}", config_before.certificates); - println!( - "Initial ICE Candidate Pool Size: {:?}", - config_before.ice_candidate_pool_size - ); - - // create new configuration - let new_config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec![ - "turn:turn.22333.fun".to_string(), - "turn:cn.22333.fun".to_string(), - ], - username: "live777".to_string(), - credential: "live777".to_string(), - credential_type: RTCIceCredentialType::Password, - }], - ..Default::default() - }; - - // set new configuration - peer.set_configuration(new_config.clone()) - .await - .expect("Failed to set configuration"); - - // get new configuration and println - let updated_config = peer.get_configuration().await; - println!("Updated ICE Servers: {:?}", updated_config.ice_servers); - println!( - "Updated ICE Transport Policy: {:?}", - updated_config.ice_transport_policy - ); - println!("Updated Bundle Policy: {:?}", updated_config.bundle_policy); - println!( - "Updated RTCP Mux Policy: {:?}", - updated_config.rtcp_mux_policy - ); - println!("Updated Peer Identity: {:?}", updated_config.peer_identity); - println!("Updated Certificates: {:?}", updated_config.certificates); - println!( - "Updated ICE Candidate Pool Size: {:?}", - updated_config.ice_candidate_pool_size - ); - - // verify - assert_eq!(updated_config.ice_servers, new_config.ice_servers); -} - -async fn peer() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let mut registry = Registry::new(); - registry = register_default_interceptors(registry, &mut m)?; - let api = APIBuilder::new() - .with_media_engine(m) - .with_interceptor_registry(registry) - .build(); - - let config = RTCConfiguration { - ice_servers: vec![RTCIceServer { - urls: vec!["stun:stun.l.google.com:19302".to_owned()], - ..Default::default() - }], - ..Default::default() - }; - - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let offer = peer_connection.create_offer(None).await?; - let mut gather_complete = peer_connection.gathering_complete_promise().await; - peer_connection.set_local_description(offer).await?; - let _ = gather_complete.recv().await; - - if peer_connection.local_description().await.is_some() { - //TODO? - } - - peer_connection.close().await?; - - Ok(()) -} - -pub(crate) fn on_connected() -> (OnPeerConnectionStateChangeHdlrFn, mpsc::Receiver<()>) { - let (done_tx, done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - let hdlr_fn: OnPeerConnectionStateChangeHdlrFn = - Box::new(move |state: RTCPeerConnectionState| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if state == RTCPeerConnectionState::Connected { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }); - (hdlr_fn, done_rx) -} - -// Everytime we receive a new SSRC we probe it and try to determine the proper way to handle it. -// In most cases a Track explicitly declares a SSRC and a OnTrack is fired. In two cases we don't -// know the SSRC ahead of time -// * Undeclared SSRC in a single media section -// * Simulcast -// -// The Undeclared SSRC processing code would run before Simulcast. If a Simulcast Offer/Answer only -// contained one Media Section we would never fire the OnTrack. We would assume it was a failed -// Undeclared SSRC processing. This test asserts that we properly handled this. -#[tokio::test] -async fn test_peer_connection_simulcast_no_data_channel() -> Result<()> { - let mut m = MediaEngine::default(); - for ext in [ - ::sdp::extmap::SDES_MID_URI, - ::sdp::extmap::SDES_RTP_STREAM_ID_URI, - ] { - m.register_header_extension( - RTCRtpHeaderExtensionCapability { - uri: ext.to_owned(), - }, - RTPCodecType::Video, - None, - )?; - } - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut pc_send, mut pc_recv) = new_pair(&api).await?; - let (send_notifier, mut send_connected) = on_connected(); - let (recv_notifier, mut recv_connected) = on_connected(); - pc_send.on_peer_connection_state_change(send_notifier); - pc_recv.on_peer_connection_state_change(recv_notifier); - let (track_tx, mut track_rx) = mpsc::unbounded_channel(); - pc_recv.on_track(Box::new(move |t, _, _| { - let rid = t.rid().to_owned(); - let _ = track_tx.send(rid); - Box::pin(async move {}) - })); - - let id = "video"; - let stream_id = "webrtc-rs"; - let track = Arc::new(TrackLocalStaticRTP::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - id.to_owned(), - "a".to_owned(), - stream_id.to_owned(), - )); - let track_a = Arc::clone(&track); - let transceiver = pc_send.add_transceiver_from_track(track, None).await?; - let sender = transceiver.sender().await; - - let track = Arc::new(TrackLocalStaticRTP::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - id.to_owned(), - "b".to_owned(), - stream_id.to_owned(), - )); - let track_b = Arc::clone(&track); - sender.add_encoding(track).await?; - - let track = Arc::new(TrackLocalStaticRTP::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - id.to_owned(), - "c".to_owned(), - stream_id.to_owned(), - )); - let track_c = Arc::clone(&track); - sender.add_encoding(track).await?; - - // signaling - signal_pair(&mut pc_send, &mut pc_recv).await?; - let _ = send_connected.recv().await; - let _ = recv_connected.recv().await; - - for sequence_number in [0; 100] { - let pkt = rtp::packet::Packet { - header: rtp::header::Header { - version: 2, - sequence_number, - payload_type: 96, - ..Default::default() - }, - payload: Bytes::from_static(&[0; 2]), - }; - - track_a.write_rtp_with_extensions(&pkt, &[]).await?; - track_b.write_rtp_with_extensions(&pkt, &[]).await?; - track_c.write_rtp_with_extensions(&pkt, &[]).await?; - } - - assert_eq!(track_rx.recv().await.unwrap(), "a".to_owned()); - assert_eq!(track_rx.recv().await.unwrap(), "b".to_owned()); - assert_eq!(track_rx.recv().await.unwrap(), "c".to_owned()); - - close_pair_now(&pc_send, &pc_recv).await; - - Ok(()) -} diff --git a/webrtc/src/peer_connection/policy/bundle_policy.rs b/webrtc/src/peer_connection/policy/bundle_policy.rs deleted file mode 100644 index 040228f26..000000000 --- a/webrtc/src/peer_connection/policy/bundle_policy.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// BundlePolicy affects which media tracks are negotiated if the remote -/// endpoint is not bundle-aware, and what ICE candidates are gathered. If the -/// remote endpoint is bundle-aware, all media tracks and data channels are -/// bundled onto the same transport. -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] -pub enum RTCBundlePolicy { - #[default] - Unspecified = 0, - - /// BundlePolicyBalanced indicates to gather ICE candidates for each - /// media type in use (audio, video, and data). If the remote endpoint is - /// not bundle-aware, negotiate only one audio and video track on separate - /// transports. - #[serde(rename = "balanced")] - Balanced = 1, - - /// BundlePolicyMaxCompat indicates to gather ICE candidates for each - /// track. If the remote endpoint is not bundle-aware, negotiate all media - /// tracks on separate transports. - #[serde(rename = "max-compat")] - MaxCompat = 2, - - /// BundlePolicyMaxBundle indicates to gather ICE candidates for only - /// one track. If the remote endpoint is not bundle-aware, negotiate only - /// one media track. - #[serde(rename = "max-bundle")] - MaxBundle = 3, -} - -/// This is done this way because of a linter. -const BUNDLE_POLICY_BALANCED_STR: &str = "balanced"; -const BUNDLE_POLICY_MAX_COMPAT_STR: &str = "max-compat"; -const BUNDLE_POLICY_MAX_BUNDLE_STR: &str = "max-bundle"; - -impl From<&str> for RTCBundlePolicy { - /// NewSchemeType defines a procedure for creating a new SchemeType from a raw - /// string naming the scheme type. - fn from(raw: &str) -> Self { - match raw { - BUNDLE_POLICY_BALANCED_STR => RTCBundlePolicy::Balanced, - BUNDLE_POLICY_MAX_COMPAT_STR => RTCBundlePolicy::MaxCompat, - BUNDLE_POLICY_MAX_BUNDLE_STR => RTCBundlePolicy::MaxBundle, - _ => RTCBundlePolicy::Unspecified, - } - } -} - -impl fmt::Display for RTCBundlePolicy { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCBundlePolicy::Balanced => write!(f, "{BUNDLE_POLICY_BALANCED_STR}"), - RTCBundlePolicy::MaxCompat => write!(f, "{BUNDLE_POLICY_MAX_COMPAT_STR}"), - RTCBundlePolicy::MaxBundle => write!(f, "{BUNDLE_POLICY_MAX_BUNDLE_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_bundle_policy() { - let tests = vec![ - ("Unspecified", RTCBundlePolicy::Unspecified), - ("balanced", RTCBundlePolicy::Balanced), - ("max-compat", RTCBundlePolicy::MaxCompat), - ("max-bundle", RTCBundlePolicy::MaxBundle), - ]; - - for (policy_string, expected_policy) in tests { - assert_eq!(RTCBundlePolicy::from(policy_string), expected_policy); - } - } - - #[test] - fn test_bundle_policy_string() { - let tests = vec![ - (RTCBundlePolicy::Unspecified, "Unspecified"), - (RTCBundlePolicy::Balanced, "balanced"), - (RTCBundlePolicy::MaxCompat, "max-compat"), - (RTCBundlePolicy::MaxBundle, "max-bundle"), - ]; - - for (policy, expected_string) in tests { - assert_eq!(policy.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/peer_connection/policy/ice_transport_policy.rs b/webrtc/src/peer_connection/policy/ice_transport_policy.rs deleted file mode 100644 index 331ebe60b..000000000 --- a/webrtc/src/peer_connection/policy/ice_transport_policy.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// ICETransportPolicy defines the ICE candidate policy surface the -/// permitted candidates. Only these candidates are used for connectivity checks. -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] -pub enum RTCIceTransportPolicy { - #[default] - Unspecified = 0, - - /// ICETransportPolicyAll indicates any type of candidate is used. - #[serde(rename = "all")] - All = 1, - - /// ICETransportPolicyRelay indicates only media relay candidates such - /// as candidates passing through a TURN server are used. - #[serde(rename = "relay")] - Relay = 2, -} - -/// ICEGatherPolicy is the ORTC equivalent of ICETransportPolicy -pub type ICEGatherPolicy = RTCIceTransportPolicy; - -const ICE_TRANSPORT_POLICY_RELAY_STR: &str = "relay"; -const ICE_TRANSPORT_POLICY_ALL_STR: &str = "all"; - -/// takes a string and converts it to ICETransportPolicy -impl From<&str> for RTCIceTransportPolicy { - fn from(raw: &str) -> Self { - match raw { - ICE_TRANSPORT_POLICY_RELAY_STR => RTCIceTransportPolicy::Relay, - ICE_TRANSPORT_POLICY_ALL_STR => RTCIceTransportPolicy::All, - _ => RTCIceTransportPolicy::Unspecified, - } - } -} - -impl fmt::Display for RTCIceTransportPolicy { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCIceTransportPolicy::Relay => ICE_TRANSPORT_POLICY_RELAY_STR, - RTCIceTransportPolicy::All => ICE_TRANSPORT_POLICY_ALL_STR, - RTCIceTransportPolicy::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_ice_transport_policy() { - let tests = vec![ - ("relay", RTCIceTransportPolicy::Relay), - ("all", RTCIceTransportPolicy::All), - ]; - - for (policy_string, expected_policy) in tests { - assert_eq!(RTCIceTransportPolicy::from(policy_string), expected_policy); - } - } - - #[test] - fn test_ice_transport_policy_string() { - let tests = vec![ - (RTCIceTransportPolicy::Relay, "relay"), - (RTCIceTransportPolicy::All, "all"), - ]; - - for (policy, expected_string) in tests { - assert_eq!(policy.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/peer_connection/policy/mod.rs b/webrtc/src/peer_connection/policy/mod.rs deleted file mode 100644 index 82036c518..000000000 --- a/webrtc/src/peer_connection/policy/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod bundle_policy; -pub mod ice_transport_policy; -pub mod rtcp_mux_policy; -pub mod sdp_semantics; diff --git a/webrtc/src/peer_connection/policy/rtcp_mux_policy.rs b/webrtc/src/peer_connection/policy/rtcp_mux_policy.rs deleted file mode 100644 index 35ecccaae..000000000 --- a/webrtc/src/peer_connection/policy/rtcp_mux_policy.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// RTCPMuxPolicy affects what ICE candidates are gathered to support -/// non-multiplexed RTCP. -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] -pub enum RTCRtcpMuxPolicy { - #[default] - Unspecified = 0, - - /// RTCPMuxPolicyNegotiate indicates to gather ICE candidates for both - /// RTP and RTCP candidates. If the remote-endpoint is capable of - /// multiplexing RTCP, multiplex RTCP on the RTP candidates. If it is not, - /// use both the RTP and RTCP candidates separately. - #[serde(rename = "negotiate")] - Negotiate = 1, - - /// RTCPMuxPolicyRequire indicates to gather ICE candidates only for - /// RTP and multiplex RTCP on the RTP candidates. If the remote endpoint is - /// not capable of rtcp-mux, session negotiation will fail. - #[serde(rename = "require")] - Require = 2, -} - -const RTCP_MUX_POLICY_NEGOTIATE_STR: &str = "negotiate"; -const RTCP_MUX_POLICY_REQUIRE_STR: &str = "require"; - -impl From<&str> for RTCRtcpMuxPolicy { - fn from(raw: &str) -> Self { - match raw { - RTCP_MUX_POLICY_NEGOTIATE_STR => RTCRtcpMuxPolicy::Negotiate, - RTCP_MUX_POLICY_REQUIRE_STR => RTCRtcpMuxPolicy::Require, - _ => RTCRtcpMuxPolicy::Unspecified, - } - } -} - -impl fmt::Display for RTCRtcpMuxPolicy { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCRtcpMuxPolicy::Negotiate => RTCP_MUX_POLICY_NEGOTIATE_STR, - RTCRtcpMuxPolicy::Require => RTCP_MUX_POLICY_REQUIRE_STR, - RTCRtcpMuxPolicy::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_rtcp_mux_policy() { - let tests = vec![ - ("Unspecified", RTCRtcpMuxPolicy::Unspecified), - ("negotiate", RTCRtcpMuxPolicy::Negotiate), - ("require", RTCRtcpMuxPolicy::Require), - ]; - - for (policy_string, expected_policy) in tests { - assert_eq!(RTCRtcpMuxPolicy::from(policy_string), expected_policy); - } - } - - #[test] - fn test_rtcp_mux_policy_string() { - let tests = vec![ - (RTCRtcpMuxPolicy::Unspecified, "Unspecified"), - (RTCRtcpMuxPolicy::Negotiate, "negotiate"), - (RTCRtcpMuxPolicy::Require, "require"), - ]; - - for (policy, expected_string) in tests { - assert_eq!(policy.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/peer_connection/policy/sdp_semantics.rs b/webrtc/src/peer_connection/policy/sdp_semantics.rs deleted file mode 100644 index 4fe510f11..000000000 --- a/webrtc/src/peer_connection/policy/sdp_semantics.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// SDPSemantics determines which style of SDP offers and answers -/// can be used. -/// -/// This is unused, we only support UnifiedPlan. -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] -pub enum RTCSdpSemantics { - Unspecified = 0, - - /// UnifiedPlan uses unified-plan offers and answers - /// (the default in Chrome since M72) - /// - #[serde(rename = "unified-plan")] - #[default] - UnifiedPlan = 1, - - /// PlanB uses plan-b offers and answers - /// NB: This format should be considered deprecated - /// - #[serde(rename = "plan-b")] - PlanB = 2, - - /// UnifiedPlanWithFallback prefers unified-plan - /// offers and answers, but will respond to a plan-b offer - /// with a plan-b answer - #[serde(rename = "unified-plan-with-fallback")] - UnifiedPlanWithFallback = 3, -} - -const SDP_SEMANTICS_UNIFIED_PLAN_WITH_FALLBACK: &str = "unified-plan-with-fallback"; -const SDP_SEMANTICS_UNIFIED_PLAN: &str = "unified-plan"; -const SDP_SEMANTICS_PLAN_B: &str = "plan-b"; - -impl From<&str> for RTCSdpSemantics { - fn from(raw: &str) -> Self { - match raw { - SDP_SEMANTICS_UNIFIED_PLAN_WITH_FALLBACK => RTCSdpSemantics::UnifiedPlanWithFallback, - SDP_SEMANTICS_UNIFIED_PLAN => RTCSdpSemantics::UnifiedPlan, - SDP_SEMANTICS_PLAN_B => RTCSdpSemantics::PlanB, - _ => RTCSdpSemantics::Unspecified, - } - } -} - -impl fmt::Display for RTCSdpSemantics { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCSdpSemantics::UnifiedPlanWithFallback => SDP_SEMANTICS_UNIFIED_PLAN_WITH_FALLBACK, - RTCSdpSemantics::UnifiedPlan => SDP_SEMANTICS_UNIFIED_PLAN, - RTCSdpSemantics::PlanB => SDP_SEMANTICS_PLAN_B, - RTCSdpSemantics::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use std::collections::HashSet; - - use sdp::description::media::MediaDescription; - use sdp::description::session::{SessionDescription, ATTR_KEY_SSRC}; - - use super::*; - - #[test] - fn test_sdp_semantics_string() { - let tests = vec![ - (RTCSdpSemantics::Unspecified, "Unspecified"), - ( - RTCSdpSemantics::UnifiedPlanWithFallback, - "unified-plan-with-fallback", - ), - (RTCSdpSemantics::PlanB, "plan-b"), - (RTCSdpSemantics::UnifiedPlan, "unified-plan"), - ]; - - for (value, expected_string) in tests { - assert_eq!(value.to_string(), expected_string); - } - } - - // The following tests are for non-standard SDP semantics - // (i.e. not unified-unified) - fn get_md_names(sdp: &SessionDescription) -> Vec { - sdp.media_descriptions - .iter() - .map(|md| md.media_name.media.clone()) - .collect() - } - - fn extract_ssrc_list(md: &MediaDescription) -> Vec { - let mut ssrcs = HashSet::new(); - for attr in &md.attributes { - if attr.key == ATTR_KEY_SSRC { - if let Some(value) = &attr.value { - let fields: Vec<&str> = value.split_whitespace().collect(); - if let Some(ssrc) = fields.first() { - ssrcs.insert(*ssrc); - } - } - } - } - ssrcs - .into_iter() - .map(|ssrc| ssrc.to_owned()) - .collect::>() - } -} diff --git a/webrtc/src/peer_connection/sdp/mod.rs b/webrtc/src/peer_connection/sdp/mod.rs deleted file mode 100644 index 3dff6d151..000000000 --- a/webrtc/src/peer_connection/sdp/mod.rs +++ /dev/null @@ -1,1129 +0,0 @@ -#[cfg(test)] -mod sdp_test; - -use crate::api::media_engine::MediaEngine; -use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; -use crate::error::{Error, Result}; -use crate::ice_transport::ice_candidate::RTCIceCandidate; -use crate::ice_transport::ice_gatherer::RTCIceGatherer; -use crate::ice_transport::ice_gathering_state::RTCIceGatheringState; -use crate::ice_transport::ice_parameters::RTCIceParameters; -use crate::rtp_transceiver::rtp_codec::{ - RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, -}; -use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use crate::rtp_transceiver::{PayloadType, RTCPFeedback, RTCRtpTransceiver, SSRC}; - -pub mod sdp_type; -pub mod session_description; - -use std::collections::HashMap; -use std::convert::From; -use std::io::BufReader; -use std::sync::Arc; - -use ice::candidate::candidate_base::unmarshal_candidate; -use ice::candidate::Candidate; -use sdp::description::common::{Address, ConnectionInformation}; -use sdp::description::media::{MediaDescription, MediaName, RangedPort}; -use sdp::description::session::*; -use sdp::extmap::ExtMap; -use sdp::util::ConnectionRole; -use smol_str::SmolStr; -use url::Url; - -use crate::peer_connection::MEDIA_SECTION_APPLICATION; -use crate::{SDP_ATTRIBUTE_RID, SDP_ATTRIBUTE_SIMULCAST}; - -/// TrackDetails represents any media source that can be represented in a SDP -/// This isn't keyed by SSRC because it also needs to support rid based sources -#[derive(Default, Debug, Clone)] -pub(crate) struct TrackDetails { - pub(crate) mid: SmolStr, - pub(crate) kind: RTPCodecType, - pub(crate) stream_id: String, - pub(crate) id: String, - pub(crate) ssrcs: Vec, - pub(crate) repair_ssrc: SSRC, - pub(crate) rids: Vec, -} - -pub(crate) fn track_details_for_ssrc( - track_details: &[TrackDetails], - ssrc: SSRC, -) -> Option<&TrackDetails> { - track_details.iter().find(|x| x.ssrcs.contains(&ssrc)) -} - -pub(crate) fn track_details_for_rid( - track_details: &[TrackDetails], - rid: SmolStr, -) -> Option<&TrackDetails> { - track_details.iter().find(|x| x.rids.contains(&rid)) -} - -pub(crate) fn filter_track_with_ssrc(incoming_tracks: &mut Vec, ssrc: SSRC) { - incoming_tracks.retain(|x| !x.ssrcs.contains(&ssrc)); -} - -/// extract all TrackDetails from an SDP. -pub(crate) fn track_details_from_sdp( - s: &SessionDescription, - exclude_inactive: bool, -) -> Vec { - let mut incoming_tracks = vec![]; - - for media in &s.media_descriptions { - let mut tracks_in_media_section = vec![]; - let mut rtx_repair_flows = HashMap::new(); - - let mut stream_id = ""; - let mut track_id = ""; - - // If media section is recvonly or inactive skip - if media.attribute(ATTR_KEY_RECV_ONLY).is_some() - || (exclude_inactive && media.attribute(ATTR_KEY_INACTIVE).is_some()) - { - continue; - } - - let mid_value = match get_mid_value(media) { - Some(mid_value) => mid_value, - None => continue, - }; - - let codec_type = RTPCodecType::from(media.media_name.media.as_str()); - if codec_type == RTPCodecType::Unspecified { - continue; - } - - for attr in &media.attributes { - match attr.key.as_str() { - ATTR_KEY_SSRCGROUP => { - if let Some(value) = &attr.value { - let split: Vec<&str> = value.split(' ').collect(); - if split[0] == SEMANTIC_TOKEN_FLOW_IDENTIFICATION { - // Add rtx ssrcs to blacklist, to avoid adding them as tracks - // Essentially lines like `a=ssrc-group:FID 2231627014 632943048` are processed by this section - // as this declares that the second SSRC (632943048) is a rtx repair flow (RFC4588) for the first - // (2231627014) as specified in RFC5576 - if split.len() == 3 { - let base_ssrc = match split[1].parse::() { - Ok(ssrc) => ssrc, - Err(err) => { - log::warn!("Failed to parse SSRC: {}", err); - continue; - } - }; - let rtx_repair_flow = match split[2].parse::() { - Ok(n) => n, - Err(err) => { - log::warn!("Failed to parse SSRC: {}", err); - continue; - } - }; - rtx_repair_flows.insert(rtx_repair_flow, base_ssrc); - // Remove if rtx was added as track before - filter_track_with_ssrc( - &mut tracks_in_media_section, - rtx_repair_flow as SSRC, - ); - } - } - } - } - - // Handle `a=msid: ` The first value is the same as MediaStream.id - // in the browser and can be used to figure out which tracks belong to the same stream. The browser should - // figure this out automatically when an ontrack event is emitted on RTCPeerConnection. - ATTR_KEY_MSID => { - if let Some(value) = &attr.value { - let mut split = value.split(' '); - - if let (Some(sid), Some(tid), None) = - (split.next(), split.next(), split.next()) - { - stream_id = sid; - track_id = tid; - } - } - } - - ATTR_KEY_SSRC => { - if let Some(value) = &attr.value { - let split: Vec<&str> = value.split(' ').collect(); - let ssrc = match split[0].parse::() { - Ok(ssrc) => ssrc, - Err(err) => { - log::warn!("Failed to parse SSRC: {}", err); - continue; - } - }; - - if rtx_repair_flows.contains_key(&ssrc) { - continue; // This ssrc is a RTX repair flow, ignore - } - - if split.len() == 3 && split[1].starts_with("msid:") { - stream_id = &split[1]["msid:".len()..]; - track_id = split[2]; - } - - let mut track_idx = tracks_in_media_section.len(); - - for (i, t) in tracks_in_media_section.iter().enumerate() { - if t.ssrcs.contains(&ssrc) { - track_idx = i; - //TODO: no break? - } - } - - let mut repair_ssrc = 0; - for (repair, base) in &rtx_repair_flows { - if *base == ssrc { - repair_ssrc = *repair; - //TODO: no break? - } - } - - if track_idx < tracks_in_media_section.len() { - tracks_in_media_section[track_idx].mid = SmolStr::from(mid_value); - tracks_in_media_section[track_idx].kind = codec_type; - stream_id.clone_into(&mut tracks_in_media_section[track_idx].stream_id); - track_id.clone_into(&mut tracks_in_media_section[track_idx].id); - tracks_in_media_section[track_idx].ssrcs = vec![ssrc]; - tracks_in_media_section[track_idx].repair_ssrc = repair_ssrc; - } else { - let track_details = TrackDetails { - mid: SmolStr::from(mid_value), - kind: codec_type, - stream_id: stream_id.to_owned(), - id: track_id.to_owned(), - ssrcs: vec![ssrc], - repair_ssrc, - ..Default::default() - }; - tracks_in_media_section.push(track_details); - } - } - } - _ => {} - }; - } - - // If media line is using RTP Stream Identifier Source Description per RFC8851 - // we will need to override tracks, and remove ssrcs. - // This is in particular important for Firefox, as it uses both 'rid', 'simulcast' - // and 'a=ssrc' lines. - let rids = get_rids(media); - if !rids.is_empty() && !track_id.is_empty() && !stream_id.is_empty() { - tracks_in_media_section = vec![TrackDetails { - mid: SmolStr::from(mid_value), - kind: codec_type, - stream_id: stream_id.to_owned(), - id: track_id.to_owned(), - rids: rids.iter().map(|r| SmolStr::from(&r.id)).collect(), - ..Default::default() - }]; - } - - incoming_tracks.extend(tracks_in_media_section); - } - - incoming_tracks -} - -pub(crate) fn get_rids(media: &MediaDescription) -> Vec { - let mut rids = vec![]; - let mut simulcast_attr: Option = None; - for attr in &media.attributes { - if attr.key.as_str() == SDP_ATTRIBUTE_RID { - if let Err(err) = attr - .value - .as_ref() - .ok_or(SimulcastRidParseError::SyntaxIdDirSplit) - .and_then(SimulcastRid::try_from) - .map(|rid| rids.push(rid)) - { - log::warn!("Failed to parse RID: {}", err); - } - } else if attr.key.as_str() == SDP_ATTRIBUTE_SIMULCAST { - simulcast_attr.clone_from(&attr.value); - } - } - - if let Some(attr) = simulcast_attr { - let mut split = attr.split(' '); - loop { - let _dir = split.next(); - let sc_str_list = split.next(); - if let Some(list) = sc_str_list { - let sc_list: Vec<&str> = list.split(';').flat_map(|alt| alt.split(',')).collect(); - for sc_id in sc_list { - let (sc_id, paused) = if let Some(sc_id) = sc_id.strip_prefix('~') { - (sc_id, true) - } else { - (sc_id, false) - }; - - if let Some(rid) = rids.iter_mut().find(|f| f.id == sc_id) { - rid.paused = paused; - } - } - } else { - break; - } - } - } - - rids -} - -pub(crate) async fn add_candidates_to_media_descriptions( - candidates: &[RTCIceCandidate], - mut m: MediaDescription, - ice_gathering_state: RTCIceGatheringState, -) -> Result { - let append_candidate_if_new = |c: &dyn Candidate, m: MediaDescription| -> MediaDescription { - let marshaled = c.marshal(); - for a in &m.attributes { - if let Some(value) = &a.value { - if &marshaled == value { - return m; - } - } - } - - m.with_value_attribute("candidate".to_owned(), marshaled) - }; - - for c in candidates { - let candidate = c.to_ice()?; - - candidate.set_component(1); - m = append_candidate_if_new(&candidate, m); - - candidate.set_component(2); - m = append_candidate_if_new(&candidate, m); - } - - if ice_gathering_state != RTCIceGatheringState::Complete { - return Ok(m); - } - for a in &m.attributes { - if &a.key == "end-of-candidates" { - return Ok(m); - } - } - - Ok(m.with_property_attribute("end-of-candidates".to_owned())) -} - -pub(crate) struct AddDataMediaSectionParams { - should_add_candidates: bool, - mid_value: String, - ice_params: RTCIceParameters, - dtls_role: ConnectionRole, - ice_gathering_state: RTCIceGatheringState, -} - -pub(crate) async fn add_data_media_section( - d: SessionDescription, - dtls_fingerprints: &[RTCDtlsFingerprint], - candidates: &[RTCIceCandidate], - params: AddDataMediaSectionParams, -) -> Result { - let mut media = MediaDescription { - media_name: MediaName { - media: MEDIA_SECTION_APPLICATION.to_owned(), - port: RangedPort { - value: 9, - range: None, - }, - protos: vec!["UDP".to_owned(), "DTLS".to_owned(), "SCTP".to_owned()], - formats: vec!["webrtc-datachannel".to_owned()], - }, - media_title: None, - connection_information: Some(ConnectionInformation { - network_type: "IN".to_owned(), - address_type: "IP4".to_owned(), - address: Some(Address { - address: "0.0.0.0".to_owned(), - ttl: None, - range: None, - }), - }), - bandwidth: vec![], - encryption_key: None, - attributes: vec![], - } - .with_value_attribute( - ATTR_KEY_CONNECTION_SETUP.to_owned(), - params.dtls_role.to_string(), - ) - .with_value_attribute(ATTR_KEY_MID.to_owned(), params.mid_value) - .with_property_attribute(RTCRtpTransceiverDirection::Sendrecv.to_string()) - .with_property_attribute("sctp-port:5000".to_owned()) - .with_ice_credentials( - params.ice_params.username_fragment, - params.ice_params.password, - ); - - for f in dtls_fingerprints { - media = media.with_fingerprint(f.algorithm.clone(), f.value.to_uppercase()); - } - - if params.should_add_candidates { - media = add_candidates_to_media_descriptions(candidates, media, params.ice_gathering_state) - .await?; - } - - Ok(d.with_media(media)) -} - -pub(crate) async fn populate_local_candidates( - session_description: Option<&session_description::RTCSessionDescription>, - ice_gatherer: Option<&Arc>, - ice_gathering_state: RTCIceGatheringState, -) -> Option { - if session_description.is_none() || ice_gatherer.is_none() { - return session_description.cloned(); - } - - if let (Some(sd), Some(ice)) = (session_description, ice_gatherer) { - let candidates = match ice.get_local_candidates().await { - Ok(candidates) => candidates, - Err(_) => return Some(sd.clone()), - }; - - let mut parsed = match sd.unmarshal() { - Ok(parsed) => parsed, - Err(_) => return Some(sd.clone()), - }; - - if !parsed.media_descriptions.is_empty() { - let mut m = parsed.media_descriptions.remove(0); - m = match add_candidates_to_media_descriptions(&candidates, m, ice_gathering_state) - .await - { - Ok(m) => m, - Err(_) => return Some(sd.clone()), - }; - parsed.media_descriptions.insert(0, m); - } - - Some(session_description::RTCSessionDescription { - sdp_type: sd.sdp_type, - sdp: parsed.marshal(), - parsed: Some(parsed), - }) - } else { - None - } -} - -pub(crate) struct AddTransceiverSdpParams { - should_add_candidates: bool, - mid_value: String, - dtls_role: ConnectionRole, - ice_gathering_state: RTCIceGatheringState, - offered_direction: Option, -} - -pub(crate) async fn add_transceiver_sdp( - mut d: SessionDescription, - dtls_fingerprints: &[RTCDtlsFingerprint], - media_engine: &Arc, - ice_params: &RTCIceParameters, - candidates: &[RTCIceCandidate], - media_section: &MediaSection, - params: AddTransceiverSdpParams, -) -> Result<(SessionDescription, bool)> { - if media_section.transceivers.is_empty() { - return Err(Error::ErrSDPZeroTransceivers); - } - let (should_add_candidates, mid_value, dtls_role, ice_gathering_state) = ( - params.should_add_candidates, - params.mid_value, - params.dtls_role, - params.ice_gathering_state, - ); - - let transceivers = &media_section.transceivers; - // Use the first transceiver to generate the section attributes - let t = &transceivers[0]; - let mut media = MediaDescription::new_jsep_media_description(t.kind.to_string(), vec![]) - .with_value_attribute(ATTR_KEY_CONNECTION_SETUP.to_owned(), dtls_role.to_string()) - .with_value_attribute(ATTR_KEY_MID.to_owned(), mid_value.clone()) - .with_ice_credentials( - ice_params.username_fragment.clone(), - ice_params.password.clone(), - ) - .with_property_attribute(ATTR_KEY_RTCPMUX.to_owned()) - .with_property_attribute(ATTR_KEY_RTCPRSIZE.to_owned()); - - let codecs = t.get_codecs().await; - for codec in &codecs { - let name = codec - .capability - .mime_type - .trim_start_matches("audio/") - .trim_start_matches("video/") - .to_owned(); - media = media.with_codec( - codec.payload_type, - name, - codec.capability.clock_rate, - codec.capability.channels, - codec.capability.sdp_fmtp_line.clone(), - ); - - for feedback in &codec.capability.rtcp_feedback { - media = media.with_value_attribute( - "rtcp-fb".to_owned(), - format!( - "{} {} {}", - codec.payload_type, feedback.typ, feedback.parameter - ), - ); - } - } - if codecs.is_empty() { - // If we are sender and we have no codecs throw an error early - if t.sender().await.track().await.is_some() { - return Err(Error::ErrSenderWithNoCodecs); - } - - // Explicitly reject track if we don't have the codec - d = d.with_media(MediaDescription { - media_name: sdp::description::media::MediaName { - media: t.kind.to_string(), - port: RangedPort { - value: 0, - range: None, - }, - protos: vec![ - "UDP".to_owned(), - "TLS".to_owned(), - "RTP".to_owned(), - "SAVPF".to_owned(), - ], - formats: vec!["0".to_owned()], - }, - media_title: None, - // We need to include connection information even if we're rejecting a track, otherwise Firefox will fail to - // parse the SDP with an error like: - // SIPCC Failed to parse SDP: SDP Parse Error on line 50: c= connection line not specified for every media level, validation failed. - // In addition this makes our SDP compliant with RFC 4566 Section 5.7: https://datatracker.ietf.org/doc/html/rfc4566#section-5.7 - connection_information: Some(ConnectionInformation { - network_type: "IN".to_owned(), - address_type: "IP4".to_owned(), - address: Some(Address { - address: "0.0.0.0".to_owned(), - ttl: None, - range: None, - }), - }), - bandwidth: vec![], - encryption_key: None, - attributes: vec![], - }); - return Ok((d, false)); - } - - let parameters = media_engine.get_rtp_parameters_by_kind(t.kind, t.direction()); - for rtp_extension in ¶meters.header_extensions { - let ext_url = Url::parse(rtp_extension.uri.as_str())?; - media = media.with_extmap(sdp::extmap::ExtMap { - value: rtp_extension.id, - uri: Some(ext_url), - ..Default::default() - }); - } - - if !media_section.rid_map.is_empty() { - let mut recv_sc_list: Vec = vec![]; - let mut send_sc_list: Vec = vec![]; - - for rid in &media_section.rid_map { - let rid_syntax = match rid.direction { - SimulcastDirection::Send => { - // If Send rid, then reply with a recv rid - if rid.paused { - recv_sc_list.push(format!("~{}", rid.id)); - } else { - recv_sc_list.push(rid.id.to_owned()); - } - format!("{} recv", rid.id) - } - SimulcastDirection::Recv => { - // If Recv rid, then reply with a send rid - if rid.paused { - send_sc_list.push(format!("~{}", rid.id)); - } else { - send_sc_list.push(rid.id.to_owned()); - } - format!("{} send", rid.id) - } - }; - media = media.with_value_attribute(SDP_ATTRIBUTE_RID.to_owned(), rid_syntax); - } - - // Simulcast - let mut sc_attr = String::new(); - if !recv_sc_list.is_empty() { - sc_attr.push_str(&format!("recv {}", recv_sc_list.join(";"))); - } - if !send_sc_list.is_empty() { - sc_attr.push_str(&format!("send {}", send_sc_list.join(";"))); - } - media = media.with_value_attribute(SDP_ATTRIBUTE_SIMULCAST.to_owned(), sc_attr); - } - - for mt in transceivers { - let sender = mt.sender().await; - if let Some(track) = sender.track().await { - let send_parameters = sender.get_parameters().await; - for encoding in &send_parameters.encodings { - media = media.with_media_source( - encoding.ssrc, - track.stream_id().to_owned(), /* cname */ - track.stream_id().to_owned(), /* streamLabel */ - track.id().to_owned(), - ); - } - - if send_parameters.encodings.len() > 1 { - let mut send_rids = Vec::with_capacity(send_parameters.encodings.len()); - - for encoding in &send_parameters.encodings { - media = media.with_value_attribute( - SDP_ATTRIBUTE_RID.to_owned(), - format!("{} send", encoding.rid), - ); - send_rids.push(encoding.rid.to_string()); - } - - media = media.with_value_attribute( - SDP_ATTRIBUTE_SIMULCAST.to_owned(), - format!("send {}", send_rids.join(";")), - ); - } - - // Send msid based on the configured track if we haven't already - // sent on this sender. If we have sent we must keep the msid line consistent, this - // is handled below. - if sender.initial_track_id().is_none() { - for stream_id in sender.associated_media_stream_ids() { - media = - media.with_property_attribute(format!("msid:{} {}", stream_id, track.id())); - } - - sender.set_initial_track_id(track.id().to_string())?; - break; - } - } - - if let Some(track_id) = sender.initial_track_id() { - // After we have include an msid attribute in an offer it must stay the same for - // all subsequent offer even if the track or transceiver direction changes. - // - // [RFC 8829 Section 5.2.2](https://datatracker.ietf.org/doc/html/rfc8829#section-5.2.2) - // - // For RtpTransceivers that are not stopped, the "a=msid" line or - // lines MUST stay the same if they are present in the current - // description, regardless of changes to the transceiver's direction - // or track. If no "a=msid" line is present in the current - // description, "a=msid" line(s) MUST be generated according to the - // same rules as for an initial offer. - for stream_id in sender.associated_media_stream_ids() { - media = media.with_property_attribute(format!("msid:{stream_id} {track_id}")); - } - - break; - } - } - - let direction = match params.offered_direction { - Some(offered_direction) => { - use RTCRtpTransceiverDirection::*; - let transceiver_direction = t.direction(); - - match offered_direction { - Sendonly | Recvonly => { - // If a stream is offered as sendonly, the corresponding stream MUST be - // marked as recvonly or inactive in the answer. - - // If a media stream is - // listed as recvonly in the offer, the answer MUST be marked as - // sendonly or inactive in the answer. - offered_direction.reverse().intersect(transceiver_direction) - } - // If an offered media stream is - // listed as sendrecv (or if there is no direction attribute at the - // media or session level, in which case the stream is sendrecv by - // default), the corresponding stream in the answer MAY be marked as - // sendonly, recvonly, sendrecv, or inactive - Sendrecv | Unspecified => t.direction(), - // If an offered media - // stream is listed as inactive, it MUST be marked as inactive in the - // answer. - Inactive => Inactive, - } - } - None => { - // If don't have an offered direction to intersect with just use the transceivers - // current direction. - // - // https://datatracker.ietf.org/doc/html/rfc8829#section-4.2.3 - // - // When creating offers, the transceiver direction is directly reflected - // in the output, even for re-offers. - t.direction() - } - }; - media = media.with_property_attribute(direction.to_string()); - - for fingerprint in dtls_fingerprints { - media = media.with_fingerprint( - fingerprint.algorithm.to_owned(), - fingerprint.value.to_uppercase(), - ); - } - - if should_add_candidates { - media = - add_candidates_to_media_descriptions(candidates, media, ice_gathering_state).await?; - } - - Ok((d.with_media(media), true)) -} - -#[derive(thiserror::Error, Debug, PartialEq)] -pub(crate) enum SimulcastRidParseError { - /// SyntaxIdDirSplit indicates rid-syntax could not be parsed. - #[error("RFC8851 mandates rid-syntax = %s\"a=rid:\" rid-id SP rid-dir")] - SyntaxIdDirSplit, - /// UnknownDirection indicates rid-dir was not parsed. Should be "send" or "recv". - #[error("RFC8851 mandates rid-dir = %s\"send\" / %s\"recv\"")] - UnknownDirection, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub(crate) enum SimulcastDirection { - Send, - Recv, -} - -impl TryFrom<&str> for SimulcastDirection { - type Error = SimulcastRidParseError; - fn try_from(value: &str) -> std::result::Result { - match value.to_lowercase().as_str() { - "send" => Ok(SimulcastDirection::Send), - "recv" => Ok(SimulcastDirection::Recv), - _ => Err(SimulcastRidParseError::UnknownDirection), - } - } -} - -#[derive(Clone, Debug)] -pub(crate) struct SimulcastRid { - pub(crate) id: String, - pub(crate) direction: SimulcastDirection, - pub(crate) params: String, - pub(crate) paused: bool, -} - -impl TryFrom<&String> for SimulcastRid { - type Error = SimulcastRidParseError; - fn try_from(value: &String) -> std::result::Result { - let mut split = value.split(' '); - let id = split - .next() - .ok_or(SimulcastRidParseError::SyntaxIdDirSplit)? - .to_owned(); - let direction = SimulcastDirection::try_from( - split - .next() - .ok_or(SimulcastRidParseError::SyntaxIdDirSplit)?, - )?; - let params = split.collect(); - - Ok(Self { - id, - direction, - params, - paused: false, - }) - } -} - -fn bundle_match(bundle: Option<&String>, id: &str) -> bool { - match bundle { - None => true, - Some(b) => b.split_whitespace().any(|s| s == id), - } -} - -#[derive(Default)] -pub(crate) struct MediaSection { - pub(crate) id: String, - pub(crate) transceivers: Vec>, - pub(crate) data: bool, - pub(crate) rid_map: Vec, - pub(crate) offered_direction: Option, -} - -pub(crate) struct PopulateSdpParams { - pub(crate) media_description_fingerprint: bool, - pub(crate) is_icelite: bool, - pub(crate) connection_role: ConnectionRole, - pub(crate) ice_gathering_state: RTCIceGatheringState, - pub(crate) match_bundle_group: Option, -} - -/// populate_sdp serializes a PeerConnections state into an SDP -pub(crate) async fn populate_sdp( - mut d: SessionDescription, - dtls_fingerprints: &[RTCDtlsFingerprint], - media_engine: &Arc, - candidates: &[RTCIceCandidate], - ice_params: &RTCIceParameters, - media_sections: &[MediaSection], - params: PopulateSdpParams, -) -> Result { - let media_dtls_fingerprints = if params.media_description_fingerprint { - dtls_fingerprints.to_vec() - } else { - vec![] - }; - - let mut bundle_value = "BUNDLE".to_owned(); - let mut bundle_count = 0; - let append_bundle = |mid_value: &str, value: &mut String, count: &mut i32| { - *value = value.clone() + " " + mid_value; - *count += 1; - }; - - for (i, m) in media_sections.iter().enumerate() { - if m.data && !m.transceivers.is_empty() { - return Err(Error::ErrSDPMediaSectionMediaDataChanInvalid); - } else if m.transceivers.len() > 1 { - return Err(Error::ErrSDPMediaSectionMultipleTrackInvalid); - } - - let should_add_candidates = i == 0; - - let should_add_id = if m.data { - let params = AddDataMediaSectionParams { - should_add_candidates, - mid_value: m.id.clone(), - ice_params: ice_params.clone(), - dtls_role: params.connection_role, - ice_gathering_state: params.ice_gathering_state, - }; - d = add_data_media_section(d, &media_dtls_fingerprints, candidates, params).await?; - true - } else { - let params = AddTransceiverSdpParams { - should_add_candidates, - mid_value: m.id.clone(), - dtls_role: params.connection_role, - ice_gathering_state: params.ice_gathering_state, - offered_direction: m.offered_direction, - }; - let (d1, should_add_id) = add_transceiver_sdp( - d, - &media_dtls_fingerprints, - media_engine, - ice_params, - candidates, - m, - params, - ) - .await?; - d = d1; - should_add_id - }; - - if should_add_id { - if bundle_match(params.match_bundle_group.as_ref(), &m.id) { - append_bundle(&m.id, &mut bundle_value, &mut bundle_count); - } else if let Some(desc) = d.media_descriptions.last_mut() { - desc.media_name.port = RangedPort { - value: 0, - range: None, - } - } - } - } - - if !params.media_description_fingerprint { - for fingerprint in dtls_fingerprints { - d = d.with_fingerprint( - fingerprint.algorithm.clone(), - fingerprint.value.to_uppercase(), - ); - } - } - - if params.is_icelite { - // RFC 5245 S15.3 - d = d.with_value_attribute(ATTR_KEY_ICELITE.to_owned(), ATTR_KEY_ICELITE.to_owned()); - } - - if bundle_count > 0 { - d = d.with_value_attribute(ATTR_KEY_GROUP.to_owned(), bundle_value); - } - - Ok(d) -} - -pub(crate) fn get_mid_value(media: &MediaDescription) -> Option<&String> { - for attr in &media.attributes { - if attr.key == "mid" { - return attr.value.as_ref(); - } - } - None -} - -pub(crate) fn get_peer_direction(media: &MediaDescription) -> RTCRtpTransceiverDirection { - for a in &media.attributes { - let direction = RTCRtpTransceiverDirection::from(a.key.as_str()); - if direction != RTCRtpTransceiverDirection::Unspecified { - return direction; - } - } - RTCRtpTransceiverDirection::Unspecified -} - -pub(crate) fn extract_fingerprint(desc: &SessionDescription) -> Result<(String, String)> { - let mut fingerprints = vec![]; - - if let Some(fingerprint) = desc.attribute("fingerprint") { - fingerprints.push(fingerprint.clone()); - } - - for m in &desc.media_descriptions { - if let Some(fingerprint) = m.attribute("fingerprint").and_then(|o| o) { - fingerprints.push(fingerprint.to_owned()); - } - } - - if fingerprints.is_empty() { - return Err(Error::ErrSessionDescriptionNoFingerprint); - } - - for m in 1..fingerprints.len() { - if fingerprints[m] != fingerprints[0] { - return Err(Error::ErrSessionDescriptionConflictingFingerprints); - } - } - - let parts: Vec<&str> = fingerprints[0].split(' ').collect(); - if parts.len() != 2 { - return Err(Error::ErrSessionDescriptionInvalidFingerprint); - } - - Ok((parts[1].to_owned(), parts[0].to_owned())) -} - -pub(crate) async fn extract_ice_details( - desc: &SessionDescription, -) -> Result<(String, String, Vec)> { - let mut candidates = vec![]; - let mut remote_pwds = vec![]; - let mut remote_ufrags = vec![]; - - if let Some(ufrag) = desc.attribute("ice-ufrag") { - remote_ufrags.push(ufrag.clone()); - } - if let Some(pwd) = desc.attribute("ice-pwd") { - remote_pwds.push(pwd.clone()); - } - - for m in &desc.media_descriptions { - if let Some(ufrag) = m.attribute("ice-ufrag").and_then(|o| o) { - remote_ufrags.push(ufrag.to_owned()); - } - if let Some(pwd) = m.attribute("ice-pwd").and_then(|o| o) { - remote_pwds.push(pwd.to_owned()); - } - - for a in &m.attributes { - if a.is_ice_candidate() { - if let Some(value) = &a.value { - let c: Arc = Arc::new(unmarshal_candidate(value)?); - let candidate = RTCIceCandidate::from(&c); - candidates.push(candidate); - } - } - } - } - - if remote_ufrags.is_empty() { - return Err(Error::ErrSessionDescriptionMissingIceUfrag); - } else if remote_pwds.is_empty() { - return Err(Error::ErrSessionDescriptionMissingIcePwd); - } - - for m in 1..remote_ufrags.len() { - if remote_ufrags[m] != remote_ufrags[0] { - return Err(Error::ErrSessionDescriptionConflictingIceUfrag); - } - } - - for m in 1..remote_pwds.len() { - if remote_pwds[m] != remote_pwds[0] { - return Err(Error::ErrSessionDescriptionConflictingIcePwd); - } - } - - Ok((remote_ufrags[0].clone(), remote_pwds[0].clone(), candidates)) -} - -pub(crate) fn have_application_media_section(desc: &SessionDescription) -> bool { - for m in &desc.media_descriptions { - if m.media_name.media == MEDIA_SECTION_APPLICATION { - return true; - } - } - - false -} - -pub(crate) fn get_by_mid<'a>( - search_mid: &str, - desc: &'a session_description::RTCSessionDescription, -) -> Option<&'a MediaDescription> { - if let Some(parsed) = &desc.parsed { - for m in &parsed.media_descriptions { - if let Some(mid) = m.attribute(ATTR_KEY_MID).flatten() { - if mid == search_mid { - return Some(m); - } - } - } - } - None -} - -/// have_data_channel return MediaDescription with MediaName equal application -pub(crate) fn have_data_channel( - desc: &session_description::RTCSessionDescription, -) -> Option<&MediaDescription> { - if let Some(parsed) = &desc.parsed { - for d in &parsed.media_descriptions { - if d.media_name.media == MEDIA_SECTION_APPLICATION { - return Some(d); - } - } - } - None -} - -pub(crate) fn codecs_from_media_description( - m: &MediaDescription, -) -> Result> { - let s = SessionDescription { - media_descriptions: vec![m.clone()], - ..Default::default() - }; - - let mut out = vec![]; - for payload_str in &m.media_name.formats { - let payload_type: PayloadType = payload_str.parse::()?; - let codec = match s.get_codec_for_payload_type(payload_type) { - Ok(codec) => codec, - Err(err) => { - if payload_type == 0 { - continue; - } - return Err(err.into()); - } - }; - - let channels = codec.encoding_parameters.parse::().unwrap_or(0); - - let mut feedback = vec![]; - for raw in &codec.rtcp_feedback { - let split: Vec<&str> = raw.split(' ').collect(); - - let entry = if split.len() == 2 { - RTCPFeedback { - typ: split[0].to_string(), - parameter: split[1].to_string(), - } - } else { - RTCPFeedback { - typ: split[0].to_string(), - parameter: String::new(), - } - }; - - feedback.push(entry); - } - - out.push(RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: m.media_name.media.clone() + "/" + codec.name.as_str(), - clock_rate: codec.clock_rate, - channels, - sdp_fmtp_line: codec.fmtp.clone(), - rtcp_feedback: feedback, - }, - payload_type, - stats_id: String::new(), - }) - } - - Ok(out) -} - -pub(crate) fn rtp_extensions_from_media_description( - m: &MediaDescription, -) -> Result> { - let mut out = HashMap::new(); - - for a in &m.attributes { - if a.key == ATTR_KEY_EXT_MAP { - let a_str = a.to_string(); - let mut reader = BufReader::new(a_str.as_bytes()); - let e = ExtMap::unmarshal(&mut reader)?; - - if let Some(uri) = e.uri { - out.insert(uri.to_string(), e.value); - } - } - } - - Ok(out) -} - -/// update_sdp_origin saves sdp.Origin in PeerConnection when creating 1st local SDP; -/// for subsequent calling, it updates Origin for SessionDescription from saved one -/// and increments session version by one. -/// -pub(crate) fn update_sdp_origin(origin: &mut Origin, d: &mut SessionDescription) { - //TODO: if atomic.CompareAndSwapUint64(&origin.SessionVersion, 0, d.Origin.SessionVersion) - if origin.session_version == 0 { - // store - origin.session_version = d.origin.session_version; - //atomic.StoreUint64(&origin.SessionID, d.Origin.SessionID) - origin.session_id = d.origin.session_id; - } else { - // load - /*for { // awaiting for saving session id - d.Origin.SessionID = atomic.LoadUint64(&origin.SessionID) - if d.Origin.SessionID != 0 { - break - } - }*/ - d.origin.session_id = origin.session_id; - - //d.Origin.SessionVersion = atomic.AddUint64(&origin.SessionVersion, 1) - origin.session_version += 1; - d.origin.session_version += 1; - } -} diff --git a/webrtc/src/peer_connection/sdp/sdp_test.rs b/webrtc/src/peer_connection/sdp/sdp_test.rs deleted file mode 100644 index ca8cf5b24..000000000 --- a/webrtc/src/peer_connection/sdp/sdp_test.rs +++ /dev/null @@ -1,1378 +0,0 @@ -use rcgen::KeyPair; -use sdp::description::common::Attribute; - -use super::*; -use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8}; -use crate::api::setting_engine::SettingEngine; -use crate::api::APIBuilder; -use crate::dtls_transport::dtls_role::DEFAULT_DTLS_ROLE_OFFER; -use crate::dtls_transport::RTCDtlsTransport; -use crate::peer_connection::certificate::RTCCertificate; -use crate::rtp_transceiver::rtp_sender::RTCRtpSender; -use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use crate::track::track_local::TrackLocal; - -#[test] -fn test_extract_fingerprint() -> Result<()> { - //"Good Session Fingerprint" - { - let s = SessionDescription { - attributes: vec![Attribute { - key: "fingerprint".to_owned(), - value: Some("foo bar".to_owned()), - }], - ..Default::default() - }; - - let (fingerprint, hash) = extract_fingerprint(&s)?; - assert_eq!(fingerprint, "bar"); - assert_eq!(hash, "foo"); - } - - //"Good Media Fingerprint" - { - let s = SessionDescription { - media_descriptions: vec![MediaDescription { - attributes: vec![Attribute { - key: "fingerprint".to_owned(), - value: Some("foo bar".to_owned()), - }], - ..Default::default() - }], - ..Default::default() - }; - - let (fingerprint, hash) = extract_fingerprint(&s)?; - assert_eq!(fingerprint, "bar"); - assert_eq!(hash, "foo"); - } - - //"No Fingerprint" - { - let s = SessionDescription::default(); - - assert_eq!( - extract_fingerprint(&s).expect_err("fingerprint absence must be detected"), - Error::ErrSessionDescriptionNoFingerprint - ); - } - - //"Invalid Fingerprint" - { - let s = SessionDescription { - attributes: vec![Attribute { - key: "fingerprint".to_owned(), - value: Some("foo".to_owned()), - }], - ..Default::default() - }; - - assert_eq!( - extract_fingerprint(&s).expect_err("invalid fingerprint text must be detected"), - Error::ErrSessionDescriptionInvalidFingerprint - ); - } - - //"Conflicting Fingerprint" - { - let s = SessionDescription { - attributes: vec![Attribute { - key: "fingerprint".to_owned(), - value: Some("foo".to_owned()), - }], - media_descriptions: vec![MediaDescription { - attributes: vec![Attribute { - key: "fingerprint".to_owned(), - value: Some("foo bar".to_owned()), - }], - ..Default::default() - }], - ..Default::default() - }; - - assert_eq!( - extract_fingerprint(&s).expect_err("mismatching fingerprint texts must be detected"), - Error::ErrSessionDescriptionConflictingFingerprints - ); - } - - Ok(()) -} - -#[tokio::test] -async fn test_extract_ice_details() -> Result<()> { - const DEFAULT_UFRAG: &str = "DEFAULT_PWD"; - const DEFAULT_PWD: &str = "DEFAULT_UFRAG"; - - //"Missing ice-pwd" - { - let s = SessionDescription { - media_descriptions: vec![MediaDescription { - attributes: vec![Attribute { - key: "ice-ufrag".to_owned(), - value: Some(DEFAULT_UFRAG.to_owned()), - }], - ..Default::default() - }], - ..Default::default() - }; - - assert_eq!( - extract_ice_details(&s) - .await - .expect_err("ICE requires password for authentication"), - Error::ErrSessionDescriptionMissingIcePwd - ); - } - - //"Missing ice-ufrag" - { - let s = SessionDescription { - media_descriptions: vec![MediaDescription { - attributes: vec![Attribute { - key: "ice-pwd".to_owned(), - value: Some(DEFAULT_PWD.to_owned()), - }], - ..Default::default() - }], - ..Default::default() - }; - - assert_eq!( - extract_ice_details(&s) - .await - .expect_err("ICE requires 'user fragment' for authentication"), - Error::ErrSessionDescriptionMissingIceUfrag - ); - } - - //"ice details at session level" - { - let s = SessionDescription { - attributes: vec![ - Attribute { - key: "ice-ufrag".to_owned(), - value: Some(DEFAULT_UFRAG.to_owned()), - }, - Attribute { - key: "ice-pwd".to_owned(), - value: Some(DEFAULT_PWD.to_owned()), - }, - ], - media_descriptions: vec![], - ..Default::default() - }; - - let (ufrag, pwd, _) = extract_ice_details(&s).await?; - assert_eq!(ufrag, DEFAULT_UFRAG); - assert_eq!(pwd, DEFAULT_PWD); - } - - //"ice details at media level" - { - let s = SessionDescription { - media_descriptions: vec![MediaDescription { - attributes: vec![ - Attribute { - key: "ice-ufrag".to_owned(), - value: Some(DEFAULT_UFRAG.to_owned()), - }, - Attribute { - key: "ice-pwd".to_owned(), - value: Some(DEFAULT_PWD.to_owned()), - }, - ], - ..Default::default() - }], - ..Default::default() - }; - - let (ufrag, pwd, _) = extract_ice_details(&s).await?; - assert_eq!(ufrag, DEFAULT_UFRAG); - assert_eq!(pwd, DEFAULT_PWD); - } - - //"Conflict ufrag" - { - let s = SessionDescription { - attributes: vec![Attribute { - key: "ice-ufrag".to_owned(), - value: Some("invalidUfrag".to_owned()), - }], - media_descriptions: vec![MediaDescription { - attributes: vec![ - Attribute { - key: "ice-ufrag".to_owned(), - value: Some(DEFAULT_UFRAG.to_owned()), - }, - Attribute { - key: "ice-pwd".to_owned(), - value: Some(DEFAULT_PWD.to_owned()), - }, - ], - ..Default::default() - }], - ..Default::default() - }; - - assert_eq!( - extract_ice_details(&s) - .await - .expect_err("mismatching ICE ufrags must be detected"), - Error::ErrSessionDescriptionConflictingIceUfrag - ); - } - - //"Conflict pwd" - { - let s = SessionDescription { - attributes: vec![Attribute { - key: "ice-pwd".to_owned(), - value: Some("invalidPwd".to_owned()), - }], - media_descriptions: vec![MediaDescription { - attributes: vec![ - Attribute { - key: "ice-ufrag".to_owned(), - value: Some(DEFAULT_UFRAG.to_owned()), - }, - Attribute { - key: "ice-pwd".to_owned(), - value: Some(DEFAULT_PWD.to_owned()), - }, - ], - ..Default::default() - }], - ..Default::default() - }; - - assert_eq!( - extract_ice_details(&s) - .await - .expect_err("mismatching ICE passwords must be detected"), - Error::ErrSessionDescriptionConflictingIcePwd - ); - } - - Ok(()) -} - -#[test] -fn test_track_details_from_sdp() -> Result<()> { - //"Tracks unknown, audio and video with RTX" - { - let s = SessionDescription { - media_descriptions: vec![ - MediaDescription { - media_name: MediaName { - media: "foobar".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "mid".to_owned(), - value: Some("0".to_owned()), - }, - Attribute { - key: "sendrecv".to_owned(), - value: None, - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("1000 msid:unknown_trk_label unknown_trk_guid".to_owned()), - }, - ], - ..Default::default() - }, - MediaDescription { - media_name: MediaName { - media: "audio".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "mid".to_owned(), - value: Some("1".to_owned()), - }, - Attribute { - key: "sendrecv".to_owned(), - value: None, - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("2000 msid:audio_trk_label audio_trk_guid".to_owned()), - }, - ], - ..Default::default() - }, - MediaDescription { - media_name: MediaName { - media: "video".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "mid".to_owned(), - value: Some("2".to_owned()), - }, - Attribute { - key: "sendrecv".to_owned(), - value: None, - }, - Attribute { - key: "ssrc-group".to_owned(), - value: Some("FID 3000 4000".to_owned()), - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("3000 msid:video_trk_label video_trk_guid".to_owned()), - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("4000 msid:rtx_trk_label rtx_trck_guid".to_owned()), - }, - ], - ..Default::default() - }, - MediaDescription { - media_name: MediaName { - media: "video".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "mid".to_owned(), - value: Some("3".to_owned()), - }, - Attribute { - key: "sendonly".to_owned(), - value: None, - }, - Attribute { - key: "msid".to_owned(), - value: Some("video_stream_id video_trk_id".to_owned()), - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("5000".to_owned()), - }, - ], - ..Default::default() - }, - MediaDescription { - media_name: MediaName { - media: "video".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "sendonly".to_owned(), - value: None, - }, - Attribute { - key: SDP_ATTRIBUTE_RID.to_owned(), - value: Some("f send pt=97;max-width=1280;max-height=720".to_owned()), - }, - ], - ..Default::default() - }, - ], - ..Default::default() - }; - - let tracks = track_details_from_sdp(&s, true); - assert_eq!(tracks.len(), 3); - if track_details_for_ssrc(&tracks, 1000).is_some() { - panic!("got the unknown track ssrc:1000 which should have been skipped"); - } - if let Some(track) = track_details_for_ssrc(&tracks, 2000) { - assert_eq!(track.kind, RTPCodecType::Audio); - assert_eq!(track.ssrcs[0], 2000); - assert_eq!(track.stream_id, "audio_trk_label"); - } else { - panic!("missing audio track with ssrc:2000"); - } - if let Some(track) = track_details_for_ssrc(&tracks, 3000) { - assert_eq!(track.kind, RTPCodecType::Video); - assert_eq!(track.ssrcs[0], 3000); - assert_eq!(track.stream_id, "video_trk_label"); - } else { - panic!("missing video track with ssrc:3000"); - } - if track_details_for_ssrc(&tracks, 4000).is_some() { - panic!("got the rtx track ssrc:3000 which should have been skipped"); - } - if let Some(track) = track_details_for_ssrc(&tracks, 5000) { - assert_eq!(track.kind, RTPCodecType::Video); - assert_eq!(track.ssrcs[0], 5000); - assert_eq!(track.id, "video_trk_id"); - assert_eq!(track.stream_id, "video_stream_id"); - } else { - panic!("missing video track with ssrc:5000"); - } - } - - { - let s = SessionDescription { - media_descriptions: vec![ - MediaDescription { - media_name: MediaName { - media: "video".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "mid".to_owned(), - value: Some("1".to_owned()), - }, - Attribute { - key: "inactive".to_owned(), - value: None, - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("6000".to_owned()), - }, - ], - ..Default::default() - }, - MediaDescription { - media_name: MediaName { - media: "video".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "mid".to_owned(), - value: Some("1".to_owned()), - }, - Attribute { - key: "recvonly".to_owned(), - value: None, - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("7000".to_owned()), - }, - ], - ..Default::default() - }, - ], - ..Default::default() - }; - assert_eq!( - track_details_from_sdp(&s, true).len(), - 0, - "inactive and recvonly tracks should be ignored when passing exclude_inactive: true" - ); - assert_eq!( - track_details_from_sdp(&s, false).len(), - 1, - "Inactive tracks should not be ignored when passing exclude_inactive: false" - ); - } - - Ok(()) -} - -#[test] -fn test_have_application_media_section() -> Result<()> { - //"Audio only" - { - let s = SessionDescription { - media_descriptions: vec![MediaDescription { - media_name: MediaName { - media: "audio".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "sendrecv".to_owned(), - value: None, - }, - Attribute { - key: "ssrc".to_owned(), - value: Some("2000".to_owned()), - }, - ], - ..Default::default() - }], - ..Default::default() - }; - - assert!(!have_application_media_section(&s)); - } - - //"Application" - { - let s = SessionDescription { - media_descriptions: vec![MediaDescription { - media_name: MediaName { - media: MEDIA_SECTION_APPLICATION.to_owned(), - ..Default::default() - }, - ..Default::default() - }], - ..Default::default() - }; - - assert!(have_application_media_section(&s)); - } - - Ok(()) -} - -async fn fingerprint_test( - certificate: &RTCCertificate, - engine: &Arc, - media: &[MediaSection], - sdpmedia_description_fingerprints: bool, - expected_fingerprint_count: usize, -) -> Result<()> { - let s = SessionDescription::default(); - - let dtls_fingerprints = certificate.get_fingerprints(); - - let params = PopulateSdpParams { - media_description_fingerprint: sdpmedia_description_fingerprints, - is_icelite: false, - connection_role: ConnectionRole::Active, - ice_gathering_state: RTCIceGatheringState::New, - match_bundle_group: None, - }; - - let s = populate_sdp( - s, - &dtls_fingerprints, - engine, - &[], - &RTCIceParameters::default(), - media, - params, - ) - .await?; - - let sdparray = s.marshal(); - - assert_eq!( - sdparray.matches("sha-256").count(), - expected_fingerprint_count - ); - - Ok(()) -} - -#[tokio::test] -async fn test_media_description_fingerprints() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - let interceptor = api.interceptor_registry.build("")?; - - let kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?; - let certificate = RTCCertificate::from_key_pair(kp)?; - - let transport = Arc::new(RTCDtlsTransport::default()); - - let video_receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - let audio_receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Audio, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let video_sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let audio_sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let media = vec![ - MediaSection { - id: "video".to_owned(), - transceivers: vec![ - RTCRtpTransceiver::new( - video_receiver, - video_sender, - RTCRtpTransceiverDirection::Inactive, - RTPCodecType::Video, - api.media_engine.get_codecs_by_kind(RTPCodecType::Video), - Arc::clone(&api.media_engine), - None, - ) - .await, - ], - ..Default::default() - }, - MediaSection { - id: "audio".to_owned(), - transceivers: vec![ - RTCRtpTransceiver::new( - audio_receiver, - audio_sender, - RTCRtpTransceiverDirection::Inactive, - RTPCodecType::Audio, - api.media_engine.get_codecs_by_kind(RTPCodecType::Audio), - Arc::clone(&api.media_engine), - None, - ) - .await, - ], - ..Default::default() - }, - MediaSection { - id: "application".to_owned(), - data: true, - ..Default::default() - }, - ]; - - #[allow(clippy::needless_range_loop)] - for i in 0..2 { - let track: Arc = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: "video/vp8".to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - media[i].transceivers[0] - .set_sender(Arc::new( - RTCRtpSender::new( - api.setting_engine.get_receive_mtu(), - Some(track), - RTPCodecType::Video, - Arc::new(RTCDtlsTransport::default()), - Arc::clone(&api.media_engine), - Arc::clone(&interceptor), - false, - ) - .await, - )) - .await; - media[i].transceivers[0].set_direction_internal(RTCRtpTransceiverDirection::Sendonly); - } - - //"Per-Media Description Fingerprints", - fingerprint_test(&certificate, &api.media_engine, &media, true, 3).await?; - - //"Per-Session Description Fingerprints", - fingerprint_test(&certificate, &api.media_engine, &media, false, 1).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_populate_sdp() -> Result<()> { - //"Rid" - { - let se = SettingEngine::default(); - let mut me = MediaEngine::default(); - me.register_default_codecs()?; - - let api = APIBuilder::new().with_media_engine(me).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - - let receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let tr = RTCRtpTransceiver::new( - receiver, - sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Video, - api.media_engine.video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let rid_map = vec![ - SimulcastRid { - id: "ridkey".to_owned(), - direction: SimulcastDirection::Recv, - params: "some".to_owned(), - paused: false, - }, - SimulcastRid { - id: "ridpaused".to_owned(), - direction: SimulcastDirection::Recv, - params: "some2".to_owned(), - paused: true, - }, - ]; - let media_sections = vec![MediaSection { - id: "video".to_owned(), - transceivers: vec![tr], - data: false, - rid_map, - ..Default::default() - }]; - - let d = SessionDescription::default(); - - let params = PopulateSdpParams { - media_description_fingerprint: se.sdp_media_level_fingerprints, - is_icelite: se.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: RTCIceGatheringState::Complete, - match_bundle_group: None, - }; - let offer_sdp = populate_sdp( - d, - &[], - &api.media_engine, - &[], - &RTCIceParameters::default(), - &media_sections, - params, - ) - .await?; - - // Test contains rid map keys - let mut found = 0; - for desc in &offer_sdp.media_descriptions { - if desc.media_name.media != "video" { - continue; - } - - let rid_map = get_rids(desc); - if let Some(rid) = rid_map.iter().find(|rid| rid.id == "ridkey") { - assert!(!rid.paused, "Rid should be active"); - assert_eq!( - rid.direction, - SimulcastDirection::Send, - "Rid should be send" - ); - found += 1; - } - if let Some(rid) = rid_map.iter().find(|rid| rid.id == "ridpaused") { - assert!(rid.paused, "Rid should be paused"); - assert_eq!( - rid.direction, - SimulcastDirection::Send, - "Rid should be send" - ); - found += 1; - } - } - assert_eq!(found, 2, "All Rid key should be present"); - } - - //"SetCodecPreferences" - { - let se = SettingEngine::default(); - let mut me = MediaEngine::default(); - me.register_default_codecs()?; - me.push_codecs(me.video_codecs.clone(), RTPCodecType::Video) - .await; - me.push_codecs(me.audio_codecs.clone(), RTPCodecType::Audio) - .await; - - let api = APIBuilder::new().with_media_engine(me).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - let receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let tr = RTCRtpTransceiver::new( - receiver, - sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Video, - api.media_engine.video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - tr.set_codec_preferences(vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }]) - .await?; - - let media_sections = vec![MediaSection { - id: "video".to_owned(), - transceivers: vec![tr], - data: false, - rid_map: vec![], - ..Default::default() - }]; - - let d = SessionDescription::default(); - - let params = PopulateSdpParams { - media_description_fingerprint: se.sdp_media_level_fingerprints, - is_icelite: se.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: RTCIceGatheringState::Complete, - match_bundle_group: None, - }; - let offer_sdp = populate_sdp( - d, - &[], - &api.media_engine, - &[], - &RTCIceParameters::default(), - &media_sections, - params, - ) - .await?; - - // Test codecs - let mut found_vp8 = false; - for desc in &offer_sdp.media_descriptions { - if desc.media_name.media != "video" { - continue; - } - for a in &desc.attributes { - if a.key.contains("rtpmap") { - if let Some(value) = &a.value { - if value == "98 VP9/90000" { - panic!("vp9 should not be present in sdp"); - } else if value == "96 VP8/90000" { - found_vp8 = true; - } - } - } - } - } - assert!(found_vp8, "vp8 should be present in sdp"); - } - - //"Bundle all" - { - let se = SettingEngine::default(); - let mut me = MediaEngine::default(); - me.register_default_codecs()?; - - let api = APIBuilder::new().with_media_engine(me).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - let receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let tr = RTCRtpTransceiver::new( - receiver, - sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Video, - api.media_engine.video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let media_sections = vec![MediaSection { - id: "video".to_owned(), - transceivers: vec![tr], - data: false, - rid_map: vec![], - ..Default::default() - }]; - - let d = SessionDescription::default(); - - let params = PopulateSdpParams { - media_description_fingerprint: se.sdp_media_level_fingerprints, - is_icelite: se.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: RTCIceGatheringState::Complete, - match_bundle_group: None, - }; - let offer_sdp = populate_sdp( - d, - &[], - &api.media_engine, - &[], - &RTCIceParameters::default(), - &media_sections, - params, - ) - .await?; - - assert_eq!( - offer_sdp.attribute(ATTR_KEY_GROUP), - Some(&"BUNDLE video".to_owned()) - ); - } - - //"Bundle matched" - { - let se = SettingEngine::default(); - let mut me = MediaEngine::default(); - me.register_default_codecs()?; - - let api = APIBuilder::new().with_media_engine(me).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - - let video_receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - let audio_receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Audio, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let video_sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - let audio_sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let trv = RTCRtpTransceiver::new( - video_receiver, - video_sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Video, - api.media_engine.video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let tra = RTCRtpTransceiver::new( - audio_receiver, - audio_sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Audio, - api.media_engine.audio_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let media_sections = vec![ - MediaSection { - id: "video".to_owned(), - transceivers: vec![trv], - data: false, - rid_map: vec![], - ..Default::default() - }, - MediaSection { - id: "audio".to_owned(), - transceivers: vec![tra], - data: false, - rid_map: vec![], - ..Default::default() - }, - ]; - - let d = SessionDescription::default(); - - let params = PopulateSdpParams { - media_description_fingerprint: se.sdp_media_level_fingerprints, - is_icelite: se.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: RTCIceGatheringState::Complete, - match_bundle_group: Some("audio".to_owned()), - }; - let offer_sdp = populate_sdp( - d, - &[], - &api.media_engine, - &[], - &RTCIceParameters::default(), - &media_sections, - params, - ) - .await?; - - assert_eq!( - offer_sdp.attribute(ATTR_KEY_GROUP), - Some(&"BUNDLE audio".to_owned()) - ); - } - - //"empty bundle group" - { - let se = SettingEngine::default(); - let mut me = MediaEngine::default(); - me.register_default_codecs()?; - - let api = APIBuilder::new().with_media_engine(me).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - let receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let tr = RTCRtpTransceiver::new( - receiver, - sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Video, - api.media_engine.video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let media_sections = vec![MediaSection { - id: "video".to_owned(), - transceivers: vec![tr], - data: false, - rid_map: vec![], - ..Default::default() - }]; - - let d = SessionDescription::default(); - - let params = PopulateSdpParams { - media_description_fingerprint: se.sdp_media_level_fingerprints, - is_icelite: se.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: RTCIceGatheringState::Complete, - match_bundle_group: Some("".to_owned()), - }; - let offer_sdp = populate_sdp( - d, - &[], - &api.media_engine, - &[], - &RTCIceParameters::default(), - &media_sections, - params, - ) - .await?; - - assert_eq!(offer_sdp.attribute(ATTR_KEY_GROUP), None); - } - - Ok(()) -} - -#[tokio::test] -async fn test_populate_sdp_reject() -> Result<()> { - let se = SettingEngine::default(); - let mut me = MediaEngine::default(); - me.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90_000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 2, - stats_id: "id".to_owned(), - }, - RTPCodecType::Video, - )?; - - let api = APIBuilder::new().with_media_engine(me).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - let video_receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let video_sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let trv = RTCRtpTransceiver::new( - video_receiver, - video_sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Video, - api.media_engine.video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let audio_receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Audio, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let audio_sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let tra = RTCRtpTransceiver::new( - audio_receiver, - audio_sender, - RTCRtpTransceiverDirection::Recvonly, - RTPCodecType::Audio, - api.media_engine.audio_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - let media_sections = vec![ - MediaSection { - id: "video".to_owned(), - transceivers: vec![trv], - data: false, - rid_map: vec![], - ..Default::default() - }, - MediaSection { - id: "audio".to_owned(), - transceivers: vec![tra], - data: false, - rid_map: vec![], - ..Default::default() - }, - ]; - - let d = SessionDescription::default(); - - let params = PopulateSdpParams { - media_description_fingerprint: se.sdp_media_level_fingerprints, - is_icelite: se.candidates.ice_lite, - connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), - ice_gathering_state: RTCIceGatheringState::Complete, - match_bundle_group: None, - }; - let offer_sdp = populate_sdp( - d, - &[], - &api.media_engine, - &[], - &RTCIceParameters::default(), - &media_sections, - params, - ) - .await?; - - let mut found_rejected_track = false; - - for desc in offer_sdp.media_descriptions { - if desc.media_name.media != "audio" { - continue; - } - found_rejected_track = true; - - assert!( - desc.connection_information.is_some(), - "connection_information should not be None, even for rejected tracks" - ); - assert_eq!( - desc.media_name.formats, - vec!["0"], - "Format for rejected track should be 0" - ); - assert_eq!( - desc.media_name.port.value, 0, - "Port for rejected track should be 0" - ); - } - - assert!( - found_rejected_track, - "There should've been a rejected track" - ); - - Ok(()) -} - -#[test] -fn test_get_rids() { - let m = vec![MediaDescription { - media_name: MediaName { - media: "video".to_owned(), - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "sendonly".to_owned(), - value: None, - }, - Attribute { - key: SDP_ATTRIBUTE_RID.to_owned(), - value: Some("f send pt=97;max-width=1280;max-height=720".to_owned()), - }, - ], - ..Default::default() - }]; - - let rids = get_rids(&m[0]); - - assert!(!rids.is_empty(), "Rid mapping should be present"); - - let f = rids.iter().find(|rid| rid.id == "f"); - assert!(f.is_some(), "rid values should contain 'f'"); -} - -#[test] -fn test_codecs_from_media_description() -> Result<()> { - //"Codec Only" - { - let codecs = codecs_from_media_description(&MediaDescription { - media_name: MediaName { - media: "audio".to_owned(), - formats: vec!["111".to_owned()], - ..Default::default() - }, - attributes: vec![Attribute { - key: "rtpmap".to_owned(), - value: Some("111 opus/48000/2".to_owned()), - }], - ..Default::default() - })?; - - assert_eq!( - codecs, - vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }], - ); - } - - //"Codec with fmtp/rtcp-fb" - { - let codecs = codecs_from_media_description(&MediaDescription { - media_name: MediaName { - media: "audio".to_owned(), - formats: vec!["111".to_owned()], - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "rtpmap".to_owned(), - value: Some("111 opus/48000/2".to_owned()), - }, - Attribute { - key: "fmtp".to_owned(), - value: Some("111 minptime=10;useinbandfec=1".to_owned()), - }, - Attribute { - key: "rtcp-fb".to_owned(), - value: Some("111 goog-remb".to_owned()), - }, - Attribute { - key: "rtcp-fb".to_owned(), - value: Some("111 ccm fir".to_owned()), - }, - ], - ..Default::default() - })?; - - assert_eq!( - codecs, - vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "minptime=10;useinbandfec=1".to_owned(), - rtcp_feedback: vec![ - RTCPFeedback { - typ: "goog-remb".to_owned(), - parameter: "".to_owned() - }, - RTCPFeedback { - typ: "ccm".to_owned(), - parameter: "fir".to_owned() - } - ] - }, - payload_type: 111, - ..Default::default() - }], - ); - } - - Ok(()) -} - -#[test] -fn test_rtp_extensions_from_media_description() -> Result<()> { - let extensions = rtp_extensions_from_media_description(&MediaDescription { - media_name: MediaName { - media: "audio".to_owned(), - formats: vec!["111".to_owned()], - ..Default::default() - }, - attributes: vec![ - Attribute { - key: "extmap".to_owned(), - value: Some("1 ".to_owned() + sdp::extmap::ABS_SEND_TIME_URI), - }, - Attribute { - key: "extmap".to_owned(), - value: Some("3 ".to_owned() + sdp::extmap::SDES_MID_URI), - }, - ], - ..Default::default() - })?; - - assert_eq!(extensions[sdp::extmap::ABS_SEND_TIME_URI], 1); - assert_eq!(extensions[sdp::extmap::SDES_MID_URI], 3); - - Ok(()) -} diff --git a/webrtc/src/peer_connection/sdp/sdp_type.rs b/webrtc/src/peer_connection/sdp/sdp_type.rs deleted file mode 100644 index 830864f79..000000000 --- a/webrtc/src/peer_connection/sdp/sdp_type.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -/// SDPType describes the type of an SessionDescription. -#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] -pub enum RTCSdpType { - #[default] - Unspecified = 0, - - /// indicates that a description MUST be treated as an SDP offer. - #[serde(rename = "offer")] - Offer, - - /// indicates that a description MUST be treated as an - /// SDP answer, but not a final answer. A description used as an SDP - /// pranswer may be applied as a response to an SDP offer, or an update to - /// a previously sent SDP pranswer. - #[serde(rename = "pranswer")] - Pranswer, - - /// indicates that a description MUST be treated as an SDP - /// final answer, and the offer-answer exchange MUST be considered complete. - /// A description used as an SDP answer may be applied as a response to an - /// SDP offer or as an update to a previously sent SDP pranswer. - #[serde(rename = "answer")] - Answer, - - /// indicates that a description MUST be treated as - /// canceling the current SDP negotiation and moving the SDP offer and - /// answer back to what it was in the previous stable state. Note the - /// local or remote SDP descriptions in the previous stable state could be - /// null if there has not yet been a successful offer-answer negotiation. - #[serde(rename = "rollback")] - Rollback, -} - -const SDP_TYPE_OFFER_STR: &str = "offer"; -const SDP_TYPE_PRANSWER_STR: &str = "pranswer"; -const SDP_TYPE_ANSWER_STR: &str = "answer"; -const SDP_TYPE_ROLLBACK_STR: &str = "rollback"; - -/// creates an SDPType from a string -impl From<&str> for RTCSdpType { - fn from(raw: &str) -> Self { - match raw { - SDP_TYPE_OFFER_STR => RTCSdpType::Offer, - SDP_TYPE_PRANSWER_STR => RTCSdpType::Pranswer, - SDP_TYPE_ANSWER_STR => RTCSdpType::Answer, - SDP_TYPE_ROLLBACK_STR => RTCSdpType::Rollback, - _ => RTCSdpType::Unspecified, - } - } -} - -impl fmt::Display for RTCSdpType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCSdpType::Offer => write!(f, "{SDP_TYPE_OFFER_STR}"), - RTCSdpType::Pranswer => write!(f, "{SDP_TYPE_PRANSWER_STR}"), - RTCSdpType::Answer => write!(f, "{SDP_TYPE_ANSWER_STR}"), - RTCSdpType::Rollback => write!(f, "{SDP_TYPE_ROLLBACK_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_sdp_type() { - let tests = vec![ - ("Unspecified", RTCSdpType::Unspecified), - ("offer", RTCSdpType::Offer), - ("pranswer", RTCSdpType::Pranswer), - ("answer", RTCSdpType::Answer), - ("rollback", RTCSdpType::Rollback), - ]; - - for (sdp_type_string, expected_sdp_type) in tests { - assert_eq!(RTCSdpType::from(sdp_type_string), expected_sdp_type); - } - } - - #[test] - fn test_sdp_type_string() { - let tests = vec![ - (RTCSdpType::Unspecified, "Unspecified"), - (RTCSdpType::Offer, "offer"), - (RTCSdpType::Pranswer, "pranswer"), - (RTCSdpType::Answer, "answer"), - (RTCSdpType::Rollback, "rollback"), - ]; - - for (sdp_type, expected_string) in tests { - assert_eq!(sdp_type.to_string(), expected_string); - } - } -} diff --git a/webrtc/src/peer_connection/sdp/session_description.rs b/webrtc/src/peer_connection/sdp/session_description.rs deleted file mode 100644 index 7085f34dd..000000000 --- a/webrtc/src/peer_connection/sdp/session_description.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::io::Cursor; - -use sdp::description::session::SessionDescription; -use serde::{Deserialize, Serialize}; - -use super::sdp_type::RTCSdpType; -use crate::error::Result; - -/// SessionDescription is used to expose local and remote session descriptions. -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct RTCSessionDescription { - #[serde(rename = "type")] - pub sdp_type: RTCSdpType, - - pub sdp: String, - - /// This will never be initialized by callers, internal use only - #[serde(skip)] - pub(crate) parsed: Option, -} - -impl RTCSessionDescription { - /// Given SDP representing an answer, wrap it in an RTCSessionDescription - /// that can be given to an RTCPeerConnection. - pub fn answer(sdp: String) -> Result { - let mut desc = RTCSessionDescription { - sdp, - sdp_type: RTCSdpType::Answer, - parsed: None, - }; - - let parsed = desc.unmarshal()?; - desc.parsed = Some(parsed); - - Ok(desc) - } - - /// Given SDP representing an offer, wrap it in an RTCSessionDescription - /// that can be given to an RTCPeerConnection. - pub fn offer(sdp: String) -> Result { - let mut desc = RTCSessionDescription { - sdp, - sdp_type: RTCSdpType::Offer, - parsed: None, - }; - - let parsed = desc.unmarshal()?; - desc.parsed = Some(parsed); - - Ok(desc) - } - - /// Given SDP representing an answer, wrap it in an RTCSessionDescription - /// that can be given to an RTCPeerConnection. `pranswer` is used when the - /// answer may not be final, or when updating a previously sent pranswer. - pub fn pranswer(sdp: String) -> Result { - let mut desc = RTCSessionDescription { - sdp, - sdp_type: RTCSdpType::Pranswer, - parsed: None, - }; - - let parsed = desc.unmarshal()?; - desc.parsed = Some(parsed); - - Ok(desc) - } - - /// Unmarshal is a helper to deserialize the sdp - pub fn unmarshal(&self) -> Result { - let mut reader = Cursor::new(self.sdp.as_bytes()); - let parsed = SessionDescription::unmarshal(&mut reader)?; - Ok(parsed) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::api::media_engine::MediaEngine; - use crate::api::APIBuilder; - use crate::peer_connection::configuration::RTCConfiguration; - - #[test] - fn test_session_description_json() { - let tests = vec![ - ( - RTCSessionDescription { - sdp_type: RTCSdpType::Offer, - sdp: "sdp".to_owned(), - parsed: None, - }, - r#"{"type":"offer","sdp":"sdp"}"#, - ), - ( - RTCSessionDescription { - sdp_type: RTCSdpType::Pranswer, - sdp: "sdp".to_owned(), - parsed: None, - }, - r#"{"type":"pranswer","sdp":"sdp"}"#, - ), - ( - RTCSessionDescription { - sdp_type: RTCSdpType::Answer, - sdp: "sdp".to_owned(), - parsed: None, - }, - r#"{"type":"answer","sdp":"sdp"}"#, - ), - ( - RTCSessionDescription { - sdp_type: RTCSdpType::Rollback, - sdp: "sdp".to_owned(), - parsed: None, - }, - r#"{"type":"rollback","sdp":"sdp"}"#, - ), - ( - RTCSessionDescription { - sdp_type: RTCSdpType::Unspecified, - sdp: "sdp".to_owned(), - parsed: None, - }, - r#"{"type":"Unspecified","sdp":"sdp"}"#, - ), - ]; - - for (desc, expected_string) in tests { - let result = serde_json::to_string(&desc); - assert!(result.is_ok(), "testCase: marshal err: {result:?}"); - let desc_data = result.unwrap(); - assert_eq!(desc_data, expected_string, "string is not expected"); - - let result = serde_json::from_str::(&desc_data); - assert!(result.is_ok(), "testCase: unmarshal err: {result:?}"); - if let Ok(sd) = result { - assert!(sd.sdp == desc.sdp && sd.sdp_type == desc.sdp_type); - } - } - } - - #[tokio::test] - async fn test_session_description_answer() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let offer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; - let answer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let _ = offer_pc.create_data_channel("foo", None).await?; - let offer = offer_pc.create_offer(None).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - - let desc = RTCSessionDescription::answer(answer.sdp.clone())?; - - assert!(desc.sdp_type == RTCSdpType::Answer); - assert!(desc.parsed.is_some()); - - assert_eq!(answer.unmarshal()?.marshal(), desc.unmarshal()?.marshal()); - - Ok(()) - } - - #[tokio::test] - async fn test_session_description_offer() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let pc = api.new_peer_connection(RTCConfiguration::default()).await?; - let offer = pc.create_offer(None).await?; - - let desc = RTCSessionDescription::offer(offer.sdp.clone())?; - - assert!(desc.sdp_type == RTCSdpType::Offer); - assert!(desc.parsed.is_some()); - - assert_eq!(offer.unmarshal()?.marshal(), desc.unmarshal()?.marshal()); - - Ok(()) - } - - #[tokio::test] - async fn test_session_description_pranswer() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let offer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; - let answer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let _ = offer_pc.create_data_channel("foo", None).await?; - let offer = offer_pc.create_offer(None).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - - let desc = RTCSessionDescription::pranswer(answer.sdp)?; - - assert!(desc.sdp_type == RTCSdpType::Pranswer); - assert!(desc.parsed.is_some()); - - Ok(()) - } - - #[tokio::test] - async fn test_session_description_unmarshal() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let offer = pc.create_offer(None).await?; - - let desc = RTCSessionDescription { - sdp_type: offer.sdp_type, - sdp: offer.sdp, - ..Default::default() - }; - - assert!(desc.parsed.is_none()); - - let parsed1 = desc.unmarshal()?; - let parsed2 = desc.unmarshal()?; - - pc.close().await?; - - // check if the two parsed results _really_ match, could be affected by internal caching - assert_eq!(parsed1.marshal(), parsed2.marshal()); - - Ok(()) - } -} diff --git a/webrtc/src/peer_connection/signaling_state.rs b/webrtc/src/peer_connection/signaling_state.rs deleted file mode 100644 index 2edd613be..000000000 --- a/webrtc/src/peer_connection/signaling_state.rs +++ /dev/null @@ -1,365 +0,0 @@ -use std::fmt; - -use crate::error::{Error, Result}; -use crate::peer_connection::sdp::sdp_type::RTCSdpType; - -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub(crate) enum StateChangeOp { - #[default] - SetLocal, - SetRemote, -} - -impl fmt::Display for StateChangeOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - StateChangeOp::SetLocal => write!(f, "SetLocal"), - StateChangeOp::SetRemote => write!(f, "SetRemote"), - //_ => write!(f, UNSPECIFIED_STR), - } - } -} - -/// SignalingState indicates the signaling state of the offer/answer process. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCSignalingState { - #[default] - Unspecified = 0, - - /// SignalingStateStable indicates there is no offer/answer exchange in - /// progress. This is also the initial state, in which case the local and - /// remote descriptions are nil. - Stable, - - /// SignalingStateHaveLocalOffer indicates that a local description, of - /// type "offer", has been successfully applied. - HaveLocalOffer, - - /// SignalingStateHaveRemoteOffer indicates that a remote description, of - /// type "offer", has been successfully applied. - HaveRemoteOffer, - - /// SignalingStateHaveLocalPranswer indicates that a remote description - /// of type "offer" has been successfully applied and a local description - /// of type "pranswer" has been successfully applied. - HaveLocalPranswer, - - /// SignalingStateHaveRemotePranswer indicates that a local description - /// of type "offer" has been successfully applied and a remote description - /// of type "pranswer" has been successfully applied. - HaveRemotePranswer, - - /// SignalingStateClosed indicates The PeerConnection has been closed. - Closed, -} - -const SIGNALING_STATE_STABLE_STR: &str = "stable"; -const SIGNALING_STATE_HAVE_LOCAL_OFFER_STR: &str = "have-local-offer"; -const SIGNALING_STATE_HAVE_REMOTE_OFFER_STR: &str = "have-remote-offer"; -const SIGNALING_STATE_HAVE_LOCAL_PRANSWER_STR: &str = "have-local-pranswer"; -const SIGNALING_STATE_HAVE_REMOTE_PRANSWER_STR: &str = "have-remote-pranswer"; -const SIGNALING_STATE_CLOSED_STR: &str = "closed"; - -impl From<&str> for RTCSignalingState { - fn from(raw: &str) -> Self { - match raw { - SIGNALING_STATE_STABLE_STR => RTCSignalingState::Stable, - SIGNALING_STATE_HAVE_LOCAL_OFFER_STR => RTCSignalingState::HaveLocalOffer, - SIGNALING_STATE_HAVE_REMOTE_OFFER_STR => RTCSignalingState::HaveRemoteOffer, - SIGNALING_STATE_HAVE_LOCAL_PRANSWER_STR => RTCSignalingState::HaveLocalPranswer, - SIGNALING_STATE_HAVE_REMOTE_PRANSWER_STR => RTCSignalingState::HaveRemotePranswer, - SIGNALING_STATE_CLOSED_STR => RTCSignalingState::Closed, - _ => RTCSignalingState::Unspecified, - } - } -} - -impl fmt::Display for RTCSignalingState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCSignalingState::Stable => write!(f, "{SIGNALING_STATE_STABLE_STR}"), - RTCSignalingState::HaveLocalOffer => { - write!(f, "{SIGNALING_STATE_HAVE_LOCAL_OFFER_STR}") - } - RTCSignalingState::HaveRemoteOffer => { - write!(f, "{SIGNALING_STATE_HAVE_REMOTE_OFFER_STR}") - } - RTCSignalingState::HaveLocalPranswer => { - write!(f, "{SIGNALING_STATE_HAVE_LOCAL_PRANSWER_STR}") - } - RTCSignalingState::HaveRemotePranswer => { - write!(f, "{SIGNALING_STATE_HAVE_REMOTE_PRANSWER_STR}") - } - RTCSignalingState::Closed => write!(f, "{SIGNALING_STATE_CLOSED_STR}"), - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -impl From for RTCSignalingState { - fn from(v: u8) -> Self { - match v { - 1 => RTCSignalingState::Stable, - 2 => RTCSignalingState::HaveLocalOffer, - 3 => RTCSignalingState::HaveRemoteOffer, - 4 => RTCSignalingState::HaveLocalPranswer, - 5 => RTCSignalingState::HaveRemotePranswer, - 6 => RTCSignalingState::Closed, - _ => RTCSignalingState::Unspecified, - } - } -} - -pub(crate) fn check_next_signaling_state( - cur: RTCSignalingState, - next: RTCSignalingState, - op: StateChangeOp, - sdp_type: RTCSdpType, -) -> Result { - // Special case for rollbacks - if sdp_type == RTCSdpType::Rollback && cur == RTCSignalingState::Stable { - return Err(Error::ErrSignalingStateCannotRollback); - } - - // 4.3.1 valid state transitions - match cur { - RTCSignalingState::Stable => { - match op { - StateChangeOp::SetLocal => { - // stable->SetLocal(offer)->have-local-offer - if sdp_type == RTCSdpType::Offer && next == RTCSignalingState::HaveLocalOffer { - return Ok(next); - } - } - StateChangeOp::SetRemote => { - // stable->SetRemote(offer)->have-remote-offer - if sdp_type == RTCSdpType::Offer && next == RTCSignalingState::HaveRemoteOffer { - return Ok(next); - } - } - } - } - RTCSignalingState::HaveLocalOffer => { - if op == StateChangeOp::SetRemote { - match sdp_type { - // have-local-offer->SetRemote(answer)->stable - RTCSdpType::Answer => { - if next == RTCSignalingState::Stable { - return Ok(next); - } - } - // have-local-offer->SetRemote(pranswer)->have-remote-pranswer - RTCSdpType::Pranswer => { - if next == RTCSignalingState::HaveRemotePranswer { - return Ok(next); - } - } - _ => {} - } - } else if op == StateChangeOp::SetLocal - && sdp_type == RTCSdpType::Offer - && next == RTCSignalingState::HaveLocalOffer - { - return Ok(next); - } - } - RTCSignalingState::HaveRemotePranswer => { - if op == StateChangeOp::SetRemote && sdp_type == RTCSdpType::Answer { - // have-remote-pranswer->SetRemote(answer)->stable - if next == RTCSignalingState::Stable { - return Ok(next); - } - } - } - RTCSignalingState::HaveRemoteOffer => { - if op == StateChangeOp::SetLocal { - match sdp_type { - // have-remote-offer->SetLocal(answer)->stable - RTCSdpType::Answer => { - if next == RTCSignalingState::Stable { - return Ok(next); - } - } - // have-remote-offer->SetLocal(pranswer)->have-local-pranswer - RTCSdpType::Pranswer => { - if next == RTCSignalingState::HaveLocalPranswer { - return Ok(next); - } - } - _ => {} - } - } - } - RTCSignalingState::HaveLocalPranswer => { - if op == StateChangeOp::SetLocal && sdp_type == RTCSdpType::Answer { - // have-local-pranswer->SetLocal(answer)->stable - if next == RTCSignalingState::Stable { - return Ok(next); - } - } - } - _ => { - return Err(Error::ErrSignalingStateProposedTransitionInvalid { - from: cur, - applying: sdp_type, - is_local: op == StateChangeOp::SetLocal, - }); - } - }; - - Err(Error::ErrSignalingStateProposedTransitionInvalid { - from: cur, - is_local: op == StateChangeOp::SetLocal, - applying: sdp_type, - }) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_signaling_state() { - let tests = vec![ - ("Unspecified", RTCSignalingState::Unspecified), - ("stable", RTCSignalingState::Stable), - ("have-local-offer", RTCSignalingState::HaveLocalOffer), - ("have-remote-offer", RTCSignalingState::HaveRemoteOffer), - ("have-local-pranswer", RTCSignalingState::HaveLocalPranswer), - ( - "have-remote-pranswer", - RTCSignalingState::HaveRemotePranswer, - ), - ("closed", RTCSignalingState::Closed), - ]; - - for (state_string, expected_state) in tests { - assert_eq!(RTCSignalingState::from(state_string), expected_state); - } - } - - #[test] - fn test_signaling_state_string() { - let tests = vec![ - (RTCSignalingState::Unspecified, "Unspecified"), - (RTCSignalingState::Stable, "stable"), - (RTCSignalingState::HaveLocalOffer, "have-local-offer"), - (RTCSignalingState::HaveRemoteOffer, "have-remote-offer"), - (RTCSignalingState::HaveLocalPranswer, "have-local-pranswer"), - ( - RTCSignalingState::HaveRemotePranswer, - "have-remote-pranswer", - ), - (RTCSignalingState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string); - } - } - - #[test] - fn test_signaling_state_transitions() { - let tests = vec![ - ( - "stable->SetLocal(offer)->have-local-offer", - RTCSignalingState::Stable, - RTCSignalingState::HaveLocalOffer, - StateChangeOp::SetLocal, - RTCSdpType::Offer, - None, - ), - ( - "stable->SetRemote(offer)->have-remote-offer", - RTCSignalingState::Stable, - RTCSignalingState::HaveRemoteOffer, - StateChangeOp::SetRemote, - RTCSdpType::Offer, - None, - ), - ( - "have-local-offer->SetRemote(answer)->stable", - RTCSignalingState::HaveLocalOffer, - RTCSignalingState::Stable, - StateChangeOp::SetRemote, - RTCSdpType::Answer, - None, - ), - ( - "have-local-offer->SetRemote(pranswer)->have-remote-pranswer", - RTCSignalingState::HaveLocalOffer, - RTCSignalingState::HaveRemotePranswer, - StateChangeOp::SetRemote, - RTCSdpType::Pranswer, - None, - ), - ( - "have-remote-pranswer->SetRemote(answer)->stable", - RTCSignalingState::HaveRemotePranswer, - RTCSignalingState::Stable, - StateChangeOp::SetRemote, - RTCSdpType::Answer, - None, - ), - ( - "have-remote-offer->SetLocal(answer)->stable", - RTCSignalingState::HaveRemoteOffer, - RTCSignalingState::Stable, - StateChangeOp::SetLocal, - RTCSdpType::Answer, - None, - ), - ( - "have-remote-offer->SetLocal(pranswer)->have-local-pranswer", - RTCSignalingState::HaveRemoteOffer, - RTCSignalingState::HaveLocalPranswer, - StateChangeOp::SetLocal, - RTCSdpType::Pranswer, - None, - ), - ( - "have-local-pranswer->SetLocal(answer)->stable", - RTCSignalingState::HaveLocalPranswer, - RTCSignalingState::Stable, - StateChangeOp::SetLocal, - RTCSdpType::Answer, - None, - ), - ( - "(invalid) stable->SetRemote(pranswer)->have-remote-pranswer", - RTCSignalingState::Stable, - RTCSignalingState::HaveRemotePranswer, - StateChangeOp::SetRemote, - RTCSdpType::Pranswer, - Some(Error::ErrSignalingStateProposedTransitionInvalid { - from: RTCSignalingState::Stable, - is_local: false, - applying: RTCSdpType::Pranswer, - }), - ), - ( - "(invalid) stable->SetRemote(rollback)->have-local-offer", - RTCSignalingState::Stable, - RTCSignalingState::HaveLocalOffer, - StateChangeOp::SetRemote, - RTCSdpType::Rollback, - Some(Error::ErrSignalingStateCannotRollback), - ), - ]; - - for (desc, cur, next, op, sdp_type, expected_err) in tests { - let result = check_next_signaling_state(cur, next, op, sdp_type); - match (&result, &expected_err) { - (Ok(got), None) => { - assert_eq!(*got, next, "{desc} state mismatch"); - } - (Err(got), Some(err)) => { - assert_eq!(got.to_string(), err.to_string(), "{desc} error mismatch"); - } - _ => { - panic!("{desc}: expected {expected_err:?}, but got {result:?}"); - } - }; - } - } -} diff --git a/webrtc/src/rtp_transceiver/fmtp/generic/generic_test.rs b/webrtc/src/rtp_transceiver/fmtp/generic/generic_test.rs deleted file mode 100644 index f37c4af5c..000000000 --- a/webrtc/src/rtp_transceiver/fmtp/generic/generic_test.rs +++ /dev/null @@ -1,160 +0,0 @@ -use super::*; - -#[test] -fn test_generic_fmtp_parse() { - let tests: Vec<(&str, &str, Box)> = vec![ - ( - "OneParam", - "key-name=value", - Box::new(GenericFmtp { - mime_type: "generic".to_owned(), - parameters: [("key-name".to_owned(), "value".to_owned())] - .iter() - .cloned() - .collect(), - }), - ), - ( - "OneParamWithWhiteSpeces", - "\tkey-name=value ", - Box::new(GenericFmtp { - mime_type: "generic".to_owned(), - parameters: [("key-name".to_owned(), "value".to_owned())] - .iter() - .cloned() - .collect(), - }), - ), - ( - "TwoParams", - "key-name=value;key2=value2", - Box::new(GenericFmtp { - mime_type: "generic".to_owned(), - parameters: [ - ("key-name".to_owned(), "value".to_owned()), - ("key2".to_owned(), "value2".to_owned()), - ] - .iter() - .cloned() - .collect(), - }), - ), - ( - "TwoParamsWithWhiteSpeces", - "key-name=value; \n\tkey2=value2 ", - Box::new(GenericFmtp { - mime_type: "generic".to_owned(), - parameters: [ - ("key-name".to_owned(), "value".to_owned()), - ("key2".to_owned(), "value2".to_owned()), - ] - .iter() - .cloned() - .collect(), - }), - ), - ]; - - for (name, input, expected) in tests { - let f = parse("generic", input); - assert_eq!(&f, &expected, "{name} failed"); - - assert_eq!(f.mime_type(), "generic"); - } -} - -#[test] -fn test_generic_fmtp_compare() { - let consist_string: HashMap = [ - (true, "consist".to_owned()), - (false, "inconsist".to_owned()), - ] - .iter() - .cloned() - .collect(); - - let tests = vec![ - ( - "Equal", - "key1=value1;key2=value2;key3=value3", - "key1=value1;key2=value2;key3=value3", - true, - ), - ( - "EqualWithWhitespaceVariants", - "key1=value1;key2=value2;key3=value3", - " key1=value1; \nkey2=value2;\t\nkey3=value3", - true, - ), - ( - "EqualWithCase", - "key1=value1;key2=value2;key3=value3", - "key1=value1;key2=Value2;Key3=value3", - true, - ), - ( - "OneHasExtraParam", - "key1=value1;key2=value2;key3=value3", - "key1=value1;key2=value2;key3=value3;key4=value4", - true, - ), - ( - "Inconsistent", - "key1=value1;key2=value2;key3=value3", - "key1=value1;key2=different_value;key3=value3", - false, - ), - ( - "Inconsistent_OneHasExtraParam", - "key1=value1;key2=value2;key3=value3;key4=value4", - "key1=value1;key2=different_value;key3=value3", - false, - ), - ]; - - for (name, a, b, consist) in tests { - let check = |a, b| { - let aa = parse("", a); - let bb = parse("", b); - - // test forward case here - let c = aa.match_fmtp(&*bb); - assert_eq!( - c, - consist, - "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", - name, - a, - b, - consist_string.get(&consist), - consist_string.get(&c), - ); - - // test reverse case here - let c = bb.match_fmtp(&*aa); - assert_eq!( - c, - consist, - "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", - name, - a, - b, - consist_string.get(&consist), - consist_string.get(&c), - ); - }; - - check(a, b); - } -} - -#[test] -fn test_generic_fmtp_compare_mime_type_case_mismatch() { - let a = parse("video/vp8", ""); - let b = parse("video/VP8", ""); - - assert!( - b.match_fmtp(&*a), - "fmtp lines should match even if they use different casing" - ); -} diff --git a/webrtc/src/rtp_transceiver/fmtp/generic/mod.rs b/webrtc/src/rtp_transceiver/fmtp/generic/mod.rs deleted file mode 100644 index 4b75ce350..000000000 --- a/webrtc/src/rtp_transceiver/fmtp/generic/mod.rs +++ /dev/null @@ -1,65 +0,0 @@ -#[cfg(test)] -mod generic_test; - -use super::*; - -/// fmtp_consist checks that two FMTP parameters are not inconsistent. -fn fmtp_consist(a: &HashMap, b: &HashMap) -> bool { - //TODO: add unicode case-folding equal support - for (k, v) in a { - if let Some(vb) = b.get(k) { - if vb.to_uppercase() != v.to_uppercase() { - return false; - } - } - } - for (k, v) in b { - if let Some(va) = a.get(k) { - if va.to_uppercase() != v.to_uppercase() { - return false; - } - } - } - true -} - -#[derive(Debug, PartialEq)] -pub(crate) struct GenericFmtp { - pub(crate) mime_type: String, - pub(crate) parameters: HashMap, -} - -impl Fmtp for GenericFmtp { - fn mime_type(&self) -> &str { - self.mime_type.as_str() - } - - /// Match returns true if g and b are compatible fmtp descriptions - /// The generic implementation is used for MimeTypes that are not defined - fn match_fmtp(&self, f: &(dyn Fmtp)) -> bool { - if let Some(c) = f.as_any().downcast_ref::() { - if self.mime_type.to_lowercase() != c.mime_type().to_lowercase() { - return false; - } - - fmtp_consist(&self.parameters, &c.parameters) - } else { - false - } - } - - fn parameter(&self, key: &str) -> Option<&String> { - self.parameters.get(key) - } - - fn equal(&self, other: &(dyn Fmtp)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn as_any(&self) -> &(dyn Any) { - self - } -} diff --git a/webrtc/src/rtp_transceiver/fmtp/h264/h264_test.rs b/webrtc/src/rtp_transceiver/fmtp/h264/h264_test.rs deleted file mode 100644 index fe97dbe05..000000000 --- a/webrtc/src/rtp_transceiver/fmtp/h264/h264_test.rs +++ /dev/null @@ -1,163 +0,0 @@ -use super::*; - -#[test] -fn test_h264_fmtp_parse() { - let tests: Vec<(&str, &str, Box)> = vec![ - ( - "OneParam", - "key-name=value", - Box::new(H264Fmtp { - parameters: [("key-name".to_owned(), "value".to_owned())] - .iter() - .cloned() - .collect(), - }), - ), - ( - "OneParamWithWhiteSpeces", - "\tkey-name=value ", - Box::new(H264Fmtp { - parameters: [("key-name".to_owned(), "value".to_owned())] - .iter() - .cloned() - .collect(), - }), - ), - ( - "TwoParams", - "key-name=value;key2=value2", - Box::new(H264Fmtp { - parameters: [ - ("key-name".to_owned(), "value".to_owned()), - ("key2".to_owned(), "value2".to_owned()), - ] - .iter() - .cloned() - .collect(), - }), - ), - ( - "TwoParamsWithWhiteSpeces", - "key-name=value; \n\tkey2=value2 ", - Box::new(H264Fmtp { - parameters: [ - ("key-name".to_owned(), "value".to_owned()), - ("key2".to_owned(), "value2".to_owned()), - ] - .iter() - .cloned() - .collect(), - }), - ), - ]; - - for (name, input, expected) in tests { - let f = parse("video/h264", input); - assert_eq!(&f, &expected, "{name} failed"); - - assert_eq!(f.mime_type(), "video/h264"); - } -} - -#[test] -fn test_h264_fmtp_compare() { - let consist_string: HashMap = [ - (true, "consist".to_owned()), - (false, "inconsist".to_owned()), - ] - .iter() - .cloned() - .collect(); - - let tests = vec![ - ( - "Equal", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - true, - ), - ( - "EqualWithWhitespaceVariants", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - " level-asymmetry-allowed=1; \npacketization-mode=1;\t\nprofile-level-id=42e01f", - true, - ), - ( - "EqualWithCase", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - "level-asymmetry-allowed=1;packetization-mode=1;PROFILE-LEVEL-ID=42e01f", - true, - ), - ( - "OneHasExtraParam", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - "packetization-mode=1;profile-level-id=42e01f", - true, - ), - ( - "DifferentProfileLevelIDVersions", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - "packetization-mode=1;profile-level-id=42e029", - true, - ), - ( - "Inconsistent", - "packetization-mode=1;profile-level-id=42e029", - "packetization-mode=0;profile-level-id=42e029", - false, - ), - ( - "Inconsistent_MissingPacketizationMode", - "packetization-mode=1;profile-level-id=42e029", - "profile-level-id=42e029", - false, - ), - ( - "Inconsistent_MissingProfileLevelID", - "packetization-mode=1;profile-level-id=42e029", - "packetization-mode=1", - false, - ), - ( - "Inconsistent_InvalidProfileLevelID", - "packetization-mode=1;profile-level-id=42e029", - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=41e029", - false, - ), - ]; - - for (name, a, b, consist) in tests { - let check = |a, b| { - let aa = parse("video/h264", a); - let bb = parse("video/h264", b); - - // test forward case here - let c = aa.match_fmtp(&*bb); - assert_eq!( - c, - consist, - "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", - name, - a, - b, - consist_string.get(&consist), - consist_string.get(&c), - ); - - // test reverse case here - let c = bb.match_fmtp(&*aa); - assert_eq!( - c, - consist, - "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", - name, - a, - b, - consist_string.get(&consist), - consist_string.get(&c), - ); - }; - - check(a, b); - } -} diff --git a/webrtc/src/rtp_transceiver/fmtp/h264/mod.rs b/webrtc/src/rtp_transceiver/fmtp/h264/mod.rs deleted file mode 100644 index 22a5bf594..000000000 --- a/webrtc/src/rtp_transceiver/fmtp/h264/mod.rs +++ /dev/null @@ -1,102 +0,0 @@ -#[cfg(test)] -mod h264_test; - -use super::*; - -fn profile_level_id_matches(a: &str, b: &str) -> bool { - let aa = match hex::decode(a) { - Ok(aa) => { - if aa.len() < 2 { - return false; - } - aa - } - Err(_) => return false, - }; - - let bb = match hex::decode(b) { - Ok(bb) => { - if bb.len() < 2 { - return false; - } - bb - } - Err(_) => return false, - }; - - aa[0] == bb[0] && aa[1] == bb[1] -} - -#[derive(Debug, PartialEq)] -pub(crate) struct H264Fmtp { - pub(crate) parameters: HashMap, -} - -impl Fmtp for H264Fmtp { - fn mime_type(&self) -> &str { - "video/h264" - } - - /// Match returns true if h and b are compatible fmtp descriptions - /// Based on RFC6184 Section 8.2.2: - /// The parameters identifying a media format configuration for H.264 - /// are profile-level-id and packetization-mode. These media format - /// configuration parameters (except for the level part of profile- - /// level-id) MUST be used symmetrically; that is, the answerer MUST - /// either maintain all configuration parameters or remove the media - /// format (payload type) completely if one or more of the parameter - /// values are not supported. - /// Informative note: The requirement for symmetric use does not - /// apply for the level part of profile-level-id and does not apply - /// for the other stream properties and capability parameters. - fn match_fmtp(&self, f: &(dyn Fmtp)) -> bool { - if let Some(c) = f.as_any().downcast_ref::() { - // test packetization-mode - let hpmode = match self.parameters.get("packetization-mode") { - Some(s) => s, - None => return false, - }; - let cpmode = match c.parameters.get("packetization-mode") { - Some(s) => s, - None => return false, - }; - - if hpmode != cpmode { - return false; - } - - // test profile-level-id - let hplid = match self.parameters.get("profile-level-id") { - Some(s) => s, - None => return false, - }; - let cplid = match c.parameters.get("profile-level-id") { - Some(s) => s, - None => return false, - }; - - if !profile_level_id_matches(hplid, cplid) { - return false; - } - - true - } else { - false - } - } - - fn parameter(&self, key: &str) -> Option<&String> { - self.parameters.get(key) - } - - fn equal(&self, other: &(dyn Fmtp)) -> bool { - other - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) - } - - fn as_any(&self) -> &(dyn Any) { - self - } -} diff --git a/webrtc/src/rtp_transceiver/fmtp/mod.rs b/webrtc/src/rtp_transceiver/fmtp/mod.rs deleted file mode 100644 index ea6e83b39..000000000 --- a/webrtc/src/rtp_transceiver/fmtp/mod.rs +++ /dev/null @@ -1,58 +0,0 @@ -pub(crate) mod generic; -pub(crate) mod h264; - -use std::any::Any; -use std::collections::HashMap; -use std::fmt; - -use crate::rtp_transceiver::fmtp::generic::GenericFmtp; -use crate::rtp_transceiver::fmtp::h264::H264Fmtp; - -/// Fmtp interface for implementing custom -/// Fmtp parsers based on mime_type -pub trait Fmtp: fmt::Debug { - /// mime_type returns the mime_type associated with - /// the fmtp - fn mime_type(&self) -> &str; - - /// match_fmtp compares two fmtp descriptions for - /// compatibility based on the mime_type - fn match_fmtp(&self, f: &(dyn Fmtp)) -> bool; - - /// parameter returns a value for the associated key - /// if contained in the parsed fmtp string - fn parameter(&self, key: &str) -> Option<&String>; - - fn equal(&self, other: &(dyn Fmtp)) -> bool; - fn as_any(&self) -> &(dyn Any); -} - -impl PartialEq for dyn Fmtp { - fn eq(&self, other: &Self) -> bool { - self.equal(other) - } -} - -/// parse parses an fmtp string based on the MimeType -pub fn parse(mime_type: &str, line: &str) -> Box { - let mut parameters = HashMap::new(); - for p in line.split(';').collect::>() { - let pp: Vec<&str> = p.trim().splitn(2, '=').collect(); - let key = pp[0].to_lowercase(); - let value = if pp.len() > 1 { - pp[1].to_owned() - } else { - String::new() - }; - parameters.insert(key, value); - } - - if mime_type.to_uppercase() == "video/h264".to_uppercase() { - Box::new(H264Fmtp { parameters }) - } else { - Box::new(GenericFmtp { - mime_type: mime_type.to_owned(), - parameters, - }) - } -} diff --git a/webrtc/src/rtp_transceiver/mod.rs b/webrtc/src/rtp_transceiver/mod.rs deleted file mode 100644 index e634cc2d4..000000000 --- a/webrtc/src/rtp_transceiver/mod.rs +++ /dev/null @@ -1,561 +0,0 @@ -#[cfg(test)] -mod rtp_transceiver_test; - -use std::fmt; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use interceptor::stream_info::{RTPHeaderExtension, StreamInfo}; -use interceptor::Attributes; -use log::trace; -use portable_atomic::{AtomicBool, AtomicU8}; -use serde::{Deserialize, Serialize}; -use smol_str::SmolStr; -use tokio::sync::{Mutex, OnceCell}; -use util::Unmarshal; - -use crate::api::media_engine::MediaEngine; -use crate::error::{Error, Result}; -use crate::rtp_transceiver::rtp_codec::*; -use crate::rtp_transceiver::rtp_receiver::{RTCRtpReceiver, RTPReceiverInternal}; -use crate::rtp_transceiver::rtp_sender::RTCRtpSender; -use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use crate::track::track_local::TrackLocal; - -pub(crate) mod fmtp; -pub mod rtp_codec; -pub mod rtp_receiver; -pub mod rtp_sender; -pub mod rtp_transceiver_direction; -pub(crate) mod srtp_writer_future; - -/// SSRC represents a synchronization source -/// A synchronization source is a randomly chosen -/// value meant to be globally unique within a particular -/// RTP session. Used to identify a single stream of media. -/// -#[allow(clippy::upper_case_acronyms)] -pub type SSRC = u32; - -/// PayloadType identifies the format of the RTP payload and determines -/// its interpretation by the application. Each codec in a RTP Session -/// will have a different PayloadType -/// -pub type PayloadType = u8; - -/// TYPE_RTCP_FBT_RANSPORT_CC .. -pub const TYPE_RTCP_FB_TRANSPORT_CC: &str = "transport-cc"; - -/// TYPE_RTCP_FB_GOOG_REMB .. -pub const TYPE_RTCP_FB_GOOG_REMB: &str = "goog-remb"; - -/// TYPE_RTCP_FB_ACK .. -pub const TYPE_RTCP_FB_ACK: &str = "ack"; - -/// TYPE_RTCP_FB_CCM .. -pub const TYPE_RTCP_FB_CCM: &str = "ccm"; - -/// TYPE_RTCP_FB_NACK .. -pub const TYPE_RTCP_FB_NACK: &str = "nack"; - -/// rtcpfeedback signals the connection to use additional RTCP packet types. -/// -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct RTCPFeedback { - /// Type is the type of feedback. - /// see: - /// valid: ack, ccm, nack, goog-remb, transport-cc - pub typ: String, - - /// The parameter value depends on the type. - /// For example, type="nack" parameter="pli" will send Picture Loss Indicator packets. - pub parameter: String, -} - -/// RTPCapabilities represents the capabilities of a transceiver -/// -#[derive(Default, Debug, Clone)] -pub struct RTCRtpCapabilities { - pub codecs: Vec, - pub header_extensions: Vec, -} - -/// RTPRtxParameters dictionary contains information relating to retransmission (RTX) settings. -/// -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct RTCRtpRtxParameters { - pub ssrc: SSRC, -} - -/// RTPCodingParameters provides information relating to both encoding and decoding. -/// This is a subset of the RFC since Pion WebRTC doesn't implement encoding/decoding itself -/// -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct RTCRtpCodingParameters { - pub rid: SmolStr, - pub ssrc: SSRC, - pub payload_type: PayloadType, - pub rtx: RTCRtpRtxParameters, -} - -/// RTPDecodingParameters provides information relating to both encoding and decoding. -/// This is a subset of the RFC since Pion WebRTC doesn't implement decoding itself -/// -pub type RTCRtpDecodingParameters = RTCRtpCodingParameters; - -/// RTPEncodingParameters provides information relating to both encoding and decoding. -/// This is a subset of the RFC since Pion WebRTC doesn't implement encoding itself -/// -pub type RTCRtpEncodingParameters = RTCRtpCodingParameters; - -/// RTPReceiveParameters contains the RTP stack settings used by receivers -#[derive(Debug)] -pub struct RTCRtpReceiveParameters { - pub encodings: Vec, -} - -/// RTPSendParameters contains the RTP stack settings used by receivers -#[derive(Debug)] -pub struct RTCRtpSendParameters { - pub rtp_parameters: RTCRtpParameters, - pub encodings: Vec, -} - -/// RTPTransceiverInit dictionary is used when calling the WebRTC function addTransceiver() to provide configuration options for the new transceiver. -pub struct RTCRtpTransceiverInit { - pub direction: RTCRtpTransceiverDirection, - pub send_encodings: Vec, - // Streams []*Track -} - -pub(crate) fn create_stream_info( - id: String, - ssrc: SSRC, - payload_type: PayloadType, - codec: RTCRtpCodecCapability, - webrtc_header_extensions: &[RTCRtpHeaderExtensionParameters], -) -> StreamInfo { - let header_extensions: Vec = webrtc_header_extensions - .iter() - .map(|h| RTPHeaderExtension { - id: h.id, - uri: h.uri.clone(), - }) - .collect(); - - let feedbacks: Vec<_> = codec - .rtcp_feedback - .iter() - .map(|f| interceptor::stream_info::RTCPFeedback { - typ: f.typ.clone(), - parameter: f.parameter.clone(), - }) - .collect(); - - StreamInfo { - id, - attributes: Attributes::new(), - ssrc, - payload_type, - rtp_header_extensions: header_extensions, - mime_type: codec.mime_type, - clock_rate: codec.clock_rate, - channels: codec.channels, - sdp_fmtp_line: codec.sdp_fmtp_line, - rtcp_feedback: feedbacks, - } -} - -pub type TriggerNegotiationNeededFnOption = - Option Pin + Send + Sync>> + Send + Sync>>; - -/// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid. -pub struct RTCRtpTransceiver { - mid: OnceCell, //atomic.Value - sender: Mutex>, //atomic.Value - receiver: Mutex>, //atomic.Value - - direction: AtomicU8, //RTPTransceiverDirection - current_direction: AtomicU8, //RTPTransceiverDirection - - codecs: Arc>>, // User provided codecs via set_codec_preferences - - pub(crate) stopped: AtomicBool, - pub(crate) kind: RTPCodecType, - - media_engine: Arc, - - trigger_negotiation_needed: Mutex, -} - -impl RTCRtpTransceiver { - pub async fn new( - receiver: Arc, - sender: Arc, - direction: RTCRtpTransceiverDirection, - kind: RTPCodecType, - codecs: Vec, - media_engine: Arc, - trigger_negotiation_needed: TriggerNegotiationNeededFnOption, - ) -> Arc { - let codecs = Arc::new(Mutex::new(codecs)); - receiver.set_transceiver_codecs(Some(Arc::clone(&codecs))); - - let t = Arc::new(RTCRtpTransceiver { - mid: OnceCell::new(), - sender: Mutex::new(sender), - receiver: Mutex::new(receiver), - - direction: AtomicU8::new(direction as u8), - current_direction: AtomicU8::new(RTCRtpTransceiverDirection::Unspecified as u8), - - codecs, - stopped: AtomicBool::new(false), - kind, - media_engine, - trigger_negotiation_needed: Mutex::new(trigger_negotiation_needed), - }); - t.sender() - .await - .set_rtp_transceiver(Some(Arc::downgrade(&t))); - - t - } - - /// set_codec_preferences sets preferred list of supported codecs - /// if codecs is empty or nil we reset to default from MediaEngine - pub async fn set_codec_preferences(&self, codecs: Vec) -> Result<()> { - for codec in &codecs { - let media_engine_codecs = self.media_engine.get_codecs_by_kind(self.kind); - let (_, match_type) = codec_parameters_fuzzy_search(codec, &media_engine_codecs); - if match_type == CodecMatch::None { - return Err(Error::ErrRTPTransceiverCodecUnsupported); - } - } - - { - let mut c = self.codecs.lock().await; - *c = codecs; - } - Ok(()) - } - - /// Codecs returns list of supported codecs - pub(crate) async fn get_codecs(&self) -> Vec { - let mut codecs = self.codecs.lock().await; - RTPReceiverInternal::get_codecs(&mut codecs, self.kind, &self.media_engine) - } - - /// sender returns the RTPTransceiver's RTPSender if it has one - pub async fn sender(&self) -> Arc { - let sender = self.sender.lock().await; - sender.clone() - } - - /// set_sender_track sets the RTPSender and Track to current transceiver - pub async fn set_sender_track( - self: &Arc, - sender: Arc, - track: Option>, - ) -> Result<()> { - self.set_sender(sender).await; - self.set_sending_track(track).await - } - - pub async fn set_sender(self: &Arc, s: Arc) { - s.set_rtp_transceiver(Some(Arc::downgrade(self))); - - let prev_sender = self.sender().await; - prev_sender.set_rtp_transceiver(None); - - { - let mut sender = self.sender.lock().await; - *sender = s; - } - } - - /// receiver returns the RTPTransceiver's RTPReceiver if it has one - pub async fn receiver(&self) -> Arc { - let receiver = self.receiver.lock().await; - receiver.clone() - } - - pub(crate) async fn set_receiver(&self, r: Arc) { - r.set_transceiver_codecs(Some(Arc::clone(&self.codecs))); - - { - let mut receiver = self.receiver.lock().await; - (*receiver).set_transceiver_codecs(None); - - *receiver = r; - } - } - - /// set_mid sets the RTPTransceiver's mid. If it was already set, will return an error. - pub(crate) fn set_mid(&self, mid: SmolStr) -> Result<()> { - self.mid - .set(mid) - .map_err(|_| Error::ErrRTPTransceiverCannotChangeMid) - } - - /// mid gets the Transceiver's mid value. When not already set, this value will be set in CreateOffer or create_answer. - pub fn mid(&self) -> Option { - self.mid.get().cloned() - } - - /// kind returns RTPTransceiver's kind. - pub fn kind(&self) -> RTPCodecType { - self.kind - } - - /// direction returns the RTPTransceiver's desired direction. - pub fn direction(&self) -> RTCRtpTransceiverDirection { - self.direction.load(Ordering::SeqCst).into() - } - - /// Set the direction of this transceiver. This might trigger a renegotiation. - pub async fn set_direction(&self, d: RTCRtpTransceiverDirection) { - let changed = self.set_direction_internal(d); - - if changed { - let lock = self.trigger_negotiation_needed.lock().await; - if let Some(trigger) = &*lock { - (trigger)().await; - } - } - } - - pub(crate) fn set_direction_internal(&self, d: RTCRtpTransceiverDirection) -> bool { - let previous: RTCRtpTransceiverDirection = - self.direction.swap(d as u8, Ordering::SeqCst).into(); - - let changed = d != previous; - - if changed { - trace!( - "Changing direction of transceiver from {} to {}", - previous, - d - ); - } - - changed - } - - /// current_direction returns the RTPTransceiver's current direction as negotiated. - /// - /// If this transceiver has never been negotiated or if it's stopped this returns [`RTCRtpTransceiverDirection::Unspecified`]. - pub fn current_direction(&self) -> RTCRtpTransceiverDirection { - if self.stopped.load(Ordering::SeqCst) { - return RTCRtpTransceiverDirection::Unspecified; - } - - self.current_direction.load(Ordering::SeqCst).into() - } - - pub(crate) fn set_current_direction(&self, d: RTCRtpTransceiverDirection) { - let previous: RTCRtpTransceiverDirection = self - .current_direction - .swap(d as u8, Ordering::SeqCst) - .into(); - - if d != previous { - trace!( - "Changing current direction of transceiver from {} to {}", - previous, - d, - ); - } - } - - /// Perform any subsequent actions after altering the transceiver's direction. - /// - /// After changing the transceiver's direction this method should be called to perform any - /// side-effects that results from the new direction, such as pausing/resuming the RTP receiver. - pub(crate) async fn process_new_current_direction( - &self, - previous_direction: RTCRtpTransceiverDirection, - ) -> Result<()> { - if self.stopped.load(Ordering::SeqCst) { - return Ok(()); - } - - let current_direction = self.current_direction(); - if previous_direction != current_direction { - let mid = self.mid(); - trace!( - "Processing transceiver({:?}) direction change from {} to {}", - mid, - previous_direction, - current_direction - ); - } else { - // no change. - return Ok(()); - } - - { - let receiver = self.receiver.lock().await; - let pause_receiver = !current_direction.has_recv(); - - if pause_receiver { - receiver.pause().await?; - } else { - receiver.resume().await?; - } - } - - let pause_sender = !current_direction.has_send(); - { - let sender = &*self.sender.lock().await; - sender.set_paused(pause_sender); - } - - Ok(()) - } - - /// stop irreversibly stops the RTPTransceiver - pub async fn stop(&self) -> Result<()> { - if self.stopped.load(Ordering::SeqCst) { - return Ok(()); - } - - self.stopped.store(true, Ordering::SeqCst); - - { - let sender = self.sender.lock().await; - sender.stop().await?; - } - { - let r = self.receiver.lock().await; - r.stop().await?; - } - - self.set_direction_internal(RTCRtpTransceiverDirection::Inactive); - - Ok(()) - } - - pub(crate) async fn set_sending_track( - &self, - track: Option>, - ) -> Result<()> { - let track_is_none = track.is_none(); - { - let sender = self.sender.lock().await; - sender.replace_track(track).await?; - } - - let direction = self.direction(); - let should_send = !track_is_none; - let should_recv = direction.has_recv(); - self.set_direction_internal(RTCRtpTransceiverDirection::from_send_recv( - should_send, - should_recv, - )); - - Ok(()) - } -} - -impl fmt::Debug for RTCRtpTransceiver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RTCRtpTransceiver") - .field("mid", &self.mid) - .field("sender", &self.sender) - .field("receiver", &self.receiver) - .field("direction", &self.direction) - .field("current_direction", &self.current_direction) - .field("codecs", &self.codecs) - .field("stopped", &self.stopped) - .field("kind", &self.kind) - .finish() - } -} - -pub(crate) async fn find_by_mid( - mid: &str, - local_transceivers: &mut Vec>, -) -> Option> { - for (i, t) in local_transceivers.iter().enumerate() { - if t.mid() == Some(SmolStr::from(mid)) { - return Some(local_transceivers.remove(i)); - } - } - - None -} - -/// Given a direction+type pluck a transceiver from the passed list -/// if no entry satisfies the requested type+direction return a inactive Transceiver -pub(crate) async fn satisfy_type_and_direction( - remote_kind: RTPCodecType, - remote_direction: RTCRtpTransceiverDirection, - local_transceivers: &mut Vec>, -) -> Option> { - // Get direction order from most preferred to least - let get_preferred_directions = || -> Vec { - match remote_direction { - RTCRtpTransceiverDirection::Sendrecv => vec![ - RTCRtpTransceiverDirection::Recvonly, - RTCRtpTransceiverDirection::Sendrecv, - ], - RTCRtpTransceiverDirection::Sendonly => vec![RTCRtpTransceiverDirection::Recvonly], - RTCRtpTransceiverDirection::Recvonly => vec![ - RTCRtpTransceiverDirection::Sendonly, - RTCRtpTransceiverDirection::Sendrecv, - ], - _ => vec![], - } - }; - - for possible_direction in get_preferred_directions() { - for (i, t) in local_transceivers.iter().enumerate() { - if t.mid().is_none() && t.kind == remote_kind && possible_direction == t.direction() { - return Some(local_transceivers.remove(i)); - } - } - } - - None -} - -/// handle_unknown_rtp_packet consumes a single RTP Packet and returns information that is helpful -/// for demuxing and handling an unknown SSRC (usually for Simulcast) -pub(crate) fn handle_unknown_rtp_packet( - buf: &[u8], - mid_extension_id: u8, - sid_extension_id: u8, - rsid_extension_id: u8, -) -> Result<(String, String, String, PayloadType)> { - let mut reader = buf; - let rp = rtp::packet::Packet::unmarshal(&mut reader)?; - - if !rp.header.extension { - return Ok((String::new(), String::new(), String::new(), 0)); - } - - let payload_type = rp.header.payload_type; - - let mid = if let Some(payload) = rp.header.get_extension(mid_extension_id) { - String::from_utf8(payload.to_vec())? - } else { - String::new() - }; - - let rid = if let Some(payload) = rp.header.get_extension(sid_extension_id) { - String::from_utf8(payload.to_vec())? - } else { - String::new() - }; - - let srid = if let Some(payload) = rp.header.get_extension(rsid_extension_id) { - String::from_utf8(payload.to_vec())? - } else { - String::new() - }; - - Ok((mid, rid, srid, payload_type)) -} diff --git a/webrtc/src/rtp_transceiver/rtp_codec.rs b/webrtc/src/rtp_transceiver/rtp_codec.rs deleted file mode 100644 index 5bf57501c..000000000 --- a/webrtc/src/rtp_transceiver/rtp_codec.rs +++ /dev/null @@ -1,165 +0,0 @@ -use std::fmt; - -use super::*; -use crate::api::media_engine::*; -use crate::error::{Error, Result}; -use crate::rtp_transceiver::fmtp; - -/// RTPCodecType determines the type of a codec -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTPCodecType { - #[default] - Unspecified = 0, - - /// RTPCodecTypeAudio indicates this is an audio codec - Audio = 1, - - /// RTPCodecTypeVideo indicates this is a video codec - Video = 2, -} - -impl From<&str> for RTPCodecType { - fn from(raw: &str) -> Self { - match raw { - "audio" => RTPCodecType::Audio, - "video" => RTPCodecType::Video, - _ => RTPCodecType::Unspecified, - } - } -} - -impl From for RTPCodecType { - fn from(v: u8) -> Self { - match v { - 1 => RTPCodecType::Audio, - 2 => RTPCodecType::Video, - _ => RTPCodecType::Unspecified, - } - } -} - -impl fmt::Display for RTPCodecType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTPCodecType::Audio => "audio", - RTPCodecType::Video => "video", - RTPCodecType::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -/// RTPCodecCapability provides information about codec capabilities. -/// -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct RTCRtpCodecCapability { - pub mime_type: String, - pub clock_rate: u32, - pub channels: u16, - pub sdp_fmtp_line: String, - pub rtcp_feedback: Vec, -} - -impl RTCRtpCodecCapability { - /// Turn codec capability into a `packetizer::Payloader` - pub fn payloader_for_codec(&self) -> Result> { - let mime_type = self.mime_type.to_lowercase(); - if mime_type == MIME_TYPE_H264.to_lowercase() { - Ok(Box::::default()) - } else if mime_type == MIME_TYPE_HEVC.to_lowercase() { - Ok(Box::::default()) - } else if mime_type == MIME_TYPE_VP8.to_lowercase() { - let mut vp8_payloader = rtp::codecs::vp8::Vp8Payloader::default(); - vp8_payloader.enable_picture_id = true; - Ok(Box::new(vp8_payloader)) - } else if mime_type == MIME_TYPE_VP9.to_lowercase() { - Ok(Box::::default()) - } else if mime_type == MIME_TYPE_OPUS.to_lowercase() { - Ok(Box::::default()) - } else if mime_type == MIME_TYPE_G722.to_lowercase() - || mime_type == MIME_TYPE_PCMU.to_lowercase() - || mime_type == MIME_TYPE_PCMA.to_lowercase() - || mime_type == MIME_TYPE_TELEPHONE_EVENT.to_lowercase() - { - Ok(Box::::default()) - } else if mime_type == MIME_TYPE_AV1.to_lowercase() { - Ok(Box::::default()) - } else { - Err(Error::ErrNoPayloaderForCodec) - } - } -} - -/// RTPHeaderExtensionCapability is used to define a RFC5285 RTP header extension supported by the codec. -/// -#[derive(Default, Debug, Clone)] -pub struct RTCRtpHeaderExtensionCapability { - pub uri: String, -} - -/// RTPHeaderExtensionParameter represents a negotiated RFC5285 RTP header extension. -/// -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct RTCRtpHeaderExtensionParameters { - pub uri: String, - pub id: isize, -} - -/// RTPCodecParameters is a sequence containing the media codecs that an RtpSender -/// will choose from, as well as entries for RTX, RED and FEC mechanisms. This also -/// includes the PayloadType that has been negotiated -/// -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct RTCRtpCodecParameters { - pub capability: RTCRtpCodecCapability, - pub payload_type: PayloadType, - pub stats_id: String, -} - -/// RTPParameters is a list of negotiated codecs and header extensions -/// -#[derive(Default, Debug, Clone)] -pub struct RTCRtpParameters { - pub header_extensions: Vec, - pub codecs: Vec, -} - -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub(crate) enum CodecMatch { - #[default] - None = 0, - Partial = 1, - Exact = 2, -} - -/// Do a fuzzy find for a codec in the list of codecs -/// Used for lookup up a codec in an existing list to find a match -/// Returns codecMatchExact, codecMatchPartial, or codecMatchNone -pub(crate) fn codec_parameters_fuzzy_search( - needle: &RTCRtpCodecParameters, - haystack: &[RTCRtpCodecParameters], -) -> (RTCRtpCodecParameters, CodecMatch) { - let needle_fmtp = fmtp::parse( - &needle.capability.mime_type, - &needle.capability.sdp_fmtp_line, - ); - - //TODO: add unicode case-folding equal support - - // First attempt to match on mime_type + sdpfmtp_line - for c in haystack { - let cfmpt = fmtp::parse(&c.capability.mime_type, &c.capability.sdp_fmtp_line); - if needle_fmtp.match_fmtp(&*cfmpt) { - return (c.clone(), CodecMatch::Exact); - } - } - - // Fallback to just mime_type - for c in haystack { - if c.capability.mime_type.to_uppercase() == needle.capability.mime_type.to_uppercase() { - return (c.clone(), CodecMatch::Partial); - } - } - - (RTCRtpCodecParameters::default(), CodecMatch::None) -} diff --git a/webrtc/src/rtp_transceiver/rtp_receiver/mod.rs b/webrtc/src/rtp_transceiver/rtp_receiver/mod.rs deleted file mode 100644 index e589ee575..000000000 --- a/webrtc/src/rtp_transceiver/rtp_receiver/mod.rs +++ /dev/null @@ -1,853 +0,0 @@ -#[cfg(test)] -mod rtp_receiver_test; - -use std::fmt; -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use interceptor::stream_info::RTPHeaderExtension; -use interceptor::{Attributes, Interceptor}; -use log::trace; -use smol_str::SmolStr; -use tokio::sync::{watch, Mutex, RwLock}; - -use crate::api::media_engine::MediaEngine; -use crate::dtls_transport::RTCDtlsTransport; -use crate::error::{flatten_errs, Error, Result}; -use crate::peer_connection::sdp::TrackDetails; -use crate::rtp_transceiver::rtp_codec::{ - codec_parameters_fuzzy_search, CodecMatch, RTCRtpCodecCapability, RTCRtpCodecParameters, - RTCRtpParameters, RTPCodecType, -}; -use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use crate::rtp_transceiver::{ - create_stream_info, RTCRtpDecodingParameters, RTCRtpReceiveParameters, SSRC, -}; -use crate::track::track_remote::TrackRemote; -use crate::track::{TrackStream, TrackStreams}; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -#[repr(u8)] -pub enum State { - /// We haven't started yet. - Unstarted = 0, - /// We haven't started yet and additionally we've been paused. - UnstartedPaused = 1, - - /// We have started and are running. - Started = 2, - - /// We have been paused after starting. - Paused = 3, - - /// We have been stopped. - Stopped = 4, -} - -impl From for State { - fn from(value: u8) -> Self { - match value { - v if v == State::Unstarted as u8 => State::Unstarted, - v if v == State::UnstartedPaused as u8 => State::UnstartedPaused, - v if v == State::Started as u8 => State::Started, - v if v == State::Paused as u8 => State::Paused, - v if v == State::Stopped as u8 => State::Stopped, - _ => unreachable!( - "Invalid serialization of {}: {}", - std::any::type_name::(), - value - ), - } - } -} - -impl fmt::Display for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::Unstarted => write!(f, "Unstarted"), - State::UnstartedPaused => write!(f, "UnstartedPaused"), - State::Started => write!(f, "Running"), - State::Paused => write!(f, "Paused"), - State::Stopped => write!(f, "Closed"), - } - } -} - -impl State { - fn transition(to: Self, tx: &watch::Sender) -> Result<()> { - let current = *tx.borrow(); - if current == to { - // Already in this state - return Ok(()); - } - - match current { - Self::Unstarted - if matches!(to, Self::Started | Self::Stopped | Self::UnstartedPaused) => - { - let _ = tx.send(to); - return Ok(()); - } - Self::UnstartedPaused - if matches!(to, Self::Unstarted | Self::Stopped | Self::Paused) => - { - let _ = tx.send(to); - return Ok(()); - } - State::Started if matches!(to, Self::Paused | Self::Stopped) => { - let _ = tx.send(to); - return Ok(()); - } - State::Paused if matches!(to, Self::Started | Self::Stopped) => { - let _ = tx.send(to); - return Ok(()); - } - _ => {} - } - - Err(Error::ErrRTPReceiverStateChangeInvalid { from: current, to }) - } - - async fn wait_for(rx: &mut watch::Receiver, states: &[State]) -> Result<()> { - loop { - let state = *rx.borrow(); - - match state { - _ if states.contains(&state) => return Ok(()), - State::Stopped => { - return Err(Error::ErrClosedPipe); - } - _ => {} - } - - if rx.changed().await.is_err() { - return Err(Error::ErrClosedPipe); - } - } - } - - async fn error_on_close(rx: &mut watch::Receiver) -> Result<()> { - if rx.changed().await.is_err() { - return Err(Error::ErrClosedPipe); - } - - let state = *rx.borrow(); - if state == State::Stopped { - return Err(Error::ErrClosedPipe); - } - - Ok(()) - } - - fn is_started(&self) -> bool { - matches!(self, Self::Started | Self::Paused) - } -} - -pub struct RTPReceiverInternal { - pub(crate) kind: RTPCodecType, - - // State is stored within the channel - state_tx: watch::Sender, - state_rx: watch::Receiver, - - tracks: RwLock>, - - transceiver_codecs: ArcSwapOption>>, - - transport: Arc, - media_engine: Arc, - interceptor: Arc, -} - -impl RTPReceiverInternal { - /// read reads incoming RTCP for this RTPReceiver - async fn read( - &self, - b: &mut [u8], - ) -> Result<(Vec>, Attributes)> { - let mut state_watch_rx = self.state_tx.subscribe(); - // Ensure we are running or paused. When paused we still receive RTCP even if RTP traffic - // isn't flowing. - State::wait_for(&mut state_watch_rx, &[State::Started, State::Paused]).await?; - - let tracks = self.tracks.read().await; - if let Some(t) = tracks.first() { - if let Some(rtcp_interceptor) = &t.stream.rtcp_interceptor { - let a = Attributes::new(); - loop { - tokio::select! { - res = State::error_on_close(&mut state_watch_rx) => { - res? - } - result = rtcp_interceptor.read(b, &a) => { - return Ok(result?) - } - } - } - } else { - Err(Error::ErrInterceptorNotBind) - } - } else { - Err(Error::ErrExistingTrack) - } - } - - /// read_simulcast reads incoming RTCP for this RTPReceiver for given rid - async fn read_simulcast( - &self, - b: &mut [u8], - rid: &str, - ) -> Result<(Vec>, Attributes)> { - let mut state_watch_rx = self.state_tx.subscribe(); - - // Ensure we are running or paused. When paused we still receive RTCP even if RTP traffic - // isn't flowing. - State::wait_for(&mut state_watch_rx, &[State::Started, State::Paused]).await?; - - let tracks = self.tracks.read().await; - for t in &*tracks { - if t.track.rid() == rid { - if let Some(rtcp_interceptor) = &t.stream.rtcp_interceptor { - let a = Attributes::new(); - - loop { - tokio::select! { - res = State::error_on_close(&mut state_watch_rx) => { - res? - } - result = rtcp_interceptor.read(b, &a) => { - return Ok(result?); - } - } - } - } else { - return Err(Error::ErrInterceptorNotBind); - } - } - } - Err(Error::ErrRTPReceiverForRIDTrackStreamNotFound) - } - - /// read_rtcp is a convenience method that wraps Read and unmarshal for you. - /// It also runs any configured interceptors. - async fn read_rtcp( - &self, - receive_mtu: usize, - ) -> Result<(Vec>, Attributes)> { - let mut b = vec![0u8; receive_mtu]; - let (pkts, attributes) = self.read(&mut b).await?; - - Ok((pkts, attributes)) - } - - /// read_simulcast_rtcp is a convenience method that wraps ReadSimulcast and unmarshal for you - async fn read_simulcast_rtcp( - &self, - rid: &str, - receive_mtu: usize, - ) -> Result<(Vec>, Attributes)> { - let mut b = vec![0u8; receive_mtu]; - let (pkts, attributes) = self.read_simulcast(&mut b, rid).await?; - - Ok((pkts, attributes)) - } - - pub(crate) async fn read_rtp( - &self, - b: &mut [u8], - tid: usize, - ) -> Result<(rtp::packet::Packet, Attributes)> { - let mut state_watch_rx = self.state_tx.subscribe(); - - // Ensure we are running. - State::wait_for(&mut state_watch_rx, &[State::Started]).await?; - - //log::debug!("read_rtp enter tracks tid {}", tid); - let mut rtp_interceptor = None; - //let mut ssrc = 0; - { - let tracks = self.tracks.read().await; - for t in &*tracks { - if t.track.tid() == tid { - rtp_interceptor.clone_from(&t.stream.rtp_interceptor); - //ssrc = t.track.ssrc(); - break; - } - } - }; - /*log::debug!( - "read_rtp exit tracks with rtp_interceptor {} with tid {}", - rtp_interceptor.is_some(), - tid, - );*/ - - if let Some(rtp_interceptor) = rtp_interceptor { - let a = Attributes::new(); - //println!( - // "read_rtp rtp_interceptor.read enter with tid {} ssrc {}", - // tid, ssrc - //); - let mut current_state = *state_watch_rx.borrow(); - loop { - tokio::select! { - _ = state_watch_rx.changed() => { - let new_state = *state_watch_rx.borrow(); - - if new_state == State::Stopped { - return Err(Error::ErrClosedPipe); - } - current_state = new_state; - } - result = rtp_interceptor.read(b, &a) => { - let result = result?; - - if current_state == State::Paused { - trace!("Dropping {} read bytes received while RTPReceiver was paused", result.0); - continue; - } - return Ok(result); - } - } - } - } else { - //log::debug!("read_rtp exit tracks with ErrRTPReceiverWithSSRCTrackStreamNotFound"); - Err(Error::ErrRTPReceiverWithSSRCTrackStreamNotFound) - } - } - - async fn get_parameters(&self) -> RTCRtpParameters { - let mut parameters = self - .media_engine - .get_rtp_parameters_by_kind(self.kind, RTCRtpTransceiverDirection::Recvonly); - - let transceiver_codecs = self.transceiver_codecs.load(); - if let Some(codecs) = &*transceiver_codecs { - let mut c = codecs.lock().await; - parameters.codecs = - RTPReceiverInternal::get_codecs(&mut c, self.kind, &self.media_engine); - } - - parameters - } - - pub(crate) fn get_codecs( - codecs: &mut [RTCRtpCodecParameters], - kind: RTPCodecType, - media_engine: &Arc, - ) -> Vec { - let media_engine_codecs = media_engine.get_codecs_by_kind(kind); - if codecs.is_empty() { - return media_engine_codecs; - } - let mut filtered_codecs = vec![]; - for codec in codecs { - let (c, match_type) = codec_parameters_fuzzy_search(codec, &media_engine_codecs); - if match_type != CodecMatch::None { - if codec.payload_type == 0 { - codec.payload_type = c.payload_type; - } - filtered_codecs.push(codec.clone()); - } - } - - filtered_codecs - } - - // State - - /// Get the current state and a receiver for the next state change. - pub(crate) fn current_state(&self) -> State { - *self.state_rx.borrow() - } - - pub(crate) fn start(&self) -> Result<()> { - State::transition(State::Started, &self.state_tx) - } - - pub(crate) fn pause(&self) -> Result<()> { - let current = self.current_state(); - - match current { - State::Unstarted => State::transition(State::UnstartedPaused, &self.state_tx), - State::Started => State::transition(State::Paused, &self.state_tx), - _ => Ok(()), - } - } - - pub(crate) fn resume(&self) -> Result<()> { - let current = self.current_state(); - - match current { - State::UnstartedPaused => State::transition(State::Unstarted, &self.state_tx), - State::Paused => State::transition(State::Started, &self.state_tx), - _ => Ok(()), - } - } - - pub(crate) fn close(&self) -> Result<()> { - State::transition(State::Stopped, &self.state_tx) - } -} - -/// RTPReceiver allows an application to inspect the receipt of a TrackRemote -pub struct RTCRtpReceiver { - receive_mtu: usize, - kind: RTPCodecType, - transport: Arc, - - pub internal: Arc, -} - -impl std::fmt::Debug for RTCRtpReceiver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RTCRtpReceiver") - .field("kind", &self.kind) - .finish() - } -} - -impl RTCRtpReceiver { - pub fn new( - receive_mtu: usize, - kind: RTPCodecType, - transport: Arc, - media_engine: Arc, - interceptor: Arc, - ) -> Self { - let (state_tx, state_rx) = watch::channel(State::Unstarted); - - RTCRtpReceiver { - receive_mtu, - kind, - transport: Arc::clone(&transport), - - internal: Arc::new(RTPReceiverInternal { - kind, - - tracks: RwLock::new(vec![]), - transport, - media_engine, - interceptor, - - state_tx, - state_rx, - - transceiver_codecs: ArcSwapOption::new(None), - }), - } - } - - pub fn kind(&self) -> RTPCodecType { - self.kind - } - - pub(crate) fn set_transceiver_codecs( - &self, - codecs: Option>>>, - ) { - self.internal.transceiver_codecs.store(codecs); - } - - /// transport returns the currently-configured *DTLSTransport or nil - /// if one has not yet been configured - pub fn transport(&self) -> Arc { - Arc::clone(&self.transport) - } - - /// get_parameters describes the current configuration for the encoding and - /// transmission of media on the receiver's track. - pub async fn get_parameters(&self) -> RTCRtpParameters { - self.internal.get_parameters().await - } - - /// SetRTPParameters applies provided RTPParameters the RTPReceiver's tracks. - /// This method is part of the ORTC API. It is not - /// meant to be used together with the basic WebRTC API. - /// The amount of provided codecs must match the number of tracks on the receiver. - pub async fn set_rtp_parameters(&self, params: RTCRtpParameters) { - let mut header_extensions = vec![]; - for h in ¶ms.header_extensions { - header_extensions.push(RTPHeaderExtension { - id: h.id, - uri: h.uri.clone(), - }); - } - - let mut tracks = self.internal.tracks.write().await; - for (idx, codec) in params.codecs.iter().enumerate() { - let t = &mut tracks[idx]; - if let Some(stream_info) = &mut t.stream.stream_info { - stream_info - .rtp_header_extensions - .clone_from(&header_extensions); - } - - let current_track = &t.track; - current_track.set_codec(codec.clone()); - current_track.set_params(params.clone()); - } - } - - /// tracks returns the RtpTransceiver traclockks - /// A RTPReceiver to support Simulcast may now have multiple tracks - pub async fn tracks(&self) -> Vec> { - let tracks = self.internal.tracks.read().await; - tracks.iter().map(|t| Arc::clone(&t.track)).collect() - } - - /// receive initialize the track and starts all the transports - pub async fn receive(&self, parameters: &RTCRtpReceiveParameters) -> Result<()> { - let receiver = Arc::downgrade(&self.internal); - - let current_state = self.internal.current_state(); - if current_state.is_started() { - return Err(Error::ErrRTPReceiverReceiveAlreadyCalled); - } - self.internal.start()?; - - let (global_params, interceptor, media_engine) = { - ( - self.internal.get_parameters().await, - Arc::clone(&self.internal.interceptor), - Arc::clone(&self.internal.media_engine), - ) - }; - - let codec = if let Some(codec) = global_params.codecs.first() { - codec.capability.clone() - } else { - RTCRtpCodecCapability::default() - }; - - for encoding in ¶meters.encodings { - let (stream_info, rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = - if encoding.ssrc != 0 { - let stream_info = create_stream_info( - "".to_owned(), - encoding.ssrc, - 0, - codec.clone(), - &global_params.header_extensions, - ); - let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = - self.transport - .streams_for_ssrc(encoding.ssrc, &stream_info, &interceptor) - .await?; - - ( - Some(stream_info), - Some(rtp_read_stream), - Some(rtp_interceptor), - Some(rtcp_read_stream), - Some(rtcp_interceptor), - ) - } else { - (None, None, None, None, None) - }; - - let t = TrackStreams { - track: Arc::new(TrackRemote::new( - self.receive_mtu, - self.kind, - encoding.ssrc, - encoding.rid.clone(), - receiver.clone(), - Arc::clone(&media_engine), - Arc::clone(&interceptor), - )), - stream: TrackStream { - stream_info, - rtp_read_stream, - rtp_interceptor, - rtcp_read_stream, - rtcp_interceptor, - }, - - repair_stream: TrackStream { - stream_info: None, - rtp_read_stream: None, - rtp_interceptor: None, - rtcp_read_stream: None, - rtcp_interceptor: None, - }, - }; - - { - let mut tracks = self.internal.tracks.write().await; - tracks.push(t); - }; - - let rtx_ssrc = encoding.rtx.ssrc; - if rtx_ssrc != 0 { - let stream_info = create_stream_info( - "".to_owned(), - rtx_ssrc, - 0, - codec.clone(), - &global_params.header_extensions, - ); - let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = self - .transport - .streams_for_ssrc(rtx_ssrc, &stream_info, &interceptor) - .await?; - - self.receive_for_rtx( - rtx_ssrc, - "".to_owned(), - TrackStream { - stream_info: Some(stream_info), - rtp_read_stream: Some(rtp_read_stream), - rtp_interceptor: Some(rtp_interceptor), - rtcp_read_stream: Some(rtcp_read_stream), - rtcp_interceptor: Some(rtcp_interceptor), - }, - ) - .await?; - } - } - - Ok(()) - } - - /// read reads incoming RTCP for this RTPReceiver - pub async fn read( - &self, - b: &mut [u8], - ) -> Result<(Vec>, Attributes)> { - self.internal.read(b).await - } - - /// read_simulcast reads incoming RTCP for this RTPReceiver for given rid - pub async fn read_simulcast( - &self, - b: &mut [u8], - rid: &str, - ) -> Result<(Vec>, Attributes)> { - self.internal.read_simulcast(b, rid).await - } - - /// read_rtcp is a convenience method that wraps Read and unmarshal for you. - /// It also runs any configured interceptors. - pub async fn read_rtcp( - &self, - ) -> Result<(Vec>, Attributes)> { - self.internal.read_rtcp(self.receive_mtu).await - } - - /// read_simulcast_rtcp is a convenience method that wraps ReadSimulcast and unmarshal for you - pub async fn read_simulcast_rtcp( - &self, - rid: &str, - ) -> Result<(Vec>, Attributes)> { - self.internal - .read_simulcast_rtcp(rid, self.receive_mtu) - .await - } - - pub(crate) async fn have_received(&self) -> bool { - self.internal.current_state().is_started() - } - - pub(crate) async fn start(&self, incoming: &TrackDetails) { - let mut encoding_size = incoming.ssrcs.len(); - if incoming.rids.len() >= encoding_size { - encoding_size = incoming.rids.len(); - }; - - let mut encodings = vec![RTCRtpDecodingParameters::default(); encoding_size]; - for (i, encoding) in encodings.iter_mut().enumerate() { - if incoming.rids.len() > i { - encoding.rid = incoming.rids[i].clone(); - } - if incoming.ssrcs.len() > i { - encoding.ssrc = incoming.ssrcs[i]; - } - - encoding.rtx.ssrc = incoming.repair_ssrc; - } - - if let Err(err) = self.receive(&RTCRtpReceiveParameters { encodings }).await { - log::warn!("RTPReceiver Receive failed {}", err); - return; - } - - // set track id and label early so they can be set as new track information - // is received from the SDP. - let is_unpaused = self.current_state() == State::Started; - for track_remote in &self.tracks().await { - track_remote.set_id(incoming.id.clone()); - track_remote.set_stream_id(incoming.stream_id.clone()); - - if is_unpaused { - track_remote.fire_onunmute().await; - } - } - } - - /// Stop irreversibly stops the RTPReceiver - pub async fn stop(&self) -> Result<()> { - let previous_state = self.internal.current_state(); - self.internal.close()?; - - let mut errs = vec![]; - let was_ever_started = previous_state.is_started(); - if was_ever_started { - let tracks = self.internal.tracks.write().await; - for t in &*tracks { - if let Some(rtcp_read_stream) = &t.stream.rtcp_read_stream { - if let Err(err) = rtcp_read_stream.close().await { - errs.push(err); - } - } - - if let Some(rtp_read_stream) = &t.stream.rtp_read_stream { - if let Err(err) = rtp_read_stream.close().await { - errs.push(err); - } - } - - if let Some(repair_rtcp_read_stream) = &t.repair_stream.rtcp_read_stream { - if let Err(err) = repair_rtcp_read_stream.close().await { - errs.push(err); - } - } - - if let Some(repair_rtp_read_stream) = &t.repair_stream.rtp_read_stream { - if let Err(err) = repair_rtp_read_stream.close().await { - errs.push(err); - } - } - - if let Some(stream_info) = &t.stream.stream_info { - self.internal - .interceptor - .unbind_remote_stream(stream_info) - .await; - } - - if let Some(repair_stream_info) = &t.repair_stream.stream_info { - self.internal - .interceptor - .unbind_remote_stream(repair_stream_info) - .await; - } - } - } - - flatten_errs(errs) - } - - /// read_rtp should only be called by a track, this only exists so we can keep state in one place - pub(crate) async fn read_rtp( - &self, - b: &mut [u8], - tid: usize, - ) -> Result<(rtp::packet::Packet, Attributes)> { - self.internal.read_rtp(b, tid).await - } - - /// receive_for_rid is the sibling of Receive expect for RIDs instead of SSRCs - /// It populates all the internal state for the given RID - pub(crate) async fn receive_for_rid( - &self, - rid: SmolStr, - params: RTCRtpParameters, - stream: TrackStream, - ) -> Result> { - let mut tracks = self.internal.tracks.write().await; - for t in &mut *tracks { - if *t.track.rid() == rid { - t.track.set_kind(self.kind); - if let Some(codec) = params.codecs.first() { - t.track.set_codec(codec.clone()); - } - t.track.set_params(params.clone()); - t.track - .set_ssrc(stream.stream_info.as_ref().map_or(0, |s| s.ssrc)); - t.stream = stream; - return Ok(Arc::clone(&t.track)); - } - } - - Err(Error::ErrRTPReceiverForRIDTrackStreamNotFound) - } - - /// receiveForRtx starts a routine that processes the repair stream - /// These packets aren't exposed to the user yet, but we need to process them for - /// TWCC - pub(crate) async fn receive_for_rtx( - &self, - ssrc: SSRC, - rsid: String, - repair_stream: TrackStream, - ) -> Result<()> { - let mut tracks = self.internal.tracks.write().await; - let l = tracks.len(); - for t in &mut *tracks { - if (ssrc != 0 && l == 1) || t.track.rid() == rsid { - t.repair_stream = repair_stream; - - let receive_mtu = self.receive_mtu; - let track = t.clone(); - tokio::spawn(async move { - let a = Attributes::new(); - let mut b = vec![0u8; receive_mtu]; - while let Some(repair_rtp_interceptor) = &track.repair_stream.rtp_interceptor { - //TODO: cancel repair_rtp_interceptor.read gracefully - //println!("repair_rtp_interceptor read begin with ssrc={}", ssrc); - if repair_rtp_interceptor.read(&mut b, &a).await.is_err() { - break; - } - } - }); - - return Ok(()); - } - } - - Err(Error::ErrRTPReceiverForRIDTrackStreamNotFound) - } - - // State - - pub(crate) fn current_state(&self) -> State { - self.internal.current_state() - } - - pub(crate) async fn pause(&self) -> Result<()> { - self.internal.pause()?; - - if !self.internal.current_state().is_started() { - return Ok(()); - } - - let streams = self.internal.tracks.read().await; - - for stream in streams.iter() { - // TODO: If we introduce futures as a direct dependency this and other futures could be - // ran concurrently with [`join_all`](https://docs.rs/futures/0.3.21/futures/future/fn.join_all.html) - stream.track.fire_onmute().await; - } - - Ok(()) - } - - pub(crate) async fn resume(&self) -> Result<()> { - self.internal.resume()?; - - if !self.internal.current_state().is_started() { - return Ok(()); - } - - let streams = self.internal.tracks.read().await; - - for stream in streams.iter() { - // TODO: If we introduce futures as a direct dependency this and other futures could be - // ran concurrently with [`join_all`](https://docs.rs/futures/0.3.21/futures/future/fn.join_all.html) - stream.track.fire_onunmute().await; - } - - Ok(()) - } -} diff --git a/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs b/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs deleted file mode 100644 index 7520667db..000000000 --- a/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs +++ /dev/null @@ -1,233 +0,0 @@ -use bytes::Bytes; -use media::Sample; -use tokio::sync::mpsc; -use tokio::time::Duration; -use waitgroup::WaitGroup; - -use super::*; -use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8}; -use crate::error::Result; -use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; -use crate::peer_connection::peer_connection_test::{ - close_pair_now, create_vnet_pair, signal_pair, until_connection_state, -}; -use crate::rtp_transceiver::rtp_codec::RTCRtpHeaderExtensionParameters; -use crate::rtp_transceiver::RTCPFeedback; -use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use crate::track::track_local::TrackLocal; - -lazy_static! { - static ref P: RTCRtpParameters = RTCRtpParameters { - codecs: vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_string(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), - rtcp_feedback: vec![RTCPFeedback { - typ: "nack".to_owned(), - parameter: "".to_owned(), - }], - }, - payload_type: 111, - ..Default::default() - }], - header_extensions: vec![ - RTCRtpHeaderExtensionParameters { - uri: "urn:ietf:params:rtp-hdrext:sdes:mid".to_owned(), - ..Default::default() - }, - RTCRtpHeaderExtensionParameters { - uri: "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id".to_owned(), - ..Default::default() - }, - RTCRtpHeaderExtensionParameters { - uri: "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id".to_owned(), - ..Default::default() - }, - ], - }; -} - -//use log::LevelFilter; -//use std::io::Write; - -#[tokio::test] -async fn test_set_rtp_parameters() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let (mut sender, mut receiver, wan) = create_vnet_pair().await?; - - let outgoing_track: Arc = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - sender.add_track(Arc::clone(&outgoing_track)).await?; - - // Those parameters wouldn't make sense in a real application, - // but for the sake of the test we just need different values. - - let (seen_packet_tx, mut seen_packet_rx) = mpsc::channel::<()>(1); - let seen_packet_tx = Arc::new(Mutex::new(Some(seen_packet_tx))); - receiver.on_track(Box::new(move |_, receiver, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - receiver.set_rtp_parameters(P.clone()).await; - - let tracks = receiver.tracks().await; - assert_eq!(tracks.len(), 1); - let t = tracks.first().unwrap(); - - let incoming_track_codecs = t.codec(); - - assert_eq!(P.header_extensions, t.params().header_extensions); - assert_eq!( - P.codecs[0].capability.mime_type, - incoming_track_codecs.capability.mime_type - ); - assert_eq!( - P.codecs[0].capability.clock_rate, - incoming_track_codecs.capability.clock_rate - ); - assert_eq!( - P.codecs[0].capability.channels, - incoming_track_codecs.capability.channels - ); - assert_eq!( - P.codecs[0].capability.sdp_fmtp_line, - incoming_track_codecs.capability.sdp_fmtp_line - ); - assert_eq!( - P.codecs[0].capability.rtcp_feedback, - incoming_track_codecs.capability.rtcp_feedback - ); - assert_eq!(P.codecs[0].payload_type, incoming_track_codecs.payload_type); - - { - let mut done = seen_packet_tx2.lock().await; - done.take(); - } - }) - })); - - let wg = WaitGroup::new(); - - until_connection_state(&mut sender, &wg, RTCPeerConnectionState::Connected).await; - until_connection_state(&mut receiver, &wg, RTCPeerConnectionState::Connected).await; - - signal_pair(&mut sender, &mut receiver).await?; - - wg.wait().await; - - if let Some(v) = outgoing_track - .as_any() - .downcast_ref::() - { - v.write_sample(&Sample { - data: Bytes::from_static(&[0xAA]), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await?; - } else { - panic!(); - } - - let _ = seen_packet_rx.recv().await; - { - let mut w = wan.lock().await; - w.stop().await?; - } - close_pair_now(&sender, &receiver).await; - - Ok(()) -} - -// Assert that SetReadDeadline works as expected -// This test uses VNet since we must have zero loss -#[tokio::test] -async fn test_rtp_receiver_set_read_deadline() -> Result<()> { - let (mut sender, mut receiver, wan) = create_vnet_pair().await?; - - let track: Arc = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - sender.add_track(Arc::clone(&track)).await?; - - let (seen_packet_tx, mut seen_packet_rx) = mpsc::channel::<()>(1); - let seen_packet_tx = Arc::new(Mutex::new(Some(seen_packet_tx))); - receiver.on_track(Box::new(move |track, receiver, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - // First call will not error because we cache for probing - let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; - assert!( - result.is_ok(), - " First call will not error because we cache for probing" - ); - - let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; - assert!(result.is_err()); - - let result = tokio::time::timeout(Duration::from_secs(1), receiver.read_rtcp()).await; - assert!(result.is_err()); - - { - let mut done = seen_packet_tx2.lock().await; - done.take(); - } - }) - })); - - let wg = WaitGroup::new(); - until_connection_state(&mut sender, &wg, RTCPeerConnectionState::Connected).await; - until_connection_state(&mut receiver, &wg, RTCPeerConnectionState::Connected).await; - - signal_pair(&mut sender, &mut receiver).await?; - - wg.wait().await; - - if let Some(v) = track.as_any().downcast_ref::() { - v.write_sample(&Sample { - data: Bytes::from_static(&[0xAA]), - duration: Duration::from_secs(1), - ..Default::default() - }) - .await?; - } else { - panic!(); - } - - let _ = seen_packet_rx.recv().await; - { - let mut w = wan.lock().await; - w.stop().await?; - } - close_pair_now(&sender, &receiver).await; - - Ok(()) -} diff --git a/webrtc/src/rtp_transceiver/rtp_sender/mod.rs b/webrtc/src/rtp_transceiver/rtp_sender/mod.rs deleted file mode 100644 index 9ae813e06..000000000 --- a/webrtc/src/rtp_transceiver/rtp_sender/mod.rs +++ /dev/null @@ -1,593 +0,0 @@ -#[cfg(test)] -mod rtp_sender_test; - -use std::sync::atomic::Ordering; -use std::sync::{Arc, Weak}; - -use ice::rand::generate_crypto_random_string; -use interceptor::stream_info::StreamInfo; -use interceptor::{Attributes, Interceptor, RTCPReader, RTPWriter}; -use portable_atomic::AtomicBool; -use tokio::sync::{mpsc, Mutex, Notify}; -use util::sync::Mutex as SyncMutex; - -use super::srtp_writer_future::SequenceTransformer; -use crate::api::media_engine::MediaEngine; -use crate::dtls_transport::RTCDtlsTransport; -use crate::error::{Error, Result}; -use crate::rtp_transceiver::rtp_codec::RTPCodecType; -use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; -use crate::rtp_transceiver::srtp_writer_future::SrtpWriterFuture; -use crate::rtp_transceiver::{ - create_stream_info, PayloadType, RTCRtpEncodingParameters, RTCRtpSendParameters, - RTCRtpTransceiver, SSRC, -}; -use crate::track::track_local::{ - InterceptorToTrackLocalWriter, TrackLocal, TrackLocalContext, TrackLocalWriter, -}; - -pub(crate) struct RTPSenderInternal { - pub(crate) send_called_rx: Mutex>, - pub(crate) stop_called_rx: Arc, - pub(crate) stop_called_signal: Arc, -} - -pub(crate) struct TrackEncoding { - pub(crate) track: Arc, - pub(crate) srtp_stream: Arc, - pub(crate) rtcp_interceptor: Arc, - pub(crate) stream_info: Mutex, - pub(crate) context: Mutex, - - pub(crate) ssrc: SSRC, -} - -/// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer -pub struct RTCRtpSender { - pub(crate) track_encodings: Mutex>, - - seq_trans: Arc, - - pub(crate) transport: Arc, - - pub(crate) kind: RTPCodecType, - pub(crate) payload_type: PayloadType, - receive_mtu: usize, - - /// a transceiver sender since we can just check the - /// transceiver negotiation status - pub(crate) negotiated: AtomicBool, - - pub(crate) media_engine: Arc, - pub(crate) interceptor: Arc, - - pub(crate) id: String, - - /// The id of the initial track, even if we later change to a different - /// track id should be use when negotiating. - pub(crate) initial_track_id: std::sync::Mutex>, - /// AssociatedMediaStreamIds from the WebRTC specifications - pub(crate) associated_media_stream_ids: std::sync::Mutex>, - - rtp_transceiver: SyncMutex>>, - - send_called_tx: SyncMutex>>, - stop_called_tx: Arc, - stop_called_signal: Arc, - - pub(crate) paused: Arc, - - internal: Arc, -} - -impl std::fmt::Debug for RTCRtpSender { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RTCRtpSender") - .field("id", &self.id) - .finish() - } -} - -impl RTCRtpSender { - pub async fn new( - receive_mtu: usize, - track: Option>, - kind: RTPCodecType, - transport: Arc, - media_engine: Arc, - interceptor: Arc, - start_paused: bool, - ) -> Self { - let id = generate_crypto_random_string( - 32, - b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", - ); - let (send_called_tx, send_called_rx) = mpsc::channel(1); - let stop_called_tx = Arc::new(Notify::new()); - let stop_called_rx = stop_called_tx.clone(); - let stop_called_signal = Arc::new(AtomicBool::new(false)); - - let internal = Arc::new(RTPSenderInternal { - send_called_rx: Mutex::new(send_called_rx), - stop_called_rx, - stop_called_signal: Arc::clone(&stop_called_signal), - }); - - let seq_trans = Arc::new(SequenceTransformer::new()); - - let stream_ids = track - .as_ref() - .map(|track| vec![track.stream_id().to_string()]) - .unwrap_or_default(); - let ret = Self { - track_encodings: Mutex::new(vec![]), - - seq_trans, - - transport, - - kind, - payload_type: 0, - receive_mtu, - - negotiated: AtomicBool::new(false), - - media_engine, - interceptor, - - id, - initial_track_id: std::sync::Mutex::new(None), - associated_media_stream_ids: std::sync::Mutex::new(stream_ids), - - rtp_transceiver: SyncMutex::new(None), - - send_called_tx: SyncMutex::new(Some(send_called_tx)), - stop_called_tx, - stop_called_signal, - - paused: Arc::new(AtomicBool::new(start_paused)), - - internal, - }; - - if let Some(track) = track { - let mut track_encodings = ret.track_encodings.lock().await; - let _ = ret.add_encoding_internal(&mut track_encodings, track).await; - } - - ret - } - - /// AddEncoding adds an encoding to RTPSender. Used by simulcast senders. - pub async fn add_encoding(&self, track: Arc) -> Result<()> { - let mut track_encodings = self.track_encodings.lock().await; - - if track.rid().is_none() { - return Err(Error::ErrRTPSenderRidNil); - } - - if self.has_stopped().await { - return Err(Error::ErrRTPSenderStopped); - } - - if self.has_sent() { - return Err(Error::ErrRTPSenderSendAlreadyCalled); - } - - let base_track = track_encodings - .first() - .map(|e| &e.track) - .ok_or(Error::ErrRTPSenderNoBaseEncoding)?; - if base_track.rid().is_none() { - return Err(Error::ErrRTPSenderNoBaseEncoding); - } - - if base_track.id() != track.id() - || base_track.stream_id() != track.stream_id() - || base_track.kind() != track.kind() - { - return Err(Error::ErrRTPSenderBaseEncodingMismatch); - } - - if track_encodings.iter().any(|e| e.track.rid() == track.rid()) { - return Err(Error::ErrRTPSenderRIDCollision); - } - - self.add_encoding_internal(&mut track_encodings, track) - .await - } - - async fn add_encoding_internal( - &self, - track_encodings: &mut Vec, - track: Arc, - ) -> Result<()> { - let ssrc = rand::random::(); - let srtp_stream = Arc::new(SrtpWriterFuture { - closed: AtomicBool::new(false), - ssrc, - rtp_sender: Arc::downgrade(&self.internal), - rtp_transport: Arc::clone(&self.transport), - rtcp_read_stream: Mutex::new(None), - rtp_write_session: Mutex::new(None), - seq_trans: Arc::clone(&self.seq_trans), - }); - - let srtp_rtcp_reader = Arc::clone(&srtp_stream) as Arc; - let rtcp_interceptor = self.interceptor.bind_rtcp_reader(srtp_rtcp_reader).await; - - let encoding = TrackEncoding { - track, - srtp_stream, - rtcp_interceptor, - stream_info: Mutex::new(StreamInfo::default()), - context: Mutex::new(TrackLocalContext::default()), - ssrc, - }; - - track_encodings.push(encoding); - - Ok(()) - } - - pub(crate) fn is_negotiated(&self) -> bool { - self.negotiated.load(Ordering::SeqCst) - } - - pub(crate) fn set_negotiated(&self) { - self.negotiated.store(true, Ordering::SeqCst); - } - - pub(crate) fn set_rtp_transceiver(&self, rtp_transceiver: Option>) { - if let Some(t) = rtp_transceiver.as_ref().and_then(|t| t.upgrade()) { - self.set_paused(!t.direction().has_send()); - } - let mut tr = self.rtp_transceiver.lock(); - *tr = rtp_transceiver; - } - - pub(crate) fn set_paused(&self, paused: bool) { - self.paused.store(paused, Ordering::SeqCst); - } - - /// transport returns the currently-configured DTLSTransport - /// if one has not yet been configured - pub fn transport(&self) -> Arc { - Arc::clone(&self.transport) - } - - /// get_parameters describes the current configuration for the encoding and - /// transmission of media on the sender's track. - pub async fn get_parameters(&self) -> RTCRtpSendParameters { - let encodings = { - let track_encodings = self.track_encodings.lock().await; - let mut encodings = Vec::with_capacity(track_encodings.len()); - for e in track_encodings.iter() { - encodings.push(RTCRtpEncodingParameters { - rid: e.track.rid().unwrap_or_default().into(), - ssrc: e.ssrc, - payload_type: self.payload_type, - ..Default::default() - }); - } - - encodings - }; - - let mut rtp_parameters = self - .media_engine - .get_rtp_parameters_by_kind(self.kind, RTCRtpTransceiverDirection::Sendonly); - rtp_parameters.codecs = { - let tr = self - .rtp_transceiver - .lock() - .clone() - .and_then(|t| t.upgrade()); - if let Some(t) = &tr { - t.get_codecs().await - } else { - self.media_engine.get_codecs_by_kind(self.kind) - } - }; - - RTCRtpSendParameters { - rtp_parameters, - encodings, - } - } - - /// track returns the RTCRtpTransceiver track, or nil - pub async fn track(&self) -> Option> { - self.track_encodings - .lock() - .await - .first() - .map(|e| Arc::clone(&e.track)) - } - - /// replace_track replaces the track currently being used as the sender's source with a new TrackLocal. - /// The new track must be of the same media kind (audio, video, etc) and switching the track should not - /// require negotiation. - pub async fn replace_track( - &self, - track: Option>, - ) -> Result<()> { - let mut track_encodings = self.track_encodings.lock().await; - - if let Some(t) = &track { - if self.kind != t.kind() { - return Err(Error::ErrRTPSenderNewTrackHasIncorrectKind); - } - - // cannot replace simulcast envelope - if track_encodings.len() > 1 { - return Err(Error::ErrRTPSenderNewTrackHasIncorrectEnvelope); - } - - let encoding = track_encodings - .first_mut() - .ok_or(Error::ErrRTPSenderNewTrackHasIncorrectEnvelope)?; - - let mut context = encoding.context.lock().await; - if self.has_sent() { - encoding.track.unbind(&context).await?; - } - - self.seq_trans.reset_offset(); - - let mid = self - .rtp_transceiver - .lock() - .clone() - .and_then(|t| t.upgrade()) - .and_then(|t| t.mid()); - - let new_context = TrackLocalContext { - id: context.id.clone(), - params: self - .media_engine - .get_rtp_parameters_by_kind(t.kind(), RTCRtpTransceiverDirection::Sendonly), - ssrc: context.ssrc, - write_stream: context.write_stream.clone(), - paused: self.paused.clone(), - mid, - }; - - match t.bind(&new_context).await { - Err(err) => { - // Re-bind the original track - encoding.track.bind(&context).await?; - - Err(err) - } - Ok(codec) => { - // Codec has changed - context.params.codecs = vec![codec]; - encoding.track = Arc::clone(t); - Ok(()) - } - } - } else { - if self.has_sent() { - for encoding in track_encodings.drain(..) { - let context = encoding.context.lock().await; - encoding.track.unbind(&context).await?; - } - } else { - track_encodings.clear(); - } - - Ok(()) - } - } - - /// send Attempts to set the parameters controlling the sending of media. - pub async fn send(&self, parameters: &RTCRtpSendParameters) -> Result<()> { - if self.has_sent() { - return Err(Error::ErrRTPSenderSendAlreadyCalled); - } - let track_encodings = self.track_encodings.lock().await; - if track_encodings.is_empty() { - return Err(Error::ErrRTPSenderTrackRemoved); - } - - let mid = self - .rtp_transceiver - .lock() - .clone() - .and_then(|t| t.upgrade()) - .and_then(|t| t.mid()); - - for (idx, encoding) in track_encodings.iter().enumerate() { - let write_stream = Arc::new(InterceptorToTrackLocalWriter::new(self.paused.clone())); - let mut context = TrackLocalContext { - id: self.id.clone(), - params: self.media_engine.get_rtp_parameters_by_kind( - encoding.track.kind(), - RTCRtpTransceiverDirection::Sendonly, - ), - ssrc: parameters.encodings[idx].ssrc, - write_stream: Some( - Arc::clone(&write_stream) as Arc - ), - paused: self.paused.clone(), - mid: mid.to_owned(), - }; - - let codec = encoding.track.bind(&context).await?; - let stream_info = create_stream_info( - self.id.clone(), - parameters.encodings[idx].ssrc, - codec.payload_type, - codec.capability.clone(), - ¶meters.rtp_parameters.header_extensions, - ); - context.params.codecs = vec![codec]; - - let srtp_writer = Arc::clone(&encoding.srtp_stream) as Arc; - let rtp_writer = self - .interceptor - .bind_local_stream(&stream_info, srtp_writer) - .await; - - *encoding.context.lock().await = context; - *encoding.stream_info.lock().await = stream_info; - *write_stream.interceptor_rtp_writer.lock().await = Some(rtp_writer); - } - - self.send_called_tx.lock().take(); - Ok(()) - } - - /// stop irreversibly stops the RTPSender - pub async fn stop(&self) -> Result<()> { - if self.stop_called_signal.load(Ordering::SeqCst) { - return Ok(()); - } - self.stop_called_signal.store(true, Ordering::SeqCst); - self.stop_called_tx.notify_waiters(); - - if !self.has_sent() { - return Ok(()); - } - - self.replace_track(None).await?; - - let track_encodings = self.track_encodings.lock().await; - for encoding in track_encodings.iter() { - let stream_info = encoding.stream_info.lock().await; - self.interceptor.unbind_local_stream(&stream_info).await; - - encoding.srtp_stream.close().await?; - } - - Ok(()) - } - - /// read reads incoming RTCP for this RTPReceiver - pub async fn read( - &self, - b: &mut [u8], - ) -> Result<(Vec>, Attributes)> { - let mut send_called_rx = self.internal.send_called_rx.lock().await; - - tokio::select! { - _ = send_called_rx.recv() => { - let rtcp_interceptor = { - let track_encodings = self.track_encodings.lock().await; - track_encodings.first().map(|e|e.rtcp_interceptor.clone()) - }.ok_or(Error::ErrInterceptorNotBind)?; - let a = Attributes::new(); - tokio::select! { - _ = self.internal.stop_called_rx.notified() => Err(Error::ErrClosedPipe), - result = rtcp_interceptor.read(b, &a) => Ok(result?), - } - } - _ = self.internal.stop_called_rx.notified() => Err(Error::ErrClosedPipe), - } - } - - /// read_rtcp is a convenience method that wraps Read and unmarshals for you. - pub async fn read_rtcp( - &self, - ) -> Result<(Vec>, Attributes)> { - let mut b = vec![0u8; self.receive_mtu]; - let (pkts, attributes) = self.read(&mut b).await?; - - Ok((pkts, attributes)) - } - - /// ReadSimulcast reads incoming RTCP for this RTPSender for given rid - pub async fn read_simulcast( - &self, - b: &mut [u8], - rid: &str, - ) -> Result<(Vec>, Attributes)> { - let mut send_called_rx = self.internal.send_called_rx.lock().await; - - tokio::select! { - _ = send_called_rx.recv() => { - let rtcp_interceptor = { - let track_encodings = self.track_encodings.lock().await; - track_encodings.iter().find(|e| e.track.rid() == Some(rid)).map(|e| e.rtcp_interceptor.clone()) - }.ok_or(Error::ErrRTPSenderNoTrackForRID)?; - let a = Attributes::new(); - tokio::select! { - _ = self.internal.stop_called_rx.notified() => Err(Error::ErrClosedPipe), - result = rtcp_interceptor.read(b, &a) => Ok(result?), - } - } - _ = self.internal.stop_called_rx.notified() => Err(Error::ErrClosedPipe), - } - } - - /// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you - pub async fn read_rtcp_simulcast( - &self, - rid: &str, - ) -> Result<(Vec>, Attributes)> { - let mut b = vec![0u8; self.receive_mtu]; - let (pkts, attributes) = self.read_simulcast(&mut b, rid).await?; - - Ok((pkts, attributes)) - } - - /// Enables overriding outgoing `RTP` packets' `sequence number`s. - /// - /// Must be called once before any data sent or never called at all. - /// - /// # Errors - /// - /// Errors if this [`RTCRtpSender`] has started to send data or sequence - /// transforming has been already enabled. - pub fn enable_seq_transformer(&self) -> Result<()> { - self.seq_trans.enable() - } - - /// has_sent tells if data has been ever sent for this instance - pub(crate) fn has_sent(&self) -> bool { - let send_called_tx = self.send_called_tx.lock(); - send_called_tx.is_none() - } - - /// has_stopped tells if stop has been called - pub(crate) async fn has_stopped(&self) -> bool { - self.stop_called_signal.load(Ordering::SeqCst) - } - - pub(crate) fn initial_track_id(&self) -> Option { - let lock = self.initial_track_id.lock().unwrap(); - - lock.clone() - } - - pub(crate) fn set_initial_track_id(&self, id: String) -> Result<()> { - let mut lock = self.initial_track_id.lock().unwrap(); - - if lock.is_some() { - return Err(Error::ErrSenderInitialTrackIdAlreadySet); - } - - *lock = Some(id); - - Ok(()) - } - - pub(crate) fn associate_media_stream_id(&self, id: String) -> bool { - let mut lock = self.associated_media_stream_ids.lock().unwrap(); - - if lock.contains(&id) { - return false; - } - - lock.push(id); - - true - } - - pub(crate) fn associated_media_stream_ids(&self) -> Vec { - let lock = self.associated_media_stream_ids.lock().unwrap(); - - lock.clone() - } -} diff --git a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs deleted file mode 100644 index f65cdd24d..000000000 --- a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs +++ /dev/null @@ -1,622 +0,0 @@ -use bytes::Bytes; -use portable_atomic::AtomicU64; -use tokio::time::Duration; -use waitgroup::WaitGroup; - -use super::*; -use crate::api::media_engine::{MIME_TYPE_H264, MIME_TYPE_OPUS, MIME_TYPE_VP8, MIME_TYPE_VP9}; -use crate::api::setting_engine::SettingEngine; -use crate::api::APIBuilder; -use crate::error::Result; -use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; -use crate::peer_connection::peer_connection_test::{ - close_pair_now, create_vnet_pair, new_pair, send_video_until_done, signal_pair, - until_connection_state, -}; -use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use crate::rtp_transceiver::RTCRtpCodecParameters; -use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; - -#[tokio::test] -async fn test_rtp_sender_replace_track() -> Result<()> { - let mut s = SettingEngine::default(); - s.disable_srtp_replay_protection(true); - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - - let api = APIBuilder::new() - .with_setting_engine(s) - .with_media_engine(m) - .build(); - - let (mut sender, mut receiver) = new_pair(&api).await?; - - let track_a = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let track_b = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_H264.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let rtp_sender = sender - .add_track(Arc::clone(&track_a) as Arc) - .await?; - - let (seen_packet_a_tx, seen_packet_a_rx) = mpsc::channel::<()>(1); - let (seen_packet_b_tx, seen_packet_b_rx) = mpsc::channel::<()>(1); - - let seen_packet_a_tx = Arc::new(seen_packet_a_tx); - let seen_packet_b_tx = Arc::new(seen_packet_b_tx); - let on_track_count = Arc::new(AtomicU64::new(0)); - receiver.on_track(Box::new(move |track, _, _| { - assert_eq!(on_track_count.fetch_add(1, Ordering::SeqCst), 0); - let seen_packet_a_tx2 = Arc::clone(&seen_packet_a_tx); - let seen_packet_b_tx2 = Arc::clone(&seen_packet_b_tx); - Box::pin(async move { - let pkt = match track.read_rtp().await { - Ok((pkt, _)) => pkt, - Err(err) => { - //assert!(errors.Is(io.EOF, err)) - log::debug!("{}", err); - return; - } - }; - - let last = pkt.payload[pkt.payload.len() - 1]; - if last == 0xAA { - assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); - let _ = seen_packet_a_tx2.send(()).await; - } else if last == 0xBB { - assert_eq!(track.codec().capability.mime_type, MIME_TYPE_H264); - let _ = seen_packet_b_tx2.send(()).await; - } else { - panic!("Unexpected RTP Data {last:02x}"); - } - }) - })); - - signal_pair(&mut sender, &mut receiver).await?; - - // Block Until packet with 0xAA has been seen - tokio::spawn(async move { - send_video_until_done( - seen_packet_a_rx, - vec![track_a], - Bytes::from_static(&[0xAA]), - None, - ) - .await; - }); - - rtp_sender - .replace_track(Some( - Arc::clone(&track_b) as Arc - )) - .await?; - - // Block Until packet with 0xBB has been seen - tokio::spawn(async move { - send_video_until_done( - seen_packet_b_rx, - vec![track_b], - Bytes::from_static(&[0xBB]), - None, - ) - .await; - }); - - close_pair_now(&sender, &receiver).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_get_parameters() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut offerer, mut answerer) = new_pair(&api).await?; - - let rtp_transceiver = offerer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - signal_pair(&mut offerer, &mut answerer).await?; - - let sender = rtp_transceiver.sender().await; - assert!(sender.track().await.is_some()); - let parameters = sender.get_parameters().await; - assert_ne!(0, parameters.rtp_parameters.codecs.len()); - assert_eq!(1, parameters.encodings.len()); - assert_eq!( - sender.track_encodings.lock().await[0].ssrc, - parameters.encodings[0].ssrc - ); - assert!(parameters.encodings[0].rid.is_empty()); - - close_pair_now(&offerer, &answerer).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_get_parameters_with_rid() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut offerer, mut answerer) = new_pair(&api).await?; - - let rtp_transceiver = offerer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - signal_pair(&mut offerer, &mut answerer).await?; - - let rid = "moo"; - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - rid.to_owned(), - "webrtc-rs".to_owned(), - )); - rtp_transceiver.set_sending_track(Some(track)).await?; - - let sender = rtp_transceiver.sender().await; - assert!(sender.track().await.is_some()); - let parameters = sender.get_parameters().await; - assert_ne!(0, parameters.rtp_parameters.codecs.len()); - assert_eq!(1, parameters.encodings.len()); - assert_eq!( - sender.track_encodings.lock().await[0].ssrc, - parameters.encodings[0].ssrc - ); - assert_eq!(rid, parameters.encodings[0].rid); - - close_pair_now(&offerer, &answerer).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_set_read_deadline() -> Result<()> { - let (mut sender, mut receiver, wan) = create_vnet_pair().await?; - - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let rtp_sender = sender - .add_track(Arc::clone(&track) as Arc) - .await?; - - let peer_connections_connected = WaitGroup::new(); - until_connection_state( - &mut sender, - &peer_connections_connected, - RTCPeerConnectionState::Connected, - ) - .await; - until_connection_state( - &mut receiver, - &peer_connections_connected, - RTCPeerConnectionState::Connected, - ) - .await; - - signal_pair(&mut sender, &mut receiver).await?; - - peer_connections_connected.wait().await; - - let result = tokio::time::timeout(Duration::from_secs(1), rtp_sender.read_rtcp()).await; - assert!(result.is_err()); - - { - let mut w = wan.lock().await; - w.stop().await?; - } - close_pair_now(&sender, &receiver).await; - - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_replace_track_invalid_track_kind_change() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut sender, mut receiver) = new_pair(&api).await?; - - let track_a = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let track_b = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - ..Default::default() - }, - "audio".to_owned(), - "webrtc-rs".to_owned(), - )); - - let rtp_sender = sender - .add_track(Arc::clone(&track_a) as Arc) - .await?; - - signal_pair(&mut sender, &mut receiver).await?; - - let (seen_packet_tx, seen_packet_rx) = mpsc::channel::<()>(1); - let seen_packet_tx = Arc::new(seen_packet_tx); - receiver.on_track(Box::new(move |_, _, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - let _ = seen_packet_tx2.send(()).await; - }) - })); - - tokio::spawn(async move { - send_video_until_done( - seen_packet_rx, - vec![track_a], - Bytes::from_static(&[0xAA]), - None, - ) - .await; - }); - - if let Err(err) = rtp_sender.replace_track(Some(track_b)).await { - assert_eq!(err, Error::ErrRTPSenderNewTrackHasIncorrectKind); - } else { - panic!(); - } - - close_pair_now(&sender, &receiver).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_replace_track_invalid_codec_change() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut sender, mut receiver) = new_pair(&api).await?; - - let track_a = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let track_b = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let rtp_sender = sender - .add_track(Arc::clone(&track_a) as Arc) - .await?; - - { - let tr = rtp_sender.rtp_transceiver.lock(); - let t = tr - .as_ref() - .and_then(|t| t.upgrade()) - .expect("Weak transceiver valid"); - t.set_codec_preferences(vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - payload_type: 96, - ..Default::default() - }]) - .await?; - } - - signal_pair(&mut sender, &mut receiver).await?; - - let (seen_packet_tx, seen_packet_rx) = mpsc::channel::<()>(1); - let seen_packet_tx = Arc::new(seen_packet_tx); - receiver.on_track(Box::new(move |_, _, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - let _ = seen_packet_tx2.send(()).await; - }) - })); - - tokio::spawn(async move { - send_video_until_done( - seen_packet_rx, - vec![track_a], - Bytes::from_static(&[0xAA]), - None, - ) - .await; - }); - - if let Err(err) = rtp_sender.replace_track(Some(track_b)).await { - assert_eq!(err, Error::ErrUnsupportedCodec); - } else { - panic!(); - } - - close_pair_now(&sender, &receiver).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_get_parameters_replaced() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (sender, receiver) = new_pair(&api).await?; - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let rtp_sender = sender.add_track(track).await?; - let param = rtp_sender.get_parameters().await; - assert_eq!(1, param.encodings.len()); - - rtp_sender.replace_track(None).await?; - let param = rtp_sender.get_parameters().await; - assert_eq!(0, param.encodings.len()); - - close_pair_now(&sender, &receiver).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_send() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (sender, receiver) = new_pair(&api).await?; - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let rtp_sender = sender.add_track(track).await?; - let param = rtp_sender.get_parameters().await; - assert_eq!(1, param.encodings.len()); - - rtp_sender.send(¶m).await?; - - assert_eq!( - Error::ErrRTPSenderSendAlreadyCalled, - rtp_sender.send(¶m).await.unwrap_err() - ); - - close_pair_now(&sender, &receiver).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_send_track_removed() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (sender, receiver) = new_pair(&api).await?; - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let rtp_sender = sender.add_track(track).await?; - let param = rtp_sender.get_parameters().await; - assert_eq!(1, param.encodings.len()); - - sender.remove_track(&rtp_sender).await?; - assert_eq!( - Error::ErrRTPSenderTrackRemoved, - rtp_sender.send(¶m).await.unwrap_err() - ); - - close_pair_now(&sender, &receiver).await; - Ok(()) -} - -#[tokio::test] -async fn test_rtp_sender_add_encoding() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (sender, receiver) = new_pair(&api).await?; - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let rtp_sender = sender.add_track(track).await?; - - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderRidNil, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "h".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderNoBaseEncoding, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "f".to_owned(), - "webrtc-rs".to_owned(), - )); - let rtp_sender = sender.add_track(track).await?; - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video-foobar".to_owned(), - "h".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderBaseEncodingMismatch, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "h".to_owned(), - "webrtc-rs-foobar".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderBaseEncodingMismatch, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "h".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderBaseEncodingMismatch, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "f".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderRIDCollision, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "h".to_owned(), - "webrtc-rs".to_owned(), - )); - rtp_sender.add_encoding(track).await?; - - rtp_sender.send(&rtp_sender.get_parameters().await).await?; - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "f".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderSendAlreadyCalled, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - rtp_sender.stop().await?; - - let track = Arc::new(TrackLocalStaticSample::new_with_rid( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "f".to_owned(), - "webrtc-rs".to_owned(), - )); - assert_eq!( - Error::ErrRTPSenderStopped, - rtp_sender.add_encoding(track).await.unwrap_err() - ); - - close_pair_now(&sender, &receiver).await; - Ok(()) -} diff --git a/webrtc/src/rtp_transceiver/rtp_transceiver_direction.rs b/webrtc/src/rtp_transceiver/rtp_transceiver_direction.rs deleted file mode 100644 index 756731ede..000000000 --- a/webrtc/src/rtp_transceiver/rtp_transceiver_direction.rs +++ /dev/null @@ -1,210 +0,0 @@ -use std::fmt; - -/// RTPTransceiverDirection indicates the direction of the RTPTransceiver. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum RTCRtpTransceiverDirection { - Unspecified, - - /// Sendrecv indicates the RTPSender will offer - /// to send RTP and RTPReceiver the will offer to receive RTP. - Sendrecv, - - /// Sendonly indicates the RTPSender will offer to send RTP. - Sendonly, - - /// Recvonly indicates the RTPReceiver the will offer to receive RTP. - Recvonly, - - /// Inactive indicates the RTPSender won't offer - /// to send RTP and RTPReceiver the won't offer to receive RTP. - Inactive, -} - -const RTP_TRANSCEIVER_DIRECTION_SENDRECV_STR: &str = "sendrecv"; -const RTP_TRANSCEIVER_DIRECTION_SENDONLY_STR: &str = "sendonly"; -const RTP_TRANSCEIVER_DIRECTION_RECVONLY_STR: &str = "recvonly"; -const RTP_TRANSCEIVER_DIRECTION_INACTIVE_STR: &str = "inactive"; - -/// defines a procedure for creating a new -/// RTPTransceiverDirection from a raw string naming the transceiver direction. -impl From<&str> for RTCRtpTransceiverDirection { - fn from(raw: &str) -> Self { - match raw { - RTP_TRANSCEIVER_DIRECTION_SENDRECV_STR => RTCRtpTransceiverDirection::Sendrecv, - RTP_TRANSCEIVER_DIRECTION_SENDONLY_STR => RTCRtpTransceiverDirection::Sendonly, - RTP_TRANSCEIVER_DIRECTION_RECVONLY_STR => RTCRtpTransceiverDirection::Recvonly, - RTP_TRANSCEIVER_DIRECTION_INACTIVE_STR => RTCRtpTransceiverDirection::Inactive, - _ => RTCRtpTransceiverDirection::Unspecified, - } - } -} - -impl From for RTCRtpTransceiverDirection { - fn from(v: u8) -> Self { - match v { - 1 => RTCRtpTransceiverDirection::Sendrecv, - 2 => RTCRtpTransceiverDirection::Sendonly, - 3 => RTCRtpTransceiverDirection::Recvonly, - 4 => RTCRtpTransceiverDirection::Inactive, - _ => RTCRtpTransceiverDirection::Unspecified, - } - } -} - -impl fmt::Display for RTCRtpTransceiverDirection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - RTCRtpTransceiverDirection::Sendrecv => { - write!(f, "{RTP_TRANSCEIVER_DIRECTION_SENDRECV_STR}") - } - RTCRtpTransceiverDirection::Sendonly => { - write!(f, "{RTP_TRANSCEIVER_DIRECTION_SENDONLY_STR}") - } - RTCRtpTransceiverDirection::Recvonly => { - write!(f, "{RTP_TRANSCEIVER_DIRECTION_RECVONLY_STR}") - } - RTCRtpTransceiverDirection::Inactive => { - write!(f, "{RTP_TRANSCEIVER_DIRECTION_INACTIVE_STR}") - } - _ => write!(f, "{}", crate::UNSPECIFIED_STR), - } - } -} - -impl RTCRtpTransceiverDirection { - /// reverse indicate the opposite direction - pub fn reverse(&self) -> RTCRtpTransceiverDirection { - match *self { - RTCRtpTransceiverDirection::Sendonly => RTCRtpTransceiverDirection::Recvonly, - RTCRtpTransceiverDirection::Recvonly => RTCRtpTransceiverDirection::Sendonly, - _ => *self, - } - } - - pub fn intersect(&self, other: RTCRtpTransceiverDirection) -> RTCRtpTransceiverDirection { - Self::from_send_recv( - self.has_send() && other.has_send(), - self.has_recv() && other.has_recv(), - ) - } - - pub fn from_send_recv(send: bool, recv: bool) -> RTCRtpTransceiverDirection { - match (send, recv) { - (true, true) => Self::Sendrecv, - (true, false) => Self::Sendonly, - (false, true) => Self::Recvonly, - (false, false) => Self::Inactive, - } - } - - pub fn has_send(&self) -> bool { - matches!(self, Self::Sendrecv | Self::Sendonly) - } - - pub fn has_recv(&self) -> bool { - matches!(self, Self::Sendrecv | Self::Recvonly) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_rtp_transceiver_direction() { - let tests = vec![ - ("Unspecified", RTCRtpTransceiverDirection::Unspecified), - ("sendrecv", RTCRtpTransceiverDirection::Sendrecv), - ("sendonly", RTCRtpTransceiverDirection::Sendonly), - ("recvonly", RTCRtpTransceiverDirection::Recvonly), - ("inactive", RTCRtpTransceiverDirection::Inactive), - ]; - - for (ct_str, expected_type) in tests { - assert_eq!(RTCRtpTransceiverDirection::from(ct_str), expected_type); - } - } - - #[test] - fn test_rtp_transceiver_direction_string() { - let tests = vec![ - (RTCRtpTransceiverDirection::Unspecified, "Unspecified"), - (RTCRtpTransceiverDirection::Sendrecv, "sendrecv"), - (RTCRtpTransceiverDirection::Sendonly, "sendonly"), - (RTCRtpTransceiverDirection::Recvonly, "recvonly"), - (RTCRtpTransceiverDirection::Inactive, "inactive"), - ]; - - for (d, expected_string) in tests { - assert_eq!(d.to_string(), expected_string); - } - } - - #[test] - fn test_rtp_transceiver_has_send() { - let tests = vec![ - (RTCRtpTransceiverDirection::Unspecified, false), - (RTCRtpTransceiverDirection::Sendrecv, true), - (RTCRtpTransceiverDirection::Sendonly, true), - (RTCRtpTransceiverDirection::Recvonly, false), - (RTCRtpTransceiverDirection::Inactive, false), - ]; - - for (d, expected_value) in tests { - assert_eq!(d.has_send(), expected_value); - } - } - - #[test] - fn test_rtp_transceiver_has_recv() { - let tests = vec![ - (RTCRtpTransceiverDirection::Unspecified, false), - (RTCRtpTransceiverDirection::Sendrecv, true), - (RTCRtpTransceiverDirection::Sendonly, false), - (RTCRtpTransceiverDirection::Recvonly, true), - (RTCRtpTransceiverDirection::Inactive, false), - ]; - - for (d, expected_value) in tests { - assert_eq!(d.has_recv(), expected_value); - } - } - - #[test] - fn test_rtp_transceiver_from_send_recv() { - let tests = vec![ - (RTCRtpTransceiverDirection::Sendrecv, (true, true)), - (RTCRtpTransceiverDirection::Sendonly, (true, false)), - (RTCRtpTransceiverDirection::Recvonly, (false, true)), - (RTCRtpTransceiverDirection::Inactive, (false, false)), - ]; - - for (expected_value, (send, recv)) in tests { - assert_eq!( - RTCRtpTransceiverDirection::from_send_recv(send, recv), - expected_value - ); - } - } - - #[test] - fn test_rtp_transceiver_intersect() { - use RTCRtpTransceiverDirection::*; - - let tests = vec![ - ((Sendrecv, Recvonly), Recvonly), - ((Sendrecv, Sendonly), Sendonly), - ((Sendrecv, Inactive), Inactive), - ((Sendonly, Inactive), Inactive), - ((Recvonly, Inactive), Inactive), - ((Recvonly, Sendrecv), Recvonly), - ((Sendonly, Sendrecv), Sendonly), - ((Sendonly, Recvonly), Inactive), - ((Recvonly, Recvonly), Recvonly), - ]; - - for ((a, b), expected_direction) in tests { - assert_eq!(a.intersect(b), expected_direction); - } - } -} diff --git a/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs b/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs deleted file mode 100644 index 6e6cc75a9..000000000 --- a/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs +++ /dev/null @@ -1,356 +0,0 @@ -use portable_atomic::AtomicUsize; - -use super::*; -use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8, MIME_TYPE_VP9}; -use crate::api::APIBuilder; -use crate::dtls_transport::RTCDtlsTransport; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::peer_connection::peer_connection_test::{close_pair_now, create_vnet_pair}; - -#[tokio::test] -async fn test_rtp_transceiver_set_codec_preferences() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.push_codecs(m.video_codecs.clone(), RTPCodecType::Video) - .await; - m.push_codecs(m.audio_codecs.clone(), RTPCodecType::Audio) - .await; - - let media_video_codecs = m.video_codecs.clone(); - - let api = APIBuilder::new().with_media_engine(m).build(); - let interceptor = api.interceptor_registry.build("")?; - let transport = Arc::new(RTCDtlsTransport::default()); - let receiver = Arc::new(api.new_rtp_receiver( - RTPCodecType::Video, - Arc::clone(&transport), - Arc::clone(&interceptor), - )); - - let sender = Arc::new( - api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) - .await, - ); - - let tr = RTCRtpTransceiver::new( - receiver, - sender, - RTCRtpTransceiverDirection::Unspecified, - RTPCodecType::Video, - media_video_codecs.clone(), - Arc::clone(&api.media_engine), - None, - ) - .await; - - assert_eq!(&tr.get_codecs().await, &media_video_codecs); - - let fail_test_cases = vec![ - vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_string(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }], - vec![ - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_string(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_string(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 111, - ..Default::default() - }, - ], - ]; - - for test_case in fail_test_cases { - if let Err(err) = tr.set_codec_preferences(test_case).await { - assert_eq!(err, Error::ErrRTPTransceiverCodecUnsupported); - } else { - panic!(); - } - } - - let success_test_cases = vec![ - vec![RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_string(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }], - vec![ - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_string(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP9.to_string(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "profile-id=0".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 98, - ..Default::default() - }, - ], - ]; - - for test_case in success_test_cases { - tr.set_codec_preferences(test_case).await?; - } - - tr.set_codec_preferences(vec![]).await?; - assert_ne!(0, tr.get_codecs().await.len()); - - Ok(()) -} - -// Assert that SetCodecPreferences properly filters codecs and PayloadTypes are respected -#[tokio::test] -async fn test_rtp_transceiver_set_codec_preferences_payload_type() -> Result<()> { - let test_codec = RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: "video/test_codec".to_string(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 50, - ..Default::default() - }; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - let offer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - m.register_codec(test_codec.clone(), RTPCodecType::Video)?; - let api = APIBuilder::new().with_media_engine(m).build(); - let answer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let _ = offer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let answer_transceiver = answer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - answer_transceiver - .set_codec_preferences(vec![ - test_codec, - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_string(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_string(), - rtcp_feedback: vec![], - }, - payload_type: 51, - ..Default::default() - }, - ]) - .await?; - - let offer = offer_pc.create_offer(None).await?; - - offer_pc.set_local_description(offer.clone()).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - - // VP8 with proper PayloadType - assert!( - answer.sdp.contains("a=rtpmap:51 VP8/90000"), - "{}", - answer.sdp - ); - - // test_codec is ignored since offerer doesn't support - assert!(!answer.sdp.contains("test_codec")); - - close_pair_now(&offer_pc, &answer_pc).await; - - Ok(()) -} - -#[tokio::test] -async fn test_rtp_transceiver_direction_change() -> Result<()> { - let (offer_pc, answer_pc, _) = create_vnet_pair().await?; - - let offer_transceiver = offer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let _ = answer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let offer = offer_pc.create_offer(None).await?; - - offer_pc.set_local_description(offer.clone()).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - assert!(answer.sdp.contains("a=sendrecv"),); - answer_pc.set_local_description(answer.clone()).await?; - offer_pc.set_remote_description(answer).await?; - - offer_transceiver - .set_direction(RTCRtpTransceiverDirection::Inactive) - .await; - - let offer = offer_pc.create_offer(None).await?; - assert!(offer.sdp.contains("a=inactive"),); - - offer_pc.set_local_description(offer.clone()).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - assert!(answer.sdp.contains("a=inactive"),); - offer_pc.set_remote_description(answer).await?; - - close_pair_now(&offer_pc, &answer_pc).await; - - Ok(()) -} - -#[tokio::test] -async fn test_rtp_transceiver_set_direction_causing_negotiation() -> Result<()> { - let (offer_pc, answer_pc, _) = create_vnet_pair().await?; - - let count = Arc::new(AtomicUsize::new(0)); - - { - let count = count.clone(); - offer_pc.on_negotiation_needed(Box::new(move || { - let count = count.clone(); - Box::pin(async move { - count.fetch_add(1, Ordering::SeqCst); - }) - })); - } - - let offer_transceiver = offer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let _ = answer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let offer = offer_pc.create_offer(None).await?; - offer_pc.set_local_description(offer.clone()).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - answer_pc.set_local_description(answer.clone()).await?; - offer_pc.set_remote_description(answer).await?; - - assert_eq!(count.load(Ordering::SeqCst), 0); - - let offer = offer_pc.create_offer(None).await?; - offer_pc.set_local_description(offer.clone()).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - answer_pc.set_local_description(answer.clone()).await?; - offer_pc.set_remote_description(answer).await?; - - assert_eq!(count.load(Ordering::SeqCst), 0); - - offer_transceiver - .set_direction(RTCRtpTransceiverDirection::Inactive) - .await; - - // wait for negotiation ops queue to finish. - offer_pc.internal.ops.done().await; - - assert_eq!(count.load(Ordering::SeqCst), 1); - - close_pair_now(&offer_pc, &answer_pc).await; - - Ok(()) -} - -#[ignore] -#[tokio::test] -async fn test_rtp_transceiver_stopping() -> Result<()> { - let (offer_pc, answer_pc, _) = create_vnet_pair().await?; - - let offer_transceiver = offer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let _ = answer_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let offer = offer_pc.create_offer(None).await?; - - offer_pc.set_local_description(offer.clone()).await?; - answer_pc.set_remote_description(offer).await?; - - let answer = answer_pc.create_answer(None).await?; - assert!(answer.sdp.contains("a=sendrecv"),); - answer_pc.set_local_description(answer.clone()).await?; - offer_pc.set_remote_description(answer).await?; - - assert!( - offer_transceiver.mid().is_some(), - "A mid should have been associated with the transceiver when applying the answer" - ); - // Stop the transceiver - offer_transceiver.stop().await?; - - let offer = offer_pc.create_offer(None).await?; - assert!(offer.sdp.contains("a=inactive"),); - let parsed = offer.parsed.unwrap(); - let m = &parsed.media_descriptions[0]; - assert_eq!( - m.media_name.port.value, 0, - "After stopping a transceiver it should be rejected in offers" - ); - - close_pair_now(&offer_pc, &answer_pc).await; - - Ok(()) -} diff --git a/webrtc/src/rtp_transceiver/srtp_writer_future.rs b/webrtc/src/rtp_transceiver/srtp_writer_future.rs deleted file mode 100644 index 5ceadc828..000000000 --- a/webrtc/src/rtp_transceiver/srtp_writer_future.rs +++ /dev/null @@ -1,290 +0,0 @@ -use std::sync::atomic::Ordering; -use std::sync::{Arc, Weak}; - -use async_trait::async_trait; -use bytes::Bytes; -use interceptor::{Attributes, RTCPReader, RTPWriter}; -use portable_atomic::AtomicBool; -use srtp::session::Session; -use srtp::stream::Stream; -use tokio::sync::Mutex; -use util; - -use crate::dtls_transport::RTCDtlsTransport; -use crate::error::{Error, Result}; -use crate::rtp_transceiver::rtp_sender::RTPSenderInternal; -use crate::rtp_transceiver::SSRC; - -/// `RTP` packet sequence number manager. -/// -/// Used to override outgoing `RTP` packets' sequence numbers. On creating it is -/// unabled and can be enabled before sending data beginning. Once data sending -/// began it can not be enabled any more. -pub(crate) struct SequenceTransformer(util::sync::Mutex); - -/// [`SequenceTransformer`] inner. -struct SequenceTransformerInner { - offset: u16, - last_sq: u16, - reset_needed: bool, - enabled: bool, - data_sent: bool, -} - -impl SequenceTransformer { - /// Creates a new [`SequenceTransformer`]. - pub(crate) fn new() -> Self { - Self(util::sync::Mutex::new(SequenceTransformerInner { - offset: 0, - last_sq: rand::random(), - reset_needed: false, - enabled: false, - data_sent: false, - })) - } - - /// Enables this [`SequenceTransformer`]. - /// - /// # Errors - /// - /// With [`Error::ErrRTPSenderSeqTransEnabled`] on trying to enable already - /// enabled [`SequenceTransformer`]. - /// - /// With [`Error::ErrRTPSenderSeqTransEnabled`] on trying to enable - /// [`SequenceTransformer`] after data sending began. - pub(crate) fn enable(&self) -> Result<()> { - let mut guard = self.0.lock(); - - if guard.enabled { - return Err(Error::ErrRTPSenderSeqTransEnabled); - } - - (!guard.data_sent) - .then(|| { - guard.enabled = true; - }) - .ok_or(Error::ErrRTPSenderDataSent) - } - - /// Indicates [`SequenceTransformer`] about necessity of recalculating - /// `offset`. - pub(crate) fn reset_offset(&self) { - self.0.lock().reset_needed = true; - } - - /// Gets [`Some`] consistent `sequence number` if this [`SequenceTransformer`] is - /// enabled or [`None`] if it is not. - /// - /// Once this method is called, considers data sending began. - fn seq_number(&self, raw_sn: u16) -> Option { - let mut guard = self.0.lock(); - guard.data_sent = true; - - if !guard.enabled { - return None; - } - - let offset = guard - .reset_needed - .then(|| { - guard.reset_needed = false; - let offset = guard.last_sq.overflowing_sub(raw_sn.overflowing_sub(1).0).0; - guard.offset = offset; - offset - }) - .unwrap_or(guard.offset); - let next = raw_sn.overflowing_add(offset).0; - guard.last_sq = next; - - Some(next) - } -} - -/// SrtpWriterFuture blocks Read/Write calls until -/// the SRTP Session is available -pub(crate) struct SrtpWriterFuture { - pub(crate) closed: AtomicBool, - pub(crate) ssrc: SSRC, - pub(crate) rtp_sender: Weak, - pub(crate) rtp_transport: Arc, - pub(crate) rtcp_read_stream: Mutex>>, // atomic.Value // * - pub(crate) rtp_write_session: Mutex>>, // atomic.Value // * - pub(crate) seq_trans: Arc, -} - -impl SrtpWriterFuture { - async fn init(&self, return_when_no_srtp: bool) -> Result<()> { - if return_when_no_srtp { - { - if let Some(rtp_sender) = self.rtp_sender.upgrade() { - if rtp_sender.stop_called_signal.load(Ordering::SeqCst) { - return Err(Error::ErrClosedPipe); - } - } else { - return Err(Error::ErrClosedPipe); - } - } - - if !self.rtp_transport.srtp_ready_signal.load(Ordering::SeqCst) { - return Ok(()); - } - } else { - let mut rx = self.rtp_transport.srtp_ready_rx.lock().await; - if let Some(srtp_ready_rx) = &mut *rx { - if let Some(rtp_sender) = self.rtp_sender.upgrade() { - tokio::select! { - _ = rtp_sender.stop_called_rx.notified()=> return Err(Error::ErrClosedPipe), - _ = srtp_ready_rx.recv() =>{} - } - } else { - return Err(Error::ErrClosedPipe); - } - } - } - - if self.closed.load(Ordering::SeqCst) { - return Err(Error::ErrClosedPipe); - } - - if let Some(srtcp_session) = self.rtp_transport.get_srtcp_session().await { - let rtcp_read_stream = srtcp_session.open(self.ssrc).await; - let mut stream = self.rtcp_read_stream.lock().await; - *stream = Some(rtcp_read_stream); - } - - { - let srtp_session = self.rtp_transport.get_srtp_session().await; - let mut session = self.rtp_write_session.lock().await; - *session = srtp_session; - } - - Ok(()) - } - - pub async fn close(&self) -> Result<()> { - if self.closed.load(Ordering::SeqCst) { - return Ok(()); - } - self.closed.store(true, Ordering::SeqCst); - - let stream = { - let mut stream = self.rtcp_read_stream.lock().await; - stream.take() - }; - if let Some(rtcp_read_stream) = stream { - Ok(rtcp_read_stream.close().await?) - } else { - Ok(()) - } - } - - pub async fn read(&self, b: &mut [u8]) -> Result { - { - let stream = { - let stream = self.rtcp_read_stream.lock().await; - stream.clone() - }; - if let Some(rtcp_read_stream) = stream { - return Ok(rtcp_read_stream.read(b).await?); - } - } - - self.init(false).await?; - - { - let stream = { - let stream = self.rtcp_read_stream.lock().await; - stream.clone() - }; - if let Some(rtcp_read_stream) = stream { - return Ok(rtcp_read_stream.read(b).await?); - } - } - - Ok(0) - } - - pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result { - { - let session = { - let session = self.rtp_write_session.lock().await; - session.clone() - }; - if let Some(rtp_write_session) = session { - return Ok(rtp_write_session.write_rtp(pkt).await?); - } - } - - self.init(true).await?; - - { - let session = { - let session = self.rtp_write_session.lock().await; - session.clone() - }; - if let Some(rtp_write_session) = session { - return Ok(rtp_write_session.write_rtp(pkt).await?); - } - } - - Ok(0) - } - - pub async fn write(&self, b: &Bytes) -> Result { - { - let session = { - let session = self.rtp_write_session.lock().await; - session.clone() - }; - if let Some(rtp_write_session) = session { - return Ok(rtp_write_session.write(b, true).await?); - } - } - - self.init(true).await?; - - { - let session = { - let session = self.rtp_write_session.lock().await; - session.clone() - }; - if let Some(rtp_write_session) = session { - return Ok(rtp_write_session.write(b, true).await?); - } - } - - Ok(0) - } -} - -type IResult = std::result::Result; - -#[async_trait] -impl RTCPReader for SrtpWriterFuture { - async fn read( - &self, - buf: &mut [u8], - a: &Attributes, - ) -> IResult<(Vec>, Attributes)> { - let read = self.read(buf).await?; - let pkt = rtcp::packet::unmarshal(&mut &buf[..read])?; - - Ok((pkt, a.clone())) - } -} - -#[async_trait] -impl RTPWriter for SrtpWriterFuture { - async fn write(&self, pkt: &rtp::packet::Packet, _a: &Attributes) -> IResult { - Ok( - match self.seq_trans.seq_number(pkt.header.sequence_number) { - Some(seq_num) => { - let mut new_pkt = pkt.clone(); - new_pkt.header.sequence_number = seq_num; - self.write_rtp(&new_pkt).await? - } - None => self.write_rtp(pkt).await?, - }, - ) - } -} diff --git a/webrtc/src/sctp_transport/mod.rs b/webrtc/src/sctp_transport/mod.rs deleted file mode 100644 index f4914c301..000000000 --- a/webrtc/src/sctp_transport/mod.rs +++ /dev/null @@ -1,442 +0,0 @@ -#[cfg(test)] -mod sctp_transport_test; - -pub mod sctp_transport_capabilities; -pub mod sctp_transport_state; - -use std::collections::{HashMap, HashSet}; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use data::data_channel::DataChannel; -use data::message::message_channel_open::ChannelType; -use portable_atomic::{AtomicBool, AtomicU32, AtomicU8}; -use sctp::association::Association; -use sctp_transport_state::RTCSctpTransportState; -use tokio::sync::{Mutex, Notify}; -use util::Conn; - -use crate::api::setting_engine::SettingEngine; -use crate::data_channel::data_channel_parameters::DataChannelParameters; -use crate::data_channel::data_channel_state::RTCDataChannelState; -use crate::data_channel::RTCDataChannel; -use crate::dtls_transport::dtls_role::DTLSRole; -use crate::dtls_transport::*; -use crate::error::*; -use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; -use crate::stats::stats_collector::StatsCollector; -use crate::stats::StatsReportType::{PeerConnection, SCTPTransport}; -use crate::stats::{ICETransportStats, PeerConnectionStats}; - -const SCTP_MAX_CHANNELS: u16 = u16::MAX; - -pub type OnDataChannelHdlrFn = Box< - dyn (FnMut(Arc) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -pub type OnDataChannelOpenedHdlrFn = Box< - dyn (FnMut(Arc) -> Pin + Send + 'static>>) - + Send - + Sync, ->; - -struct AcceptDataChannelParams { - notify_rx: Arc, - sctp_association: Arc, - data_channels: Arc>>>, - on_error_handler: Arc>>, - on_data_channel_handler: Arc>>, - on_data_channel_opened_handler: Arc>>, - data_channels_opened: Arc, - data_channels_accepted: Arc, - setting_engine: Arc, -} - -/// SCTPTransport provides details about the SCTP transport. -#[derive(Default)] -pub struct RTCSctpTransport { - pub(crate) dtls_transport: Arc, - - // State represents the current state of the SCTP transport. - state: AtomicU8, // RTCSctpTransportState - - // SCTPTransportState doesn't have an enum to distinguish between New/Connecting - // so we need a dedicated field - is_started: AtomicBool, - - // max_message_size represents the maximum size of data that can be passed to - // DataChannel's send() method. - max_message_size: usize, - - // max_channels represents the maximum amount of DataChannel's that can - // be used simultaneously. - max_channels: u16, - - sctp_association: Mutex>>, - - on_error_handler: Arc>>, - on_data_channel_handler: Arc>>, - on_data_channel_opened_handler: Arc>>, - - // DataChannels - pub(crate) data_channels: Arc>>>, - pub(crate) data_channels_opened: Arc, - pub(crate) data_channels_requested: Arc, - data_channels_accepted: Arc, - - notify_tx: Arc, - - setting_engine: Arc, -} - -impl RTCSctpTransport { - pub(crate) fn new( - dtls_transport: Arc, - setting_engine: Arc, - ) -> Self { - RTCSctpTransport { - dtls_transport, - state: AtomicU8::new(RTCSctpTransportState::Connecting as u8), - is_started: AtomicBool::new(false), - max_message_size: RTCSctpTransport::calc_message_size(65536, 65536), - max_channels: SCTP_MAX_CHANNELS, - sctp_association: Mutex::new(None), - on_error_handler: Arc::new(ArcSwapOption::empty()), - on_data_channel_handler: Arc::new(ArcSwapOption::empty()), - on_data_channel_opened_handler: Arc::new(ArcSwapOption::empty()), - - data_channels: Arc::new(Mutex::new(vec![])), - data_channels_opened: Arc::new(AtomicU32::new(0)), - data_channels_requested: Arc::new(AtomicU32::new(0)), - data_channels_accepted: Arc::new(AtomicU32::new(0)), - - notify_tx: Arc::new(Notify::new()), - - setting_engine, - } - } - - /// transport returns the DTLSTransport instance the SCTPTransport is sending over. - pub fn transport(&self) -> Arc { - Arc::clone(&self.dtls_transport) - } - - /// get_capabilities returns the SCTPCapabilities of the SCTPTransport. - pub fn get_capabilities(&self) -> SCTPTransportCapabilities { - SCTPTransportCapabilities { - max_message_size: 0, - } - } - - /// Start the SCTPTransport. Since both local and remote parties must mutually - /// create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish - /// a connection over SCTP. - pub async fn start(&self, _remote_caps: SCTPTransportCapabilities) -> Result<()> { - if self.is_started.load(Ordering::SeqCst) { - return Ok(()); - } - self.is_started.store(true, Ordering::SeqCst); - - let dtls_transport = self.transport(); - if let Some(net_conn) = &dtls_transport.conn().await { - let sctp_association = loop { - tokio::select! { - _ = self.notify_tx.notified() => { - // It seems like notify_tx is only notified on Stop so perhaps this check - // is redundant. - // TODO: Consider renaming notify_tx to shutdown_tx. - if self.state.load(Ordering::SeqCst) == RTCSctpTransportState::Closed as u8 { - return Err(Error::ErrSCTPTransportDTLS); - } - }, - association = sctp::association::Association::client(sctp::association::Config { - net_conn: Arc::clone(net_conn) as Arc, - max_receive_buffer_size: 0, - max_message_size: 0, - name: String::new(), - }) => { - break Arc::new(association?); - } - }; - }; - - { - let mut sa = self.sctp_association.lock().await; - *sa = Some(Arc::clone(&sctp_association)); - } - self.state - .store(RTCSctpTransportState::Connected as u8, Ordering::SeqCst); - - let param = AcceptDataChannelParams { - notify_rx: self.notify_tx.clone(), - sctp_association, - data_channels: Arc::clone(&self.data_channels), - on_error_handler: Arc::clone(&self.on_error_handler), - on_data_channel_handler: Arc::clone(&self.on_data_channel_handler), - on_data_channel_opened_handler: Arc::clone(&self.on_data_channel_opened_handler), - data_channels_opened: Arc::clone(&self.data_channels_opened), - data_channels_accepted: Arc::clone(&self.data_channels_accepted), - setting_engine: Arc::clone(&self.setting_engine), - }; - tokio::spawn(async move { - RTCSctpTransport::accept_data_channels(param).await; - }); - - Ok(()) - } else { - Err(Error::ErrSCTPTransportDTLS) - } - } - - /// Stop stops the SCTPTransport - pub async fn stop(&self) -> Result<()> { - { - let mut sctp_association = self.sctp_association.lock().await; - if let Some(sa) = sctp_association.take() { - sa.close().await?; - } - } - - self.state - .store(RTCSctpTransportState::Closed as u8, Ordering::SeqCst); - - self.notify_tx.notify_waiters(); - - Ok(()) - } - - async fn accept_data_channels(param: AcceptDataChannelParams) { - let dcs = param.data_channels.lock().await; - let mut existing_data_channels = Vec::new(); - for dc in dcs.iter() { - if let Some(dc) = dc.data_channel.lock().await.clone() { - existing_data_channels.push(dc); - } - } - drop(dcs); - - loop { - let dc = tokio::select! { - _ = param.notify_rx.notified() => break, - result = DataChannel::accept( - ¶m.sctp_association, - data::data_channel::Config::default(), - &existing_data_channels, - ) => { - match result { - Ok(dc) => dc, - Err(err) => { - if data::Error::ErrStreamClosed == err { - log::error!("Failed to accept data channel: {}", err); - if let Some(handler) = &*param.on_error_handler.load() { - let mut f = handler.lock().await; - f(err.into()).await; - } - } - break; - } - } - } - }; - - let mut max_retransmits = 0; - let mut max_packet_lifetime = 0; - let val = dc.config.reliability_parameter as u16; - let ordered; - - match dc.config.channel_type { - ChannelType::Reliable => { - ordered = true; - } - ChannelType::ReliableUnordered => { - ordered = false; - } - ChannelType::PartialReliableRexmit => { - ordered = true; - max_retransmits = val; - } - ChannelType::PartialReliableRexmitUnordered => { - ordered = false; - max_retransmits = val; - } - ChannelType::PartialReliableTimed => { - ordered = true; - max_packet_lifetime = val; - } - ChannelType::PartialReliableTimedUnordered => { - ordered = false; - max_packet_lifetime = val; - } - }; - - let negotiated = if dc.config.negotiated { - Some(dc.stream_identifier()) - } else { - None - }; - let rtc_dc = Arc::new(RTCDataChannel::new( - DataChannelParameters { - label: dc.config.label.clone(), - protocol: dc.config.protocol.clone(), - negotiated, - ordered, - max_packet_life_time: max_packet_lifetime, - max_retransmits, - }, - Arc::clone(¶m.setting_engine), - )); - - if let Some(handler) = &*param.on_data_channel_handler.load() { - let mut f = handler.lock().await; - f(Arc::clone(&rtc_dc)).await; - - param.data_channels_accepted.fetch_add(1, Ordering::SeqCst); - - let mut dcs = param.data_channels.lock().await; - dcs.push(Arc::clone(&rtc_dc)); - } - - rtc_dc.handle_open(Arc::new(dc)).await; - - if let Some(handler) = &*param.on_data_channel_opened_handler.load() { - let mut f = handler.lock().await; - f(rtc_dc).await; - param.data_channels_opened.fetch_add(1, Ordering::SeqCst); - } - } - } - - /// on_error sets an event handler which is invoked when - /// the SCTP connection error occurs. - pub fn on_error(&self, f: OnErrorHdlrFn) { - self.on_error_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_data_channel sets an event handler which is invoked when a data - /// channel message arrives from a remote peer. - pub fn on_data_channel(&self, f: OnDataChannelHdlrFn) { - self.on_data_channel_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_data_channel_opened sets an event handler which is invoked when a data - /// channel is opened - pub fn on_data_channel_opened(&self, f: OnDataChannelOpenedHdlrFn) { - self.on_data_channel_opened_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - fn calc_message_size(remote_max_message_size: usize, can_send_size: usize) -> usize { - if remote_max_message_size == 0 && can_send_size == 0 { - usize::MAX - } else if remote_max_message_size == 0 { - can_send_size - } else if can_send_size == 0 || can_send_size > remote_max_message_size { - remote_max_message_size - } else { - can_send_size - } - } - - /// max_channels is the maximum number of RTCDataChannels that can be open simultaneously. - pub fn max_channels(&self) -> u16 { - if self.max_channels == 0 { - SCTP_MAX_CHANNELS - } else { - self.max_channels - } - } - - /// state returns the current state of the SCTPTransport - pub fn state(&self) -> RTCSctpTransportState { - self.state.load(Ordering::SeqCst).into() - } - - pub(crate) async fn collect_stats( - &self, - collector: &StatsCollector, - peer_connection_id: String, - ) { - let dtls_transport = self.transport(); - - // TODO: should this be collected? - dtls_transport.collect_stats(collector).await; - - // data channels - let mut data_channels_closed = 0; - let data_channels = self.data_channels.lock().await; - for data_channel in &*data_channels { - match data_channel.ready_state() { - RTCDataChannelState::Connecting => (), - RTCDataChannelState::Open => (), - _ => data_channels_closed += 1, - } - data_channel.collect_stats(collector).await; - } - - let mut reports = HashMap::new(); - let peer_connection_stats = - PeerConnectionStats::new(self, peer_connection_id.clone(), data_channels_closed); - reports.insert(peer_connection_id, PeerConnection(peer_connection_stats)); - - // conn - if let Some(agent) = dtls_transport.ice_transport.gatherer.get_agent().await { - let stats = ICETransportStats::new("sctp_transport".to_owned(), agent); - reports.insert(stats.id.clone(), SCTPTransport(stats)); - } - - collector.merge(reports); - } - - pub(crate) async fn generate_and_set_data_channel_id( - &self, - dtls_role: DTLSRole, - ) -> Result { - let mut id = 0u16; - if dtls_role != DTLSRole::Client { - id += 1; - } - - // Create map of ids so we can compare without double-looping each time. - let mut ids_map = HashSet::new(); - { - let data_channels = self.data_channels.lock().await; - for dc in &*data_channels { - ids_map.insert(dc.id()); - } - } - - let max = self.max_channels(); - while id < max - 1 { - if ids_map.contains(&id) { - id += 2; - } else { - return Ok(id); - } - } - - Err(Error::ErrMaxDataChannelID) - } - - pub(crate) async fn association(&self) -> Option> { - let sctp_association = self.sctp_association.lock().await; - sctp_association.clone() - } - - pub(crate) fn data_channels_accepted(&self) -> u32 { - self.data_channels_accepted.load(Ordering::SeqCst) - } - - pub(crate) fn data_channels_opened(&self) -> u32 { - self.data_channels_opened.load(Ordering::SeqCst) - } - - pub(crate) fn data_channels_requested(&self) -> u32 { - self.data_channels_requested.load(Ordering::SeqCst) - } -} diff --git a/webrtc/src/sctp_transport/sctp_transport_capabilities.rs b/webrtc/src/sctp_transport/sctp_transport_capabilities.rs deleted file mode 100644 index ee4b2a7cc..000000000 --- a/webrtc/src/sctp_transport/sctp_transport_capabilities.rs +++ /dev/null @@ -1,7 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// SCTPTransportCapabilities indicates the capabilities of the SCTPTransport. -#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] -pub struct SCTPTransportCapabilities { - pub max_message_size: u32, -} diff --git a/webrtc/src/sctp_transport/sctp_transport_state.rs b/webrtc/src/sctp_transport/sctp_transport_state.rs deleted file mode 100644 index 310b814c5..000000000 --- a/webrtc/src/sctp_transport/sctp_transport_state.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::fmt; - -/// SCTPTransportState indicates the state of the SCTP transport. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -#[repr(u8)] -pub enum RTCSctpTransportState { - #[default] - Unspecified, - - /// SCTPTransportStateConnecting indicates the SCTPTransport is in the - /// process of negotiating an association. This is the initial state of the - /// SCTPTransportState when an SCTPTransport is created. - Connecting, - - /// SCTPTransportStateConnected indicates the negotiation of an - /// association is completed. - Connected, - - /// SCTPTransportStateClosed indicates a SHUTDOWN or ABORT chunk is - /// received or when the SCTP association has been closed intentionally, - /// such as by closing the peer connection or applying a remote description - /// that rejects data or changes the SCTP port. - Closed, -} - -const SCTP_TRANSPORT_STATE_CONNECTING_STR: &str = "connecting"; -const SCTP_TRANSPORT_STATE_CONNECTED_STR: &str = "connected"; -const SCTP_TRANSPORT_STATE_CLOSED_STR: &str = "closed"; - -impl From<&str> for RTCSctpTransportState { - fn from(raw: &str) -> Self { - match raw { - SCTP_TRANSPORT_STATE_CONNECTING_STR => RTCSctpTransportState::Connecting, - SCTP_TRANSPORT_STATE_CONNECTED_STR => RTCSctpTransportState::Connected, - SCTP_TRANSPORT_STATE_CLOSED_STR => RTCSctpTransportState::Closed, - _ => RTCSctpTransportState::Unspecified, - } - } -} - -impl From for RTCSctpTransportState { - fn from(v: u8) -> Self { - match v { - 1 => RTCSctpTransportState::Connecting, - 2 => RTCSctpTransportState::Connected, - 3 => RTCSctpTransportState::Closed, - _ => RTCSctpTransportState::Unspecified, - } - } -} - -impl fmt::Display for RTCSctpTransportState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match *self { - RTCSctpTransportState::Connecting => SCTP_TRANSPORT_STATE_CONNECTING_STR, - RTCSctpTransportState::Connected => SCTP_TRANSPORT_STATE_CONNECTED_STR, - RTCSctpTransportState::Closed => SCTP_TRANSPORT_STATE_CLOSED_STR, - RTCSctpTransportState::Unspecified => crate::UNSPECIFIED_STR, - }; - write!(f, "{s}") - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_new_sctp_transport_state() { - let tests = vec![ - (crate::UNSPECIFIED_STR, RTCSctpTransportState::Unspecified), - ("connecting", RTCSctpTransportState::Connecting), - ("connected", RTCSctpTransportState::Connected), - ("closed", RTCSctpTransportState::Closed), - ]; - - for (state_string, expected_state) in tests { - assert_eq!( - RTCSctpTransportState::from(state_string), - expected_state, - "testCase: {expected_state}", - ); - } - } - - #[test] - fn test_sctp_transport_state_string() { - let tests = vec![ - (RTCSctpTransportState::Unspecified, crate::UNSPECIFIED_STR), - (RTCSctpTransportState::Connecting, "connecting"), - (RTCSctpTransportState::Connected, "connected"), - (RTCSctpTransportState::Closed, "closed"), - ]; - - for (state, expected_string) in tests { - assert_eq!(state.to_string(), expected_string) - } - } -} diff --git a/webrtc/src/sctp_transport/sctp_transport_test.rs b/webrtc/src/sctp_transport/sctp_transport_test.rs deleted file mode 100644 index 39a14cb80..000000000 --- a/webrtc/src/sctp_transport/sctp_transport_test.rs +++ /dev/null @@ -1,43 +0,0 @@ -use portable_atomic::AtomicU16; - -use super::*; - -#[tokio::test] -async fn test_generate_data_channel_id() -> Result<()> { - let sctp_transport_with_channels = |ids: &[u16]| -> RTCSctpTransport { - let mut data_channels = vec![]; - for id in ids { - data_channels.push(Arc::new(RTCDataChannel { - id: AtomicU16::new(*id), - ..Default::default() - })); - } - - RTCSctpTransport { - data_channels: Arc::new(Mutex::new(data_channels)), - ..Default::default() - } - }; - - let tests = vec![ - (DTLSRole::Client, sctp_transport_with_channels(&[]), 0), - (DTLSRole::Client, sctp_transport_with_channels(&[1]), 0), - (DTLSRole::Client, sctp_transport_with_channels(&[0]), 2), - (DTLSRole::Client, sctp_transport_with_channels(&[0, 2]), 4), - (DTLSRole::Client, sctp_transport_with_channels(&[0, 4]), 2), - (DTLSRole::Server, sctp_transport_with_channels(&[]), 1), - (DTLSRole::Server, sctp_transport_with_channels(&[0]), 1), - (DTLSRole::Server, sctp_transport_with_channels(&[1]), 3), - (DTLSRole::Server, sctp_transport_with_channels(&[1, 3]), 5), - (DTLSRole::Server, sctp_transport_with_channels(&[1, 5]), 3), - ]; - - for (role, s, expected) in tests { - match s.generate_and_set_data_channel_id(role).await { - Ok(actual) => assert_eq!(actual, expected), - Err(err) => panic!("failed to generate id: {err}"), - }; - } - - Ok(()) -} diff --git a/webrtc/src/stats/mod.rs b/webrtc/src/stats/mod.rs deleted file mode 100644 index 928059e8b..000000000 --- a/webrtc/src/stats/mod.rs +++ /dev/null @@ -1,688 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::SystemTime; - -use ice::agent::agent_stats::{CandidatePairStats, CandidateStats}; -use ice::agent::Agent; -use ice::candidate::{CandidatePairState, CandidateType}; -use ice::network_type::NetworkType; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use smol_str::SmolStr; -use stats_collector::StatsCollector; -use tokio::time::Instant; - -use crate::data_channel::data_channel_state::RTCDataChannelState; -use crate::data_channel::RTCDataChannel; -use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; -use crate::peer_connection::certificate::RTCCertificate; -use crate::rtp_transceiver::rtp_codec::RTCRtpCodecParameters; -use crate::rtp_transceiver::{PayloadType, SSRC}; -use crate::sctp_transport::RTCSctpTransport; - -mod serialize; -pub mod stats_collector; - -#[derive(Debug, Serialize, Deserialize)] -pub enum RTCStatsType { - #[serde(rename = "candidate-pair")] - CandidatePair, - #[serde(rename = "certificate")] - Certificate, - #[serde(rename = "codec")] - Codec, - #[serde(rename = "csrc")] - CSRC, - #[serde(rename = "data-channel")] - DataChannel, - #[serde(rename = "inbound-rtp")] - InboundRTP, - #[serde(rename = "local-candidate")] - LocalCandidate, - #[serde(rename = "outbound-rtp")] - OutboundRTP, - #[serde(rename = "peer-connection")] - PeerConnection, - #[serde(rename = "receiver")] - Receiver, - #[serde(rename = "remote-candidate")] - RemoteCandidate, - #[serde(rename = "remote-inbound-rtp")] - RemoteInboundRTP, - #[serde(rename = "remote-outbound-rtp")] - RemoteOutboundRTP, - #[serde(rename = "sender")] - Sender, - #[serde(rename = "transport")] - Transport, -} - -pub enum SourceStatsType { - LocalCandidate(CandidateStats), - RemoteCandidate(CandidateStats), -} - -#[derive(Debug)] -pub enum StatsReportType { - CandidatePair(ICECandidatePairStats), - CertificateStats(CertificateStats), - Codec(CodecStats), - DataChannel(DataChannelStats), - LocalCandidate(ICECandidateStats), - PeerConnection(PeerConnectionStats), - RemoteCandidate(ICECandidateStats), - SCTPTransport(ICETransportStats), - Transport(ICETransportStats), - InboundRTP(InboundRTPStats), - OutboundRTP(OutboundRTPStats), - RemoteInboundRTP(RemoteInboundRTPStats), - RemoteOutboundRTP(RemoteOutboundRTPStats), -} - -impl From for StatsReportType { - fn from(stats: SourceStatsType) -> Self { - match stats { - SourceStatsType::LocalCandidate(stats) => StatsReportType::LocalCandidate( - ICECandidateStats::new(stats, RTCStatsType::LocalCandidate), - ), - SourceStatsType::RemoteCandidate(stats) => StatsReportType::RemoteCandidate( - ICECandidateStats::new(stats, RTCStatsType::RemoteCandidate), - ), - } - } -} - -impl Serialize for StatsReportType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - StatsReportType::CandidatePair(stats) => stats.serialize(serializer), - StatsReportType::CertificateStats(stats) => stats.serialize(serializer), - StatsReportType::Codec(stats) => stats.serialize(serializer), - StatsReportType::DataChannel(stats) => stats.serialize(serializer), - StatsReportType::LocalCandidate(stats) => stats.serialize(serializer), - StatsReportType::PeerConnection(stats) => stats.serialize(serializer), - StatsReportType::RemoteCandidate(stats) => stats.serialize(serializer), - StatsReportType::SCTPTransport(stats) => stats.serialize(serializer), - StatsReportType::Transport(stats) => stats.serialize(serializer), - StatsReportType::InboundRTP(stats) => stats.serialize(serializer), - StatsReportType::OutboundRTP(stats) => stats.serialize(serializer), - StatsReportType::RemoteInboundRTP(stats) => stats.serialize(serializer), - StatsReportType::RemoteOutboundRTP(stats) => stats.serialize(serializer), - } - } -} - -impl<'de> Deserialize<'de> for StatsReportType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = serde_json::Value::deserialize(deserializer)?; - let type_field = value - .get("type") - .ok_or_else(|| serde::de::Error::missing_field("type"))?; - let rtc_type: RTCStatsType = serde_json::from_value(type_field.clone()).map_err(|e| { - serde::de::Error::custom(format!( - "failed to deserialize RTCStatsType from the `type` field ({}): {}", - type_field, e - )) - })?; - - match rtc_type { - RTCStatsType::CandidatePair => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::CandidatePair(stats)) - } - RTCStatsType::Certificate => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::CertificateStats(stats)) - } - RTCStatsType::Codec => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::Codec(stats)) - } - RTCStatsType::CSRC => { - todo!() - } - RTCStatsType::DataChannel => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::DataChannel(stats)) - } - RTCStatsType::InboundRTP => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::InboundRTP(stats)) - } - RTCStatsType::LocalCandidate => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::LocalCandidate(stats)) - } - RTCStatsType::OutboundRTP => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::OutboundRTP(stats)) - } - RTCStatsType::PeerConnection => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::PeerConnection(stats)) - } - RTCStatsType::Receiver => { - todo!() - } - RTCStatsType::RemoteCandidate => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::RemoteCandidate(stats)) - } - RTCStatsType::RemoteInboundRTP => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::RemoteInboundRTP(stats)) - } - RTCStatsType::RemoteOutboundRTP => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::RemoteOutboundRTP(stats)) - } - RTCStatsType::Sender => { - todo!() - } - RTCStatsType::Transport => { - let stats = serde_json::from_value(value).map_err(serde::de::Error::custom)?; - Ok(StatsReportType::Transport(stats)) - } - } - } -} - -#[derive(Debug)] -pub struct StatsReport { - pub reports: HashMap, -} - -impl From for StatsReport { - fn from(collector: StatsCollector) -> Self { - StatsReport { - reports: collector.into_reports(), - } - } -} - -impl Serialize for StatsReport { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - self.reports.serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for StatsReport { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = serde_json::Value::deserialize(deserializer)?; - let root = value - .as_object() - .ok_or(serde::de::Error::custom("root object missing"))?; - - let mut reports = HashMap::new(); - for (key, value) in root { - let report = serde_json::from_value(value.clone()).map_err(|e| { - serde::de::Error::custom(format!( - "failed to deserialize `StatsReportType` from key={}, value={}: {}", - key, value, e - )) - })?; - reports.insert(key.clone(), report); - } - Ok(Self { reports }) - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ICECandidatePairStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCIceCandidatePairStats - // TODO: Add `transportId` - pub local_candidate_id: String, - pub remote_candidate_id: String, - pub state: CandidatePairState, - pub nominated: bool, - pub packets_sent: u32, - pub packets_received: u32, - pub bytes_sent: u64, - pub bytes_received: u64, - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub last_packet_sent_timestamp: Instant, - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub last_packet_received_timestamp: Instant, - pub total_round_trip_time: f64, - pub current_round_trip_time: f64, - pub available_outgoing_bitrate: f64, - pub available_incoming_bitrate: f64, - pub requests_received: u64, - pub requests_sent: u64, - pub responses_received: u64, - pub responses_sent: u64, - pub consent_requests_sent: u64, - // TODO: Add `packetsDiscardedOnSend` - // TODO: Add `bytesDiscardedOnSend` - - // Non-canon - pub circuit_breaker_trigger_count: u32, - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub consent_expired_timestamp: Instant, - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub first_request_timestamp: Instant, - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub last_request_timestamp: Instant, - pub retransmissions_sent: u64, -} - -impl From for ICECandidatePairStats { - fn from(stats: CandidatePairStats) -> Self { - ICECandidatePairStats { - available_incoming_bitrate: stats.available_incoming_bitrate, - available_outgoing_bitrate: stats.available_outgoing_bitrate, - bytes_received: stats.bytes_received, - bytes_sent: stats.bytes_sent, - circuit_breaker_trigger_count: stats.circuit_breaker_trigger_count, - consent_expired_timestamp: stats.consent_expired_timestamp, - consent_requests_sent: stats.consent_requests_sent, - current_round_trip_time: stats.current_round_trip_time, - first_request_timestamp: stats.first_request_timestamp, - id: format!("{}-{}", stats.local_candidate_id, stats.remote_candidate_id), - last_packet_received_timestamp: stats.last_packet_received_timestamp, - last_packet_sent_timestamp: stats.last_packet_sent_timestamp, - last_request_timestamp: stats.last_request_timestamp, - local_candidate_id: stats.local_candidate_id, - nominated: stats.nominated, - packets_received: stats.packets_received, - packets_sent: stats.packets_sent, - remote_candidate_id: stats.remote_candidate_id, - requests_received: stats.requests_received, - requests_sent: stats.requests_sent, - responses_received: stats.responses_received, - responses_sent: stats.responses_sent, - retransmissions_sent: stats.retransmissions_sent, - state: stats.state, - stats_type: RTCStatsType::CandidatePair, - timestamp: stats.timestamp, - total_round_trip_time: stats.total_round_trip_time, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ICECandidateStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCIceCandidateStats - pub candidate_type: CandidateType, - pub deleted: bool, - pub ip: String, - pub network_type: NetworkType, - pub port: u16, - pub priority: u32, - pub relay_protocol: String, - pub url: String, -} - -impl ICECandidateStats { - fn new(stats: CandidateStats, stats_type: RTCStatsType) -> Self { - ICECandidateStats { - candidate_type: stats.candidate_type, - deleted: stats.deleted, - id: stats.id, - ip: stats.ip, - network_type: stats.network_type, - port: stats.port, - priority: stats.priority, - relay_protocol: stats.relay_protocol, - stats_type, - timestamp: stats.timestamp, - url: stats.url, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ICETransportStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // Non-canon - pub bytes_received: usize, - pub bytes_sent: usize, -} - -impl ICETransportStats { - pub(crate) fn new(id: String, agent: Arc) -> Self { - ICETransportStats { - id, - bytes_received: agent.get_bytes_received(), - bytes_sent: agent.get_bytes_sent(), - stats_type: RTCStatsType::Transport, - timestamp: Instant::now(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CertificateStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCCertificateStats - pub fingerprint: String, - pub fingerprint_algorithm: String, - // TODO: Add `base64Certificate` and `issuerCertificateId`. -} - -impl CertificateStats { - pub(crate) fn new(cert: &RTCCertificate, fingerprint: RTCDtlsFingerprint) -> Self { - CertificateStats { - // TODO: base64_certificate - fingerprint: fingerprint.value, - fingerprint_algorithm: fingerprint.algorithm, - id: cert.stats_id.clone(), - // TODO: issuer_certificate_id - stats_type: RTCStatsType::Certificate, - timestamp: Instant::now(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CodecStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCCodecStats - pub payload_type: PayloadType, - pub mime_type: String, - pub channels: u16, - pub clock_rate: u32, - pub sdp_fmtp_line: String, - // TODO: Add `transportId` -} - -impl From<&RTCRtpCodecParameters> for CodecStats { - fn from(codec: &RTCRtpCodecParameters) -> Self { - CodecStats { - channels: codec.capability.channels, - clock_rate: codec.capability.clock_rate, - id: codec.stats_id.clone(), - mime_type: codec.capability.mime_type.clone(), - payload_type: codec.payload_type, - sdp_fmtp_line: codec.capability.sdp_fmtp_line.clone(), - stats_type: RTCStatsType::Codec, - timestamp: Instant::now(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct DataChannelStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCDataChannelStats - pub bytes_received: usize, - pub bytes_sent: usize, - pub data_channel_identifier: u16, - pub label: String, - pub messages_received: usize, - pub messages_sent: usize, - pub protocol: String, - pub state: RTCDataChannelState, -} - -impl DataChannelStats { - pub(crate) async fn from(data_channel: &RTCDataChannel) -> Self { - let state = data_channel.ready_state(); - - let mut bytes_received = 0; - let mut bytes_sent = 0; - let mut messages_received = 0; - let mut messages_sent = 0; - - let lock = data_channel.data_channel.lock().await; - - if let Some(internal) = &*lock { - bytes_received = internal.bytes_received(); - bytes_sent = internal.bytes_sent(); - messages_received = internal.messages_received(); - messages_sent = internal.messages_sent(); - } - - Self { - bytes_received, - bytes_sent, - data_channel_identifier: data_channel.id(), // TODO: "The value is initially null" - id: data_channel.stats_id.clone(), - label: data_channel.label.clone(), - messages_received, - messages_sent, - protocol: data_channel.protocol.clone(), - state, - stats_type: RTCStatsType::DataChannel, - timestamp: Instant::now(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct PeerConnectionStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCPeerConnectionStats - pub data_channels_closed: u32, - pub data_channels_opened: u32, - - // Non-canon - pub data_channels_accepted: u32, - pub data_channels_requested: u32, -} - -impl PeerConnectionStats { - pub fn new(transport: &RTCSctpTransport, stats_id: String, data_channels_closed: u32) -> Self { - PeerConnectionStats { - data_channels_accepted: transport.data_channels_accepted(), - data_channels_closed, - data_channels_opened: transport.data_channels_opened(), - data_channels_requested: transport.data_channels_requested(), - id: stats_id, - stats_type: RTCStatsType::PeerConnection, - timestamp: Instant::now(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct InboundRTPStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCRtpStreamStats - pub ssrc: SSRC, - pub kind: String, // Either "video" or "audio" - // TODO: Add transportId - // TODO: Add codecId - - // RTCReceivedRtpStreamStats - pub packets_received: u64, - // TODO: packetsLost - // TODO: jitter(maybe, might be uattainable for the same reason as `framesDropped`) - // NB: `framesDropped` can't be produced since we aren't decoding, might be worth introducing a - // way for consumers to control this in the future. - - // RTCInboundRtpStreamStats - pub track_identifier: String, - pub mid: SmolStr, - // TODO: `remoteId` - // NB: `framesDecoded`, `frameWidth`, frameHeight`, `framesPerSecond`, `qpSum`, - // `totalDecodeTime`, `totalInterFrameDelay`, and `totalSquaredInterFrameDelay` are all decoder - // specific values and can't be produced since we aren't decoding. - pub last_packet_received_timestamp: Option, - pub header_bytes_received: u64, - // TODO: `packetsDiscarded`. This value only makes sense if we have jitter buffer, which we - // cannot assume. - // TODO: `fecPacketsReceived`, `fecPacketsDiscarded` - pub bytes_received: u64, - pub nack_count: u64, - pub fir_count: Option, - pub pli_count: Option, - // NB: `totalProcessingDelay`, `estimatedPlayoutTimestamp`, `jitterBufferDelay`, - // `jitterBufferTargetDelay`, `jitterBufferEmittedCount`, `jitterBufferMinimumDelay`, - // `totalSamplesReceived`, `concealedSamples`, `silentConcealedSamples`, `concealmentEvents`, - // `insertedSamplesForDeceleration`, `removedSamplesForAcceleration`, `audioLevel`, - // `totalAudioEneregy`, `totalSampleDuration`, `framesReceived, and `decoderImplementation` are - // all decoder specific and can't be produced since we aren't decoding. -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct OutboundRTPStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCRtpStreamStats - pub ssrc: SSRC, - pub kind: String, // Either "video" or "audio" - // TODO: Add transportId - // TODO: Add codecId - - // RTCSentRtpStreamStats - pub packets_sent: u64, - pub bytes_sent: u64, - - // RTCOutboundRtpStreamStats - // NB: non-canon in browsers this is available via `RTCMediaSourceStats` which we are unlikely to implement - pub track_identifier: String, - pub mid: SmolStr, - // TODO: `mediaSourceId` and `remoteId` - pub rid: Option, - pub header_bytes_sent: u64, - // TODO: `retransmittedPacketsSent` and `retransmittedPacketsSent` - // NB: `targetBitrate`, `totalEncodedBytesTarget`, `frameWidth` `frameHeight`, `framesPerSecond`, `framesSent`, - // `hugeFramesSent`, `framesEncoded`, `keyFramesEncoded`, `qpSum`, and `totalEncodeTime` are - // all encoder specific and can't be produced snce we aren't encoding. - // TODO: `totalPacketSendDelay` time from `TrackLocalWriter::write_rtp` to being written to - // socket. - - // NB: `qualityLimitationReason`, `qualityLimitationDurations`, and `qualityLimitationResolutionChanges` are all - // encoder specific and can't be produced since we aren't encoding. - pub nack_count: u64, - pub fir_count: Option, - pub pli_count: Option, - // NB: `encoderImplementation` is encoder specific and can't be produced since we aren't - // encoding. -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RemoteInboundRTPStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCRtpStreamStats - pub ssrc: SSRC, - pub kind: String, // Either "video" or "audio" - // TODO: Add transportId - // TODO: Add codecId - - // RTCReceivedRtpStreamStats - pub packets_received: u64, - pub packets_lost: i64, - // TODO: jitter(maybe, might be uattainable for the same reason as `framesDropped`) - // NB: `framesDropped` can't be produced since we aren't decoding, might be worth introducing a - // way for consumers to control this in the future. - - // RTCRemoteInboundRtpStreamStats - pub local_id: String, - pub round_trip_time: Option, - pub total_round_trip_time: f64, - pub fraction_lost: f64, - pub round_trip_time_measurements: u64, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RemoteOutboundRTPStats { - // RTCStats - #[serde(with = "serialize::instant_to_epoch_seconds")] - pub timestamp: Instant, - #[serde(rename = "type")] - pub stats_type: RTCStatsType, - pub id: String, - - // RTCRtpStreamStats - pub ssrc: SSRC, - pub kind: String, // Either "video" or "audio" - // TODO: Add transportId - // TODO: Add codecId - - // RTCSentRtpStreamStats - pub packets_sent: u64, - pub bytes_sent: u64, - - // RTCRemoteOutboundRtpStreamStats - pub local_id: String, - // TODO: `remote_timestamp` - pub round_trip_time: Option, - pub reports_sent: u64, - pub total_round_trip_time: f64, - pub round_trip_time_measurements: u64, -} diff --git a/webrtc/src/stats/serialize.rs b/webrtc/src/stats/serialize.rs deleted file mode 100644 index b1abd394b..000000000 --- a/webrtc/src/stats/serialize.rs +++ /dev/null @@ -1,50 +0,0 @@ -/// Serializes a `tokio::time::Instant` to an approximation of epoch time in the form -/// of an `f64` where the integer portion is seconds and the decimal portion is milliseconds. -/// For instance, `Monday, May 30, 2022 10:45:26.456 PM UTC` converts to `1653950726.456`. -/// -/// Note that an `Instant` is not connected to real world time, so this conversion is -/// approximate. -pub mod instant_to_epoch_seconds { - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; - use tokio::time::Instant; - - pub fn serialize(instant: &Instant, serializer: S) -> Result - where - S: Serializer, - { - let system_now = SystemTime::now(); - let instant_now = Instant::now(); - let approx = system_now - (instant_now - *instant); - let epoch = approx - .duration_since(UNIX_EPOCH) - .expect("Time went backwards"); - - let epoch_ms = epoch.as_millis() as f64 / 1000.0; - - epoch_ms.serialize(serializer) - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let epoch_seconds: f64 = Deserialize::deserialize(deserializer)?; - - let since_epoch = Duration::from_secs_f64(epoch_seconds); - - let system_now = SystemTime::now(); - let instant_now = Instant::now(); - - let deserialized_system_time = UNIX_EPOCH + since_epoch; - - let adjustment = match deserialized_system_time.duration_since(system_now) { - Ok(duration) => -duration.as_secs_f64(), - Err(e) => e.duration().as_secs_f64(), - }; - - let adjusted_instant = instant_now + Duration::from_secs_f64(adjustment); - - Ok(adjusted_instant) - } -} diff --git a/webrtc/src/stats/stats_collector.rs b/webrtc/src/stats/stats_collector.rs deleted file mode 100644 index e0228dfe9..000000000 --- a/webrtc/src/stats/stats_collector.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::collections::HashMap; - -use util::sync::Mutex; - -use super::StatsReportType; - -#[derive(Debug, Default)] -pub struct StatsCollector { - pub(crate) reports: Mutex>, -} - -impl StatsCollector { - pub(crate) fn new() -> Self { - StatsCollector { - ..Default::default() - } - } - - pub(crate) fn insert(&self, id: String, stats: StatsReportType) { - let mut reports = self.reports.lock(); - reports.insert(id, stats); - } - - pub(crate) fn merge(&self, stats: HashMap) { - let mut reports = self.reports.lock(); - reports.extend(stats) - } - - pub(crate) fn into_reports(self) -> HashMap { - self.reports.into_inner() - } -} diff --git a/webrtc/src/track/mod.rs b/webrtc/src/track/mod.rs deleted file mode 100644 index 8c0a2d5be..000000000 --- a/webrtc/src/track/mod.rs +++ /dev/null @@ -1,29 +0,0 @@ -pub mod track_local; -pub mod track_remote; - -use std::sync::Arc; - -use interceptor::stream_info::StreamInfo; -use interceptor::{RTCPReader, RTPReader}; -use track_remote::*; - -pub(crate) const RTP_OUTBOUND_MTU: usize = 1200; -pub(crate) const RTP_PAYLOAD_TYPE_BITMASK: u8 = 0x7F; - -#[derive(Clone)] -pub(crate) struct TrackStream { - pub(crate) stream_info: Option, - pub(crate) rtp_read_stream: Option>, - pub(crate) rtp_interceptor: Option>, - pub(crate) rtcp_read_stream: Option>, - pub(crate) rtcp_interceptor: Option>, -} - -/// TrackStreams maintains a mapping of RTP/RTCP streams to a specific track -/// a RTPReceiver may contain multiple streams if we are dealing with Simulcast -#[derive(Clone)] -pub(crate) struct TrackStreams { - pub(crate) track: Arc, - pub(crate) stream: TrackStream, - pub(crate) repair_stream: TrackStream, -} diff --git a/webrtc/src/track/track_local/mod.rs b/webrtc/src/track/track_local/mod.rs deleted file mode 100644 index fcfd8bda7..000000000 --- a/webrtc/src/track/track_local/mod.rs +++ /dev/null @@ -1,170 +0,0 @@ -#[cfg(test)] -mod track_local_static_test; - -pub mod track_local_static_rtp; -pub mod track_local_static_sample; - -use std::any::Any; -use std::fmt; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use async_trait::async_trait; -use interceptor::{Attributes, RTPWriter}; -use portable_atomic::AtomicBool; -use smol_str::SmolStr; -use tokio::sync::Mutex; -use util::Unmarshal; - -use crate::error::{Error, Result}; -use crate::rtp_transceiver::rtp_codec::*; -use crate::rtp_transceiver::*; - -/// TrackLocalWriter is the Writer for outbound RTP Packets -#[async_trait] -pub trait TrackLocalWriter: fmt::Debug { - /// write_rtp encrypts a RTP packet and writes to the connection - async fn write_rtp(&self, p: &rtp::packet::Packet) -> Result; - - /// write encrypts and writes a full RTP packet - async fn write(&self, b: &[u8]) -> Result; -} - -/// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used -/// in Interceptors. -#[derive(Default, Debug, Clone)] -pub struct TrackLocalContext { - pub(crate) id: String, - pub(crate) params: RTCRtpParameters, - pub(crate) ssrc: SSRC, - pub(crate) write_stream: Option>, - pub(crate) paused: Arc, - pub(crate) mid: Option, -} - -impl TrackLocalContext { - /// codec_parameters returns the negotiated RTPCodecParameters. These are the codecs supported by both - /// PeerConnections and the SSRC/PayloadTypes - pub fn codec_parameters(&self) -> &[RTCRtpCodecParameters] { - &self.params.codecs - } - - /// header_extensions returns the negotiated RTPHeaderExtensionParameters. These are the header extensions supported by - /// both PeerConnections and the SSRC/PayloadTypes - pub fn header_extensions(&self) -> &[RTCRtpHeaderExtensionParameters] { - &self.params.header_extensions - } - - /// ssrc requires the negotiated SSRC of this track - /// This track may have multiple if RTX is enabled - pub fn ssrc(&self) -> SSRC { - self.ssrc - } - - /// write_stream returns the write_stream for this TrackLocal. The implementer writes the outbound - /// media packets to it - pub fn write_stream(&self) -> Option> { - self.write_stream.clone() - } - - /// id is a unique identifier that is used for both bind/unbind - pub fn id(&self) -> String { - self.id.clone() - } -} -/// TrackLocal is an interface that controls how the user can send media -/// The user can provide their own TrackLocal implementations, or use -/// the implementations in pkg/media -#[async_trait] -pub trait TrackLocal { - /// bind should implement the way how the media data flows from the Track to the PeerConnection - /// This will be called internally after signaling is complete and the list of available - /// codecs has been determined - async fn bind(&self, t: &TrackLocalContext) -> Result; - - /// unbind should implement the teardown logic when the track is no longer needed. This happens - /// because a track has been stopped. - async fn unbind(&self, t: &TrackLocalContext) -> Result<()>; - - /// id is the unique identifier for this Track. This should be unique for the - /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' - /// and stream_id would be 'desktop' or 'webcam' - fn id(&self) -> &str; - - /// RID is the RTP Stream ID for this track. - fn rid(&self) -> Option<&str>; - - /// stream_id is the group this track belongs too. This must be unique - fn stream_id(&self) -> &str; - - /// kind controls if this TrackLocal is audio or video - fn kind(&self) -> RTPCodecType; - - fn as_any(&self) -> &dyn Any; -} - -/// TrackBinding is a single bind for a Track -/// Bind can be called multiple times, this stores the -/// result for a single bind call so that it can be used when writing -#[derive(Default, Debug)] -pub(crate) struct TrackBinding { - id: String, - ssrc: SSRC, - payload_type: PayloadType, - params: RTCRtpParameters, - write_stream: Option>, - sender_paused: Arc, - hdr_ext_ids: Vec, -} - -impl TrackBinding { - pub fn is_sender_paused(&self) -> bool { - self.sender_paused.load(Ordering::SeqCst) - } -} - -pub(crate) struct InterceptorToTrackLocalWriter { - pub(crate) interceptor_rtp_writer: Mutex>>, - sender_paused: Arc, -} - -impl InterceptorToTrackLocalWriter { - pub(crate) fn new(paused: Arc) -> Self { - InterceptorToTrackLocalWriter { - interceptor_rtp_writer: Mutex::new(None), - sender_paused: paused, - } - } - - fn is_sender_paused(&self) -> bool { - self.sender_paused.load(Ordering::SeqCst) - } -} - -impl std::fmt::Debug for InterceptorToTrackLocalWriter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InterceptorToTrackLocalWriter").finish() - } -} - -#[async_trait] -impl TrackLocalWriter for InterceptorToTrackLocalWriter { - async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result { - if self.is_sender_paused() { - return Ok(0); - } - - let interceptor_rtp_writer = self.interceptor_rtp_writer.lock().await; - if let Some(writer) = &*interceptor_rtp_writer { - let a = Attributes::new(); - Ok(writer.write(pkt, &a).await?) - } else { - Ok(0) - } - } - - async fn write(&self, mut b: &[u8]) -> Result { - let pkt = rtp::packet::Packet::unmarshal(&mut b)?; - self.write_rtp(&pkt).await - } -} diff --git a/webrtc/src/track/track_local/track_local_static_rtp.rs b/webrtc/src/track/track_local/track_local_static_rtp.rs deleted file mode 100644 index 18cdaed7e..000000000 --- a/webrtc/src/track/track_local/track_local_static_rtp.rs +++ /dev/null @@ -1,295 +0,0 @@ -use std::collections::HashMap; - -use bytes::{Bytes, BytesMut}; -use tokio::sync::Mutex; -use util::{Marshal, MarshalSize}; - -use super::*; -use crate::error::flatten_errs; - -/// TrackLocalStaticRTP is a TrackLocal that has a pre-set codec and accepts RTP Packets. -/// If you wish to send a media.Sample use TrackLocalStaticSample -#[derive(Debug)] -pub struct TrackLocalStaticRTP { - pub(crate) bindings: Mutex>>, - codec: RTCRtpCodecCapability, - id: String, - rid: Option, - stream_id: String, -} - -impl TrackLocalStaticRTP { - /// returns a TrackLocalStaticRTP without rid. - pub fn new(codec: RTCRtpCodecCapability, id: String, stream_id: String) -> Self { - TrackLocalStaticRTP { - codec, - bindings: Mutex::new(vec![]), - id, - rid: None, - stream_id, - } - } - - /// returns a TrackLocalStaticRTP with rid. - pub fn new_with_rid( - codec: RTCRtpCodecCapability, - id: String, - rid: String, - stream_id: String, - ) -> Self { - TrackLocalStaticRTP { - codec, - bindings: Mutex::new(vec![]), - id, - rid: Some(rid), - stream_id, - } - } - - /// codec gets the Codec of the track - pub fn codec(&self) -> RTCRtpCodecCapability { - self.codec.clone() - } - - pub async fn any_binding_paused(&self) -> bool { - let bindings = self.bindings.lock().await; - bindings - .iter() - .any(|b| b.sender_paused.load(Ordering::SeqCst)) - } - - pub async fn all_binding_paused(&self) -> bool { - let bindings = self.bindings.lock().await; - bindings - .iter() - .all(|b| b.sender_paused.load(Ordering::SeqCst)) - } - - /// write_rtp_with_extensions writes a RTP Packet to the TrackLocalStaticRTP - /// If one PeerConnection fails the packets will still be sent to - /// all PeerConnections. The error message will contain the ID of the failed - /// PeerConnections so you can remove them - /// - /// If the RTCRtpSender direction is such that no packets should be sent, any call to this - /// function are blocked internally. Care must be taken to not increase the sequence number - /// while the sender is paused. While the actual _sending_ is blocked, the receiver will - /// miss out when the sequence number "rolls over", which in turn will break SRTP. - /// - /// Extensions that are already configured on the packet are overwritten by extensions in - /// `extensions`. - pub async fn write_rtp_with_extensions( - &self, - p: &rtp::packet::Packet, - extensions: &[rtp::extension::HeaderExtension], - ) -> Result { - let mut n = 0; - let mut write_errs = vec![]; - let mut pkt = p.clone(); - - let bindings = { - let bindings = self.bindings.lock().await; - bindings.clone() - }; - // Prepare the extensions data - let extension_data: HashMap<_, _> = extensions - .iter() - .flat_map(|extension| { - let buf = { - let mut buf = BytesMut::with_capacity(extension.marshal_size()); - buf.resize(extension.marshal_size(), 0); - if let Err(err) = extension.marshal_to(&mut buf) { - write_errs.push(Error::Util(err)); - return None; - } - - buf.freeze() - }; - - Some((extension.uri(), buf)) - }) - .collect(); - - for b in bindings.into_iter() { - if b.is_sender_paused() { - // See caveat in function doc. - continue; - } - pkt.header.ssrc = b.ssrc; - pkt.header.payload_type = b.payload_type; - - for ext in b.hdr_ext_ids.iter() { - let payload = ext.payload.to_owned(); - if let Err(err) = pkt.header.set_extension(ext.id, payload) { - write_errs.push(Error::Rtp(err)); - } - } - - for (uri, data) in extension_data.iter() { - if let Some(id) = b - .params - .header_extensions - .iter() - .find(|ext| &ext.uri == uri) - .map(|ext| ext.id) - { - if let Err(err) = pkt.header.set_extension(id as u8, data.clone()) { - write_errs.push(Error::Rtp(err)); - continue; - } - } - } - - if let Some(write_stream) = &b.write_stream { - match write_stream.write_rtp(&pkt).await { - Ok(m) => { - n += m; - } - Err(err) => { - write_errs.push(err); - } - } - } else { - write_errs.push(Error::new("track binding has none write_stream".to_owned())); - } - } - - flatten_errs(write_errs)?; - Ok(n) - } -} - -#[async_trait] -impl TrackLocal for TrackLocalStaticRTP { - /// bind is called by the PeerConnection after negotiation is complete - /// This asserts that the code requested is supported by the remote peer. - /// If so it setups all the state (SSRC and PayloadType) to have a call - async fn bind(&self, t: &TrackLocalContext) -> Result { - let parameters = RTCRtpCodecParameters { - capability: self.codec.clone(), - ..Default::default() - }; - let mut hdr_ext_ids = vec![]; - if let Some(id) = t - .header_extensions() - .iter() - .find(|e| e.uri == ::sdp::extmap::SDES_MID_URI) - .map(|e| e.id as u8) - { - if let Some(payload) = t - .mid - .as_ref() - .map(|mid| Bytes::copy_from_slice(mid.as_bytes())) - { - hdr_ext_ids.push(rtp::header::Extension { id, payload }); - } - } - - if let Some(id) = t - .header_extensions() - .iter() - .find(|e| e.uri == ::sdp::extmap::SDES_RTP_STREAM_ID_URI) - .map(|e| e.id as u8) - { - if let Some(payload) = self.rid().map(|rid| rid.to_owned().into()) { - hdr_ext_ids.push(rtp::header::Extension { id, payload }); - } - } - - let (codec, match_type) = codec_parameters_fuzzy_search(¶meters, t.codec_parameters()); - if match_type != CodecMatch::None { - { - let mut bindings = self.bindings.lock().await; - bindings.push(Arc::new(TrackBinding { - id: t.id(), - ssrc: t.ssrc(), - payload_type: codec.payload_type, - params: t.params.clone(), - write_stream: t.write_stream(), - sender_paused: t.paused.clone(), - hdr_ext_ids, - })); - } - - Ok(codec) - } else { - Err(Error::ErrUnsupportedCodec) - } - } - - /// unbind implements the teardown logic when the track is no longer needed. This happens - /// because a track has been stopped. - async fn unbind(&self, t: &TrackLocalContext) -> Result<()> { - let mut bindings = self.bindings.lock().await; - let mut idx = None; - for (index, binding) in bindings.iter().enumerate() { - if binding.id == t.id() { - idx = Some(index); - break; - } - } - if let Some(index) = idx { - bindings.remove(index); - Ok(()) - } else { - Err(Error::ErrUnbindFailed) - } - } - - /// id is the unique identifier for this Track. This should be unique for the - /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' - /// and StreamID would be 'desktop' or 'webcam' - fn id(&self) -> &str { - self.id.as_str() - } - - /// RID is the RTP Stream ID for this track. - fn rid(&self) -> Option<&str> { - self.rid.as_deref() - } - - /// stream_id is the group this track belongs too. This must be unique - fn stream_id(&self) -> &str { - self.stream_id.as_str() - } - - /// kind controls if this TrackLocal is audio or video - fn kind(&self) -> RTPCodecType { - if self.codec.mime_type.starts_with("audio/") { - RTPCodecType::Audio - } else if self.codec.mime_type.starts_with("video/") { - RTPCodecType::Video - } else { - RTPCodecType::Unspecified - } - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -#[async_trait] -impl TrackLocalWriter for TrackLocalStaticRTP { - /// write_rtp writes a RTP Packet to the TrackLocalStaticRTP - /// If one PeerConnection fails the packets will still be sent to - /// all PeerConnections. The error message will contain the ID of the failed - /// PeerConnections so you can remove them - /// - /// If the RTCRtpSender direction is such that no packets should be sent, any call to this - /// function are blocked internally. Care must be taken to not increase the sequence number - /// while the sender is paused. While the actual _sending_ is blocked, the receiver will - /// miss out when the sequence number "rolls over", which in turn will break SRTP. - async fn write_rtp(&self, p: &rtp::packet::Packet) -> Result { - self.write_rtp_with_extensions(p, &[]).await - } - - /// write writes a RTP Packet as a buffer to the TrackLocalStaticRTP - /// If one PeerConnection fails the packets will still be sent to - /// all PeerConnections. The error message will contain the ID of the failed - /// PeerConnections so you can remove them - async fn write(&self, mut b: &[u8]) -> Result { - let pkt = rtp::packet::Packet::unmarshal(&mut b)?; - self.write_rtp(&pkt).await?; - Ok(b.len()) - } -} diff --git a/webrtc/src/track/track_local/track_local_static_sample.rs b/webrtc/src/track/track_local/track_local_static_sample.rs deleted file mode 100644 index 2c4d5166e..000000000 --- a/webrtc/src/track/track_local/track_local_static_sample.rs +++ /dev/null @@ -1,324 +0,0 @@ -use log::warn; -use media::Sample; -use tokio::sync::Mutex; - -use super::track_local_static_rtp::TrackLocalStaticRTP; -use super::*; -use crate::error::flatten_errs; -use crate::track::RTP_OUTBOUND_MTU; - -#[derive(Debug, Clone)] -struct TrackLocalStaticSampleInternal { - packetizer: Option>, - sequencer: Option>, - clock_rate: f64, - did_warn_about_wonky_pause: bool, -} - -/// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. -/// If you wish to send a RTP Packet use TrackLocalStaticRTP -#[derive(Debug)] -pub struct TrackLocalStaticSample { - rtp_track: TrackLocalStaticRTP, - internal: Mutex, -} - -impl TrackLocalStaticSample { - /// returns a TrackLocalStaticSample without RID - pub fn new(codec: RTCRtpCodecCapability, id: String, stream_id: String) -> Self { - let rtp_track = TrackLocalStaticRTP::new(codec, id, stream_id); - - TrackLocalStaticSample { - rtp_track, - internal: Mutex::new(TrackLocalStaticSampleInternal { - packetizer: None, - sequencer: None, - clock_rate: 0.0f64, - did_warn_about_wonky_pause: false, - }), - } - } - - /// returns a TrackLocalStaticSample with RID - pub fn new_with_rid( - codec: RTCRtpCodecCapability, - id: String, - rid: String, - stream_id: String, - ) -> Self { - let rtp_track = TrackLocalStaticRTP::new_with_rid(codec, id, rid, stream_id); - - TrackLocalStaticSample { - rtp_track, - internal: Mutex::new(TrackLocalStaticSampleInternal { - packetizer: None, - sequencer: None, - clock_rate: 0.0f64, - did_warn_about_wonky_pause: false, - }), - } - } - - /// codec gets the Codec of the track - pub fn codec(&self) -> RTCRtpCodecCapability { - self.rtp_track.codec() - } - - /// write_sample writes a Sample to the TrackLocalStaticSample - /// If one PeerConnection fails the packets will still be sent to - /// all PeerConnections. The error message will contain the ID of the failed - /// PeerConnections so you can remove them - pub async fn write_sample(&self, sample: &Sample) -> Result<()> { - self.write_sample_with_extensions(sample, &[]).await - } - - /// Write a sample with provided RTP extensions. - /// - /// Alternatively to this method [`TrackLocalStaticSample::sample_writer`] can be used instead. - /// - /// See [`TrackLocalStaticSample::write_sample`] for further details. - pub async fn write_sample_with_extensions( - &self, - sample: &Sample, - extensions: &[rtp::extension::HeaderExtension], - ) -> Result<()> { - let mut internal = self.internal.lock().await; - - if internal.packetizer.is_none() || internal.sequencer.is_none() { - return Ok(()); - } - - let (any_paused, all_paused) = ( - self.rtp_track.any_binding_paused().await, - self.rtp_track.all_binding_paused().await, - ); - - if all_paused { - // Abort already here to not increment sequence numbers. - return Ok(()); - } - - if any_paused { - // This is a problem state due to how this impl is structured. The sequencer will allocate - // one sequence number per RTP packet regardless of how many TrackBinding that will send - // the packet. I.e. we get the same sequence number per multiple SSRC, which is not good - // for SRTP, but that's how it works. - // - // SRTP has a further problem with regards to jumps in sequence number. Consider this: - // - // 1. Create track local - // 2. Bind track local to track 1. - // 3. Bind track local to track 2. - // 4. Pause track 1. - // 5. Keep sending... - // - // At this point, the track local will keep incrementing the sequence number, because we have - // one binding that is still active. However SRTP hmac verifying (tag), can only accept a - // relatively small jump in sequence numbers since it uses the ROC (i.e. how many times the - // sequence number has rolled over), which means if this pause state of one binding persists - // for a longer time, the track can never be resumed since the receiver would have missed - // the rollovers. - if !internal.did_warn_about_wonky_pause { - internal.did_warn_about_wonky_pause = true; - warn!("Detected multiple track bindings where only one was paused"); - } - } - - // skip packets by the number of previously dropped packets - if let Some(sequencer) = &internal.sequencer { - for _ in 0..sample.prev_dropped_packets { - sequencer.next_sequence_number(); - } - } - - let clock_rate = internal.clock_rate; - - let packets = if let Some(packetizer) = &mut internal.packetizer { - let samples = (sample.duration.as_secs_f64() * clock_rate) as u32; - if sample.prev_dropped_packets > 0 { - packetizer.skip_samples(samples * sample.prev_dropped_packets as u32); - } - packetizer.packetize(&sample.data, samples)? - } else { - vec![] - }; - - let mut write_errs = vec![]; - for p in packets { - if let Err(err) = self - .rtp_track - .write_rtp_with_extensions(&p, extensions) - .await - { - write_errs.push(err); - } - } - - flatten_errs(write_errs) - } - - /// Create a builder for writing samples with additional data. - /// - /// # Example - /// ```no_run - /// use rtp::extension::audio_level_extension::AudioLevelExtension; - /// use std::time::Duration; - /// use webrtc::api::media_engine::MIME_TYPE_VP8; - /// use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; - /// use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; - /// - /// #[tokio::main] - /// async fn main() { - /// let track = TrackLocalStaticSample::new( - /// RTCRtpCodecCapability { - /// mime_type: MIME_TYPE_VP8.to_owned(), - /// ..Default::default() - /// }, - /// "video".to_owned(), - /// "webrtc-rs".to_owned(), - /// ); - /// let result = track - /// .sample_writer() - /// .with_audio_level(AudioLevelExtension { - /// level: 10, - /// voice: true, - /// }) - /// .write_sample(&media::Sample{ - /// data: bytes::Bytes::new(), - /// duration: Duration::from_secs(1), - /// ..Default::default() - /// }) - /// .await; - /// } - /// ``` - pub fn sample_writer(&self) -> SampleWriter<'_> { - SampleWriter::new(self) - } -} - -#[async_trait] -impl TrackLocal for TrackLocalStaticSample { - /// Bind is called by the PeerConnection after negotiation is complete - /// This asserts that the code requested is supported by the remote peer. - /// If so it setups all the state (SSRC and PayloadType) to have a call - async fn bind(&self, t: &TrackLocalContext) -> Result { - let codec = self.rtp_track.bind(t).await?; - - let mut internal = self.internal.lock().await; - - // We only need one packetizer - if internal.packetizer.is_some() { - return Ok(codec); - } - - let payloader = codec.capability.payloader_for_codec()?; - let sequencer: Box = - Box::new(rtp::sequence::new_random_sequencer()); - internal.packetizer = Some(Box::new(rtp::packetizer::new_packetizer( - RTP_OUTBOUND_MTU, - 0, // Value is handled when writing - 0, // Value is handled when writing - payloader, - sequencer.clone(), - codec.capability.clock_rate, - ))); - internal.sequencer = Some(sequencer); - internal.clock_rate = codec.capability.clock_rate as f64; - - Ok(codec) - } - - /// unbind implements the teardown logic when the track is no longer needed. This happens - /// because a track has been stopped. - async fn unbind(&self, t: &TrackLocalContext) -> Result<()> { - self.rtp_track.unbind(t).await - } - - /// id is the unique identifier for this Track. This should be unique for the - /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' - /// and StreamID would be 'desktop' or 'webcam' - fn id(&self) -> &str { - self.rtp_track.id() - } - - /// RID is the RTP Stream ID for this track. - fn rid(&self) -> Option<&str> { - self.rtp_track.rid() - } - - /// stream_id is the group this track belongs too. This must be unique - fn stream_id(&self) -> &str { - self.rtp_track.stream_id() - } - - /// kind controls if this TrackLocal is audio or video - fn kind(&self) -> RTPCodecType { - self.rtp_track.kind() - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -mod sample_writer { - use media::Sample; - use rtp::extension::audio_level_extension::AudioLevelExtension; - use rtp::extension::video_orientation_extension::VideoOrientationExtension; - use rtp::extension::HeaderExtension; - - use super::TrackLocalStaticSample; - use crate::error::Result; - - /// Helper for writing Samples via [`TrackLocalStaticSample`] that carry extra RTP data. - /// - /// Created via [`TrackLocalStaticSample::sample_writer`]. - pub struct SampleWriter<'track> { - track: &'track TrackLocalStaticSample, - extensions: Vec, - } - - impl<'track> SampleWriter<'track> { - pub(super) fn new(track: &'track TrackLocalStaticSample) -> Self { - Self { - track, - extensions: vec![], - } - } - - /// Add a RTP audio level extension to all packets written for the sample. - /// - /// This overwrites any previously configured audio level extension. - pub fn with_audio_level(self, ext: AudioLevelExtension) -> Self { - self.with_extension(HeaderExtension::AudioLevel(ext)) - } - - /// Add a RTP video orientation extension to all packets written for the sample. - /// - /// This overwrites any previously configured video orientation extension. - pub fn with_video_orientation(self, ext: VideoOrientationExtension) -> Self { - self.with_extension(HeaderExtension::VideoOrientation(ext)) - } - - /// Add any RTP extension to all packets written for the sample. - pub fn with_extension(mut self, ext: HeaderExtension) -> Self { - self.extensions.retain(|e| !e.is_same(&ext)); - - self.extensions.push(ext); - - self - } - - /// Write the sample to the track. - /// - /// Creates one or more RTP packets with any extensions specified for each packet and sends - /// them. - pub async fn write_sample(self, sample: &Sample) -> Result<()> { - self.track - .write_sample_with_extensions(sample, &self.extensions) - .await - } - } -} - -pub use sample_writer::SampleWriter; diff --git a/webrtc/src/track/track_local/track_local_static_test.rs b/webrtc/src/track/track_local/track_local_static_test.rs deleted file mode 100644 index b385d98ac..000000000 --- a/webrtc/src/track/track_local/track_local_static_test.rs +++ /dev/null @@ -1,434 +0,0 @@ -use std::sync::Arc; - -use bytes::Bytes; -use tokio::sync::{mpsc, Mutex}; - -use super::track_local_static_rtp::*; -use super::track_local_static_sample::*; -use super::*; -use crate::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; -use crate::api::APIBuilder; -use crate::peer_connection::configuration::RTCConfiguration; -use crate::peer_connection::peer_connection_test::*; - -// If a remote doesn't support a Codec used by a `TrackLocalStatic` -// an error should be returned to the user -#[tokio::test] -async fn test_track_local_static_no_codec_intersection() -> Result<()> { - let track: Arc = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: "video/vp8".to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - //"Offerer" - { - let mut pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let mut no_codec_pc = APIBuilder::new() - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - pc.add_track(Arc::clone(&track)).await?; - - if let Err(err) = signal_pair(&mut pc, &mut no_codec_pc).await { - assert_eq!(err, Error::ErrUnsupportedCodec); - } else { - panic!(); - } - - close_pair_now(&no_codec_pc, &pc).await; - } - - //"Answerer" - { - let mut pc = api.new_peer_connection(RTCConfiguration::default()).await?; - - let mut m = MediaEngine::default(); - m.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: "video/VP9".to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 96, - ..Default::default() - }, - RTPCodecType::Video, - )?; - let mut vp9only_pc = APIBuilder::new() - .with_media_engine(m) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - vp9only_pc - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - pc.add_track(Arc::clone(&track)).await?; - - if let Err(err) = signal_pair(&mut vp9only_pc, &mut pc).await { - assert_eq!( - err, - Error::ErrUnsupportedCodec, - "expected {}, but got {}", - Error::ErrUnsupportedCodec, - err - ); - } else { - panic!(); - } - - close_pair_now(&vp9only_pc, &pc).await; - } - - //"Local" - { - let (mut offerer, mut answerer) = new_pair(&api).await?; - - let invalid_codec_track = TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: "video/invalid-codec".to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - ); - - offerer.add_track(Arc::new(invalid_codec_track)).await?; - - if let Err(err) = signal_pair(&mut offerer, &mut answerer).await { - assert_eq!(err, Error::ErrUnsupportedCodec); - } else { - panic!(); - } - - close_pair_now(&offerer, &answerer).await; - } - - Ok(()) -} - -// Assert that Bind/Unbind happens when expected -#[tokio::test] -async fn test_track_local_static_closed() -> Result<()> { - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; - - pc_answer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let vp8writer: Arc = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: "video/vp8".to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - pc_offer.add_track(Arc::clone(&vp8writer)).await?; - - if let Some(v) = vp8writer.as_any().downcast_ref::() { - let bindings = v.bindings.lock().await; - assert_eq!( - bindings.len(), - 0, - "No binding should exist before signaling" - ); - } else { - panic!(); - } - - signal_pair(&mut pc_offer, &mut pc_answer).await?; - - if let Some(v) = vp8writer.as_any().downcast_ref::() { - let bindings = v.bindings.lock().await; - assert_eq!(bindings.len(), 1, "binding should exist after signaling"); - } else { - panic!(); - } - - close_pair_now(&pc_offer, &pc_answer).await; - - if let Some(v) = vp8writer.as_any().downcast_ref::() { - let bindings = v.bindings.lock().await; - assert_eq!(bindings.len(), 0, "No binding should exist after close"); - } else { - panic!(); - } - - Ok(()) -} - -//use log::LevelFilter; -//use std::io::Write; - -#[tokio::test] -async fn test_track_local_static_payload_type() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut media_engine_one = MediaEngine::default(); - media_engine_one.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 100, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - let mut media_engine_two = MediaEngine::default(); - media_engine_two.register_codec( - RTCRtpCodecParameters { - capability: RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - clock_rate: 90000, - channels: 0, - sdp_fmtp_line: "".to_owned(), - rtcp_feedback: vec![], - }, - payload_type: 200, - ..Default::default() - }, - RTPCodecType::Video, - )?; - - let mut offerer = APIBuilder::new() - .with_media_engine(media_engine_one) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - let mut answerer = APIBuilder::new() - .with_media_engine(media_engine_two) - .build() - .new_peer_connection(RTCConfiguration::default()) - .await?; - - let track = Arc::new(TrackLocalStaticSample::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - offerer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - answerer - .add_track(Arc::clone(&track) as Arc) - .await?; - - let (on_track_fired_tx, on_track_fired_rx) = mpsc::channel::<()>(1); - let on_track_fired_tx = Arc::new(Mutex::new(Some(on_track_fired_tx))); - offerer.on_track(Box::new(move |track, _, _| { - let on_track_fired_tx2 = Arc::clone(&on_track_fired_tx); - Box::pin(async move { - assert_eq!(track.payload_type(), 100); - assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); - { - log::debug!("onTrackFiredFunc!!!"); - let mut done = on_track_fired_tx2.lock().await; - done.take(); - } - }) - })); - - signal_pair(&mut offerer, &mut answerer).await?; - - send_video_until_done( - on_track_fired_rx, - vec![track], - Bytes::from_static(&[0x00]), - None, - ) - .await; - - close_pair_now(&offerer, &answerer).await; - - Ok(()) -} - -// Assert that writing to a Track doesn't modify the input -// Even though we can pass a pointer we shouldn't modify the incoming value -#[tokio::test] -async fn test_track_local_static_mutate_input() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; - - let vp8writer: Arc = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - pc_offer.add_track(Arc::clone(&vp8writer)).await?; - - signal_pair(&mut pc_offer, &mut pc_answer).await?; - - let pkt = rtp::packet::Packet { - header: rtp::header::Header { - ssrc: 1, - payload_type: 1, - ..Default::default() - }, - ..Default::default() - }; - if let Some(v) = vp8writer.as_any().downcast_ref::() { - v.write_rtp(&pkt).await?; - } else { - panic!(); - } - - assert_eq!(pkt.header.ssrc, 1); - assert_eq!(pkt.header.payload_type, 1); - - close_pair_now(&pc_offer, &pc_answer).await; - - Ok(()) -} - -//use std::io::Write; -//use log::LevelFilter; - -// Assert that writing to a Track that has Binded (but not connected) -// does not block -#[tokio::test] -async fn test_track_local_static_binding_non_blocking() -> Result<()> { - /*env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{}:{} [{}] {} - {}", - record.file().unwrap_or("unknown"), - record.line().unwrap_or(0), - record.level(), - chrono::Local::now().format("%H:%M:%S.%6f"), - record.args() - ) - }) - .filter(None, LevelFilter::Trace) - .init();*/ - - let mut m = MediaEngine::default(); - m.register_default_codecs()?; - let api = APIBuilder::new().with_media_engine(m).build(); - - let (pc_offer, pc_answer) = new_pair(&api).await?; - - pc_offer - .add_transceiver_from_kind(RTPCodecType::Video, None) - .await?; - - let vp8writer: Arc = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_VP8.to_owned(), - ..Default::default() - }, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - - pc_answer.add_track(Arc::clone(&vp8writer)).await?; - - let offer = pc_offer.create_offer(None).await?; - pc_answer.set_remote_description(offer).await?; - - let answer = pc_answer.create_answer(None).await?; - pc_answer.set_local_description(answer).await?; - - if let Some(v) = vp8writer.as_any().downcast_ref::() { - v.write(&[0u8; 20]).await?; - } else { - panic!(); - } - - close_pair_now(&pc_offer, &pc_answer).await; - - Ok(()) -} - -/* -//TODO: func BenchmarkTrackLocalWrite(b *testing.B) { - offerPC, answerPC, err := newPair() - defer closePairNow(b, offerPC, answerPC) - if err != nil { - b.Fatalf("Failed to create a PC pair for testing") - } - - track, err := NewTrackLocalStaticRTP(RTPCodecCapability{mime_type: MIME_TYPE_VP8}, "video", "pion") - assert.NoError(b, err) - - _, err = offerPC.AddTrack(track) - assert.NoError(b, err) - - _, err = answerPC.AddTransceiverFromKind(RTPCodecTypeVideo) - assert.NoError(b, err) - - b.SetBytes(1024) - - buf := make([]byte, 1024) - for i := 0; i < b.N; i++ { - _, err := track.Write(buf) - assert.NoError(b, err) - } -} -*/ diff --git a/webrtc/src/track/track_remote/mod.rs b/webrtc/src/track/track_remote/mod.rs deleted file mode 100644 index 7d0565cf9..000000000 --- a/webrtc/src/track/track_remote/mod.rs +++ /dev/null @@ -1,321 +0,0 @@ -use std::collections::VecDeque; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::{Arc, Weak}; - -use arc_swap::ArcSwapOption; -use interceptor::{Attributes, Interceptor}; -use portable_atomic::{AtomicU32, AtomicU8, AtomicUsize}; -use smol_str::SmolStr; -use tokio::sync::Mutex; -use util::sync::Mutex as SyncMutex; - -use crate::api::media_engine::MediaEngine; -use crate::error::{Error, Result}; -use crate::rtp_transceiver::rtp_codec::{RTCRtpCodecParameters, RTCRtpParameters, RTPCodecType}; -use crate::rtp_transceiver::rtp_receiver::RTPReceiverInternal; -use crate::rtp_transceiver::{PayloadType, SSRC}; - -lazy_static! { - static ref TRACK_REMOTE_UNIQUE_ID: AtomicUsize = AtomicUsize::new(0); -} -pub type OnMuteHdlrFn = Box< - dyn (FnMut() -> Pin + Send + 'static>>) + Send + Sync + 'static, ->; - -#[derive(Default)] -struct Handlers { - on_mute: ArcSwapOption>, - on_unmute: ArcSwapOption>, -} - -#[derive(Default)] -struct TrackRemoteInternal { - peeked: VecDeque<(rtp::packet::Packet, Attributes)>, -} - -/// TrackRemote represents a single inbound source of media -pub struct TrackRemote { - tid: usize, - - id: SyncMutex, - stream_id: SyncMutex, - - receive_mtu: usize, - payload_type: AtomicU8, //PayloadType, - kind: AtomicU8, //RTPCodecType, - ssrc: AtomicU32, //SSRC, - codec: SyncMutex, - pub(crate) params: SyncMutex, - rid: SmolStr, - - media_engine: Arc, - interceptor: Arc, - - handlers: Arc, - - receiver: Option>, - internal: Mutex, -} - -impl std::fmt::Debug for TrackRemote { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TrackRemote") - .field("id", &self.id) - .field("stream_id", &self.stream_id) - .field("payload_type", &self.payload_type) - .field("kind", &self.kind) - .field("ssrc", &self.ssrc) - .field("codec", &self.codec) - .field("params", &self.params) - .field("rid", &self.rid) - .finish() - } -} - -impl TrackRemote { - pub(crate) fn new( - receive_mtu: usize, - kind: RTPCodecType, - ssrc: SSRC, - rid: SmolStr, - receiver: Weak, - media_engine: Arc, - interceptor: Arc, - ) -> Self { - TrackRemote { - tid: TRACK_REMOTE_UNIQUE_ID.fetch_add(1, Ordering::SeqCst), - id: Default::default(), - stream_id: Default::default(), - receive_mtu, - payload_type: Default::default(), - kind: AtomicU8::new(kind as u8), - ssrc: AtomicU32::new(ssrc), - codec: Default::default(), - params: Default::default(), - rid, - receiver: Some(receiver), - media_engine, - interceptor, - handlers: Default::default(), - - internal: Default::default(), - } - } - - pub fn tid(&self) -> usize { - self.tid - } - - /// id is the unique identifier for this Track. This should be unique for the - /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' - /// and StreamID would be 'desktop' or 'webcam' - pub fn id(&self) -> String { - let id = self.id.lock(); - id.clone() - } - - pub fn set_id(&self, s: String) { - let mut id = self.id.lock(); - *id = s; - } - - /// stream_id is the group this track belongs too. This must be unique - pub fn stream_id(&self) -> String { - let stream_id = self.stream_id.lock(); - stream_id.clone() - } - - pub fn set_stream_id(&self, s: String) { - let mut stream_id = self.stream_id.lock(); - *stream_id = s; - } - - /// rid gets the RTP Stream ID of this Track - /// With Simulcast you will have multiple tracks with the same ID, but different RID values. - /// In many cases a TrackRemote will not have an RID, so it is important to assert it is non-zero - pub fn rid(&self) -> &str { - self.rid.as_str() - } - - /// payload_type gets the PayloadType of the track - pub fn payload_type(&self) -> PayloadType { - self.payload_type.load(Ordering::SeqCst) - } - - pub fn set_payload_type(&self, payload_type: PayloadType) { - self.payload_type.store(payload_type, Ordering::SeqCst); - } - - /// kind gets the Kind of the track - pub fn kind(&self) -> RTPCodecType { - self.kind.load(Ordering::SeqCst).into() - } - - pub fn set_kind(&self, kind: RTPCodecType) { - self.kind.store(kind as u8, Ordering::SeqCst); - } - - /// ssrc gets the SSRC of the track - pub fn ssrc(&self) -> SSRC { - self.ssrc.load(Ordering::SeqCst) - } - - pub fn set_ssrc(&self, ssrc: SSRC) { - self.ssrc.store(ssrc, Ordering::SeqCst); - } - - /// msid gets the Msid of the track - pub fn msid(&self) -> String { - format!("{} {}", self.stream_id(), self.id()) - } - - /// codec gets the Codec of the track - pub fn codec(&self) -> RTCRtpCodecParameters { - let codec = self.codec.lock(); - codec.clone() - } - - pub fn set_codec(&self, codec: RTCRtpCodecParameters) { - let mut c = self.codec.lock(); - *c = codec; - } - - pub fn params(&self) -> RTCRtpParameters { - let p = self.params.lock(); - p.clone() - } - - pub fn set_params(&self, params: RTCRtpParameters) { - let mut p = self.params.lock(); - *p = params; - } - - pub fn onmute(&self, handler: F) - where - F: FnMut() -> Pin + Send + 'static>> + Send + 'static + Sync, - { - self.handlers - .on_mute - .store(Some(Arc::new(Mutex::new(Box::new(handler))))); - } - - pub fn onunmute(&self, handler: F) - where - F: FnMut() -> Pin + Send + 'static>> + Send + 'static + Sync, - { - self.handlers - .on_unmute - .store(Some(Arc::new(Mutex::new(Box::new(handler))))); - } - - /// Reads data from the track. - /// - /// **Cancel Safety:** This method is not cancel safe. Dropping the resulting [`Future`] before - /// it returns [`std::task::Poll::Ready`] will cause data loss. - pub async fn read(&self, b: &mut [u8]) -> Result<(rtp::packet::Packet, Attributes)> { - { - // Internal lock scope - let mut internal = self.internal.lock().await; - if let Some((pkt, attributes)) = internal.peeked.pop_front() { - self.check_and_update_track(&pkt).await?; - - return Ok((pkt, attributes)); - } - }; - - let receiver = match self.receiver.as_ref().and_then(|r| r.upgrade()) { - Some(r) => r, - None => return Err(Error::ErrRTPReceiverNil), - }; - - let (pkt, attributes) = receiver.read_rtp(b, self.tid).await?; - self.check_and_update_track(&pkt).await?; - Ok((pkt, attributes)) - } - - /// check_and_update_track checks payloadType for every incoming packet - /// once a different payloadType is detected the track will be updated - pub(crate) async fn check_and_update_track(&self, pkt: &rtp::packet::Packet) -> Result<()> { - let payload_type = pkt.header.payload_type; - if payload_type != self.payload_type() { - let p = self - .media_engine - .get_rtp_parameters_by_payload_type(payload_type) - .await?; - - if let Some(receiver) = &self.receiver { - if let Some(receiver) = receiver.upgrade() { - self.kind.store(receiver.kind as u8, Ordering::SeqCst); - } - } - self.payload_type.store(payload_type, Ordering::SeqCst); - { - let mut codec = self.codec.lock(); - *codec = if let Some(codec) = p.codecs.first() { - codec.clone() - } else { - return Err(Error::ErrCodecNotFound); - }; - } - { - let mut params = self.params.lock(); - *params = p; - } - } - - Ok(()) - } - - /// read_rtp is a convenience method that wraps Read and unmarshals for you. - pub async fn read_rtp(&self) -> Result<(rtp::packet::Packet, Attributes)> { - let mut b = vec![0u8; self.receive_mtu]; - let (pkt, attributes) = self.read(&mut b).await?; - - Ok((pkt, attributes)) - } - - /// peek is like Read, but it doesn't discard the packet read - pub(crate) async fn peek(&self, b: &mut [u8]) -> Result<(rtp::packet::Packet, Attributes)> { - let (pkt, a) = self.read(b).await?; - - // this might overwrite data if somebody peeked between the Read - // and us getting the lock. Oh well, we'll just drop a packet in - // that case. - { - let mut internal = self.internal.lock().await; - internal.peeked.push_back((pkt.clone(), a.clone())); - } - Ok((pkt, a)) - } - - /// Set the initially peeked data for this track. - /// - /// This is useful when a track is first created to populate data read from the track in the - /// process of identifying the track as part of simulcast probing. Using this during other - /// parts of the track's lifecycle is probably an error. - pub(crate) async fn prepopulate_peeked_data( - &self, - data: VecDeque<(rtp::packet::Packet, Attributes)>, - ) { - let mut internal = self.internal.lock().await; - internal.peeked = data; - } - - pub(crate) async fn fire_onmute(&self) { - let on_mute = self.handlers.on_mute.load(); - - if let Some(f) = on_mute.as_ref() { - (f.lock().await)().await - }; - } - - pub(crate) async fn fire_onunmute(&self) { - let on_unmute = self.handlers.on_unmute.load(); - - if let Some(f) = on_unmute.as_ref() { - (f.lock().await)().await - }; - } -}