Skip to content

Commit

Permalink
feat: key ratchet/derive. (#66)
Browse files Browse the repository at this point in the history
* feat: key derive.

* update.

* add ParticipantKeyHandler.

* update.

* update.

* update.

* update.

* update.

* fix key derive.

* chore: add kKeyRatcheted state.

* fixed key ratchet.

* update api for darwin.

* chore: for android.

* fix crash.

* fix bug for setkey

* chore: add ExportKey for KeyManager.

* chore: key export for android.

* chore: clang-format.

* chore: exportKey for darwin.

* chore: When ratchet and material derivation fail, the current keyset will not be updated until the decryption is successful or the ratchet count window is exceeded.

* chore: magic bytes.

* update.

* update for darwin.

* rename KeyManager to KeyProvider.

* fix compile for android.

* fix key retchet.

* Emit the KeyRatcheted state after a successful ratchet.

---------

Co-authored-by: root <root@WIN-13900KF>
  • Loading branch information
cloudwebrtc and root committed Jun 6, 2023
1 parent 5351ab6 commit c2f54cf
Show file tree
Hide file tree
Showing 22 changed files with 756 additions and 499 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@
/xcodebuild
/.vscode
!webrtc/*
/tmp.patch
/out-release
/out-debug
/node_modules
/libwebrtc
/args.txt
251 changes: 194 additions & 57 deletions api/crypto/frame_crypto_transformer.cc

Large diffs are not rendered by default.

191 changes: 181 additions & 10 deletions api/crypto/frame_crypto_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,209 @@
#ifndef WEBRTC_FRAME_CRYPTOR_TRANSFORMER_H_
#define WEBRTC_FRAME_CRYPTOR_TRANSFORMER_H_

#include <unordered_map>

#include "api/frame_transformer_interface.h"
#include "rtc_base/buffer.h"
#include "rtc_base/synchronization/mutex.h"
#include "rtc_base/system/rtc_export.h"
#include "rtc_base/thread.h"

int DerivePBKDF2KeyFromRawKey(const std::vector<uint8_t> raw_key,
const std::vector<uint8_t>& salt,
unsigned int optional_length_bits,
std::vector<uint8_t>* derived_key);

namespace webrtc {

class KeyManager : public rtc::RefCountInterface {
const size_t KEYRING_SIZE = 16;

struct KeyProviderOptions {
bool shared_key;
std::vector<uint8_t> ratchet_salt;
std::vector<uint8_t> uncrypted_magic_bytes;
int ratchet_window_size;
KeyProviderOptions() : shared_key(false), ratchet_window_size(0) {}
KeyProviderOptions(KeyProviderOptions& copy)
: shared_key(copy.shared_key),
ratchet_salt(copy.ratchet_salt),
uncrypted_magic_bytes(copy.uncrypted_magic_bytes),
ratchet_window_size(copy.ratchet_window_size) {}
};

class ParticipantKeyHandler {
public:
struct KeySet {
std::vector<uint8_t> material;
std::vector<uint8_t> encryption_key;
KeySet(std::vector<uint8_t> material, std::vector<uint8_t> encryptionKey)
: material(material), encryption_key(encryptionKey) {}
};

public:
ParticipantKeyHandler(KeyProviderOptions options) : options_(options) {
cryptoKeyRing_.resize(KEYRING_SIZE);
}

virtual ~ParticipantKeyHandler() = default;

virtual std::vector<uint8_t> RatchetKey(int keyIndex) {
auto currentMaterial = GetKeySet(keyIndex)->material;
std::vector<uint8_t> newMaterial;
if (DerivePBKDF2KeyFromRawKey(currentMaterial, options_.ratchet_salt, 256,
&newMaterial) != 0) {
return std::vector<uint8_t>();
}
SetKeyFromMaterial(newMaterial,
keyIndex != -1 ? keyIndex : currentKeyIndex);
return newMaterial;
}

virtual std::shared_ptr<KeySet> GetKeySet(int keyIndex) {
return cryptoKeyRing_[keyIndex != -1 ? keyIndex : currentKeyIndex];
}

virtual void SetKeyFromMaterial(std::vector<uint8_t> password, int keyIndex) {
if (keyIndex >= 0) {
currentKeyIndex = keyIndex % cryptoKeyRing_.size();
}
cryptoKeyRing_[currentKeyIndex] =
DeriveKeys(password, options_.ratchet_salt, 128);
}

virtual KeyProviderOptions& options() { return options_; }

std::shared_ptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> derived_key;
if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits,
&derived_key) == 0) {
return std::make_shared<KeySet>(password, derived_key);
}
return nullptr;
}

std::vector<uint8_t> RatchetKeyMaterial(
std::vector<uint8_t> currentMaterial) {
std::vector<uint8_t> newMaterial;
if (DerivePBKDF2KeyFromRawKey(currentMaterial, options_.ratchet_salt, 256,
&newMaterial) != 0) {
return std::vector<uint8_t>();
}
return newMaterial;
}

private:
int currentKeyIndex = 0;
KeyProviderOptions options_;
std::vector<std::shared_ptr<KeySet>> cryptoKeyRing_;
};

class KeyProvider : public rtc::RefCountInterface {
public:
enum { kRawKeySize = 32 };

public:
virtual const std::vector<std::vector<uint8_t>> keys(
virtual const std::shared_ptr<ParticipantKeyHandler> GetKey(
const std::string participant_id) const = 0;

virtual bool SetKey(const std::string participant_id,
int index,
std::vector<uint8_t> key) = 0;

virtual const std::vector<uint8_t> RatchetKey(
const std::string participant_id,
int key_index) = 0;

virtual const std::vector<uint8_t> ExportKey(const std::string participant_id,
int key_index) const = 0;

virtual KeyProviderOptions& options() = 0;

protected:
virtual ~KeyManager() {}
virtual ~KeyProvider() {}
};

enum FrameCryptionError {
class DefaultKeyProviderImpl : public KeyProvider {
public:
DefaultKeyProviderImpl(KeyProviderOptions options) : options_(options) {}
~DefaultKeyProviderImpl() override = default;

/// Set the key at the given index.
bool SetKey(const std::string participant_id,
int index,
std::vector<uint8_t> key) override {
webrtc::MutexLock lock(&mutex_);

if (keys_.find(participant_id) == keys_.end()) {
keys_[participant_id] = std::make_shared<ParticipantKeyHandler>(options_);
}

auto keyHandler = keys_[participant_id];
keyHandler->SetKeyFromMaterial(key, index);

return true;
}

const std::shared_ptr<ParticipantKeyHandler> GetKey(
const std::string participant_id) const override {
webrtc::MutexLock lock(&mutex_);
if (keys_.find(participant_id) == keys_.end()) {
return nullptr;
}

return keys_.find(participant_id)->second;
}

const std::vector<uint8_t> RatchetKey(const std::string participant_id,
int key_index) override {
webrtc::MutexLock lock(&mutex_);
if (keys_.find(participant_id) == keys_.end()) {
return std::vector<uint8_t>();
}

return keys_[participant_id]->RatchetKey(key_index);
}

const std::vector<uint8_t> ExportKey(const std::string participant_id,
int key_index) const override {
webrtc::MutexLock lock(&mutex_);
if (keys_.find(participant_id) == keys_.end()) {
return std::vector<uint8_t>();
}

auto keySet = GetKey(participant_id);

if (!keySet) {
return std::vector<uint8_t>();
}

return keySet->GetKeySet(key_index)->material;
}

KeyProviderOptions& options() override { return options_; }

private:
mutable webrtc::Mutex mutex_;
KeyProviderOptions options_;
std::unordered_map<std::string, std::shared_ptr<ParticipantKeyHandler>> keys_;
};

enum FrameCryptionState {
kNew = 0,
kOk,
kEncryptionFailed,
kDecryptionFailed,
kMissingKey,
kKeyRatcheted,
kInternalError,
};

class FrameCryptorTransformerObserver {
public:
virtual void OnFrameCryptionError(const std::string participant_id,
FrameCryptionError error) = 0;
virtual void OnFrameCryptionStateChanged(const std::string participant_id,
FrameCryptionState error) = 0;

protected:
virtual ~FrameCryptorTransformerObserver() {}
Expand All @@ -71,7 +241,7 @@ class RTC_EXPORT FrameCryptorTransformer
explicit FrameCryptorTransformer(const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyManager> key_manager);
rtc::scoped_refptr<KeyProvider> key_provider);

virtual void SetFrameCryptorTransformerObserver(
FrameCryptorTransformerObserver* observer) {
Expand All @@ -85,6 +255,7 @@ class RTC_EXPORT FrameCryptorTransformer
}

virtual int key_index() const { return key_index_; }

virtual void SetEnabled(bool enabled) {
webrtc::MutexLock lock(&mutex_);
enabled_cryption_ = enabled;
Expand Down Expand Up @@ -140,11 +311,11 @@ class RTC_EXPORT FrameCryptorTransformer
sink_callbacks_;
int key_index_ = 0;
std::map<uint32_t, uint32_t> sendCounts_;
rtc::scoped_refptr<KeyManager> key_manager_;
rtc::scoped_refptr<KeyProvider> key_provider_;
FrameCryptorTransformerObserver* observer_ = nullptr;
std::unique_ptr<rtc::Thread> thread_;
FrameCryptionError last_enc_error_ = FrameCryptionError::kNew;
FrameCryptionError last_dec_error_ = FrameCryptionError::kNew;
FrameCryptionState last_enc_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
};

} // namespace webrtc
Expand Down
10 changes: 5 additions & 5 deletions sdk/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -1024,9 +1024,9 @@ if (is_ios || is_mac) {
"objc/api/peerconnection/RTCFrameCryptor+Private.h",
"objc/api/peerconnection/RTCFrameCryptor.h",
"objc/api/peerconnection/RTCFrameCryptor.mm",
"objc/api/peerconnection/RTCFrameCryptorKeyManager+Private.h",
"objc/api/peerconnection/RTCFrameCryptorKeyManager.h",
"objc/api/peerconnection/RTCFrameCryptorKeyManager.mm",
"objc/api/peerconnection/RTCFrameCryptorKeyProvider+Private.h",
"objc/api/peerconnection/RTCFrameCryptorKeyProvider.h",
"objc/api/peerconnection/RTCFrameCryptorKeyProvider.mm",
"objc/api/peerconnection/RTCIceCandidate+Private.h",
"objc/api/peerconnection/RTCIceCandidate.h",
"objc/api/peerconnection/RTCIceCandidate.mm",
Expand Down Expand Up @@ -1381,7 +1381,7 @@ if (is_ios || is_mac) {
"objc/api/peerconnection/RTCDataChannel.h",
"objc/api/peerconnection/RTCDataChannelConfiguration.h",
"objc/api/peerconnection/RTCFrameCryptor.h",
"objc/api/peerconnection/RTCFrameCryptorKeyManager.h",
"objc/api/peerconnection/RTCFrameCryptorKeyProvider.h",
"objc/api/peerconnection/RTCFieldTrials.h",
"objc/api/peerconnection/RTCIceCandidate.h",
"objc/api/peerconnection/RTCIceCandidateErrorEvent.h",
Expand Down Expand Up @@ -1505,7 +1505,7 @@ if (is_ios || is_mac) {
"objc/api/peerconnection/RTCDataChannelConfiguration.h",
"objc/api/peerconnection/RTCDtmfSender.h",
"objc/api/peerconnection/RTCFrameCryptor.h",
"objc/api/peerconnection/RTCFrameCryptorKeyManager.h",
"objc/api/peerconnection/RTCFrameCryptorKeyProvider.h",
"objc/api/peerconnection/RTCFieldTrials.h",
"objc/api/peerconnection/RTCIceCandidate.h",
"objc/api/peerconnection/RTCIceCandidateErrorEvent.h",
Expand Down
10 changes: 7 additions & 3 deletions sdk/android/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ if (is_android) {
"api/org/webrtc/RTCStatsCollectorCallback.java",
"api/org/webrtc/RTCStatsReport.java",
"api/org/webrtc/RtcCertificatePem.java",
"api/org/webrtc/FrameCryptor.java",
"api/org/webrtc/FrameCryptorAlgorithm.java",
"api/org/webrtc/FrameCryptorFactory.java",
"api/org/webrtc/FrameCryptorKeyProvider.java",
"api/org/webrtc/RtpCapabilities.java",
"api/org/webrtc/RtpParameters.java",
"api/org/webrtc/RtpReceiver.java",
Expand Down Expand Up @@ -724,8 +728,8 @@ if (current_os == "linux" || is_android) {
"src/jni/pc/dtmf_sender.cc",
"src/jni/pc/frame_cryptor.cc",
"src/jni/pc/frame_cryptor.h",
"src/jni/pc/frame_cryptor_key_manager.cc",
"src/jni/pc/frame_cryptor_key_manager.h",
"src/jni/pc/frame_cryptor_key_provider.cc",
"src/jni/pc/frame_cryptor_key_provider.h",
"src/jni/pc/ice_candidate.cc",
"src/jni/pc/ice_candidate.h",
"src/jni/pc/media_constraints.cc",
Expand Down Expand Up @@ -1411,7 +1415,7 @@ if (current_os == "linux" || is_android) {
"api/org/webrtc/DtmfSender.java",
"api/org/webrtc/FrameCryptor.java",
"api/org/webrtc/FrameCryptorFactory.java",
"api/org/webrtc/FrameCryptorKeyManager.java",
"api/org/webrtc/FrameCryptorKeyProvider.java",
"api/org/webrtc/IceCandidate.java",
"api/org/webrtc/IceCandidateErrorEvent.java",
"api/org/webrtc/MediaConstraints.java",
Expand Down
18 changes: 9 additions & 9 deletions sdk/android/api/org/webrtc/FrameCryptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,24 @@
import androidx.annotation.Nullable;

public class FrameCryptor {

public enum FrameCryptorErrorState {
public enum FrameCryptionState {
NEW,
OK,
ENCRYPTIONFAILED,
DECRYPTIONFAILED,
MISSINGKEY,
KEYRATCHETED,
INTERNALERROR;

@CalledByNative("FrameCryptorErrorState")
static FrameCryptorErrorState fromNativeIndex(int nativeIndex) {
@CalledByNative("FrameCryptionState")
static FrameCryptionState fromNativeIndex(int nativeIndex) {
return values()[nativeIndex];
}
}

public static interface Observer {
@CalledByNative("Observer")
void onFrameCryptorErrorState(String participantId, FrameCryptorErrorState newState);
void onFrameCryptionStateChanged(String participantId, FrameCryptionState newState);
}

private long nativeFrameCryptor;
Expand Down Expand Up @@ -74,23 +74,23 @@ public void setKeyIndex(int index) {

public void dispose() {
checkFrameCryptorExists();
nativeUnSetObserver(nativeFrameCryptor);
JniCommon.nativeReleaseRef(nativeFrameCryptor);
nativeFrameCryptor = 0;
if(observerPtr != 0) {
if (observerPtr != 0) {
JniCommon.nativeReleaseRef(observerPtr);
observerPtr = 0;
}
nativeUnSetObserver(nativeFrameCryptor);
}

public void setObserver(@Nullable Observer observer) {
checkFrameCryptorExists();
long newPtr = nativeSetObserver(nativeFrameCryptor, observer);
if(observerPtr != 0) {
if (observerPtr != 0) {
JniCommon.nativeReleaseRef(observerPtr);
observerPtr = 0;
}
newPtr= observerPtr;
newPtr = observerPtr;
}

private void checkFrameCryptorExists() {
Expand Down
Loading

0 comments on commit c2f54cf

Please sign in to comment.