Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: key ratchet/derive. #66

Merged
merged 27 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@
/xcodebuild
/.vscode
!webrtc/*
/tmp.patch
/out-release
/out-debug
/node_modules
/libwebrtc
/args.txt
246 changes: 192 additions & 54 deletions api/crypto/frame_crypto_transformer.cc

Large diffs are not rendered by default.

183 changes: 177 additions & 6 deletions api/crypto/frame_crypto_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,209 @@
#ifndef WEBRTC_FRAME_CRYPTOR_TRANSFORMER_H_
#define WEBRTC_FRAME_CRYPTOR_TRANSFORMER_H_

#include <unordered_map>

#include "api/frame_transformer_interface.h"
#include "rtc_base/buffer.h"
#include "rtc_base/synchronization/mutex.h"
#include "rtc_base/system/rtc_export.h"
#include "rtc_base/thread.h"

int DerivePBKDF2KeyFromRawKey(const std::vector<uint8_t> raw_key,
const std::vector<uint8_t>& salt,
unsigned int optional_length_bits,
std::vector<uint8_t>* derived_key);

namespace webrtc {

const size_t KEYRING_SIZE = 16;

struct KeyProviderOptions {
bool shared_key;
std::vector<uint8_t> ratchet_salt;
std::vector<uint8_t> uncrypted_magic_bytes;
int ratchet_window_size;
KeyProviderOptions() : shared_key(false), ratchet_window_size(0) {}
KeyProviderOptions(KeyProviderOptions& copy)
: shared_key(copy.shared_key),
ratchet_salt(copy.ratchet_salt),
uncrypted_magic_bytes(copy.uncrypted_magic_bytes),
ratchet_window_size(copy.ratchet_window_size) {}
};

class ParticipantKeyHandler {
public:
struct KeySet {
std::vector<uint8_t> material;
std::vector<uint8_t> encryption_key;
KeySet(std::vector<uint8_t> material, std::vector<uint8_t> encryptionKey)
: material(material), encryption_key(encryptionKey) {}
};

public:
ParticipantKeyHandler(KeyProviderOptions options) : options_(options) {
cryptoKeyRing_.resize(KEYRING_SIZE);
}

virtual ~ParticipantKeyHandler() = default;

virtual std::vector<uint8_t> RatchetKey(int keyIndex) {
auto currentMaterial = GetKeySet(keyIndex)->material;
std::vector<uint8_t> newMaterial;
if (DerivePBKDF2KeyFromRawKey(currentMaterial, options_.ratchet_salt, 256,
&newMaterial) != 0) {
return std::vector<uint8_t>();
}
SetKeyFromMaterial(newMaterial,
keyIndex != -1 ? keyIndex : currentKeyIndex);
return newMaterial;
}

virtual std::shared_ptr<KeySet> GetKeySet(int keyIndex) {
return cryptoKeyRing_[keyIndex != -1 ? keyIndex : currentKeyIndex];
}

virtual void SetKeyFromMaterial(std::vector<uint8_t> password, int keyIndex) {
if (keyIndex >= 0) {
currentKeyIndex = keyIndex % cryptoKeyRing_.size();
}
cryptoKeyRing_[currentKeyIndex] =
DeriveKeys(password, options_.ratchet_salt, 128);
}

virtual KeyProviderOptions& options() { return options_; }

std::shared_ptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> derived_key;
if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits,
&derived_key) == 0) {
return std::make_shared<KeySet>(password, derived_key);
}
return nullptr;
}

std::vector<uint8_t> RatchetKeyMaterial(
std::vector<uint8_t> currentMaterial) {
std::vector<uint8_t> newMaterial;
if (DerivePBKDF2KeyFromRawKey(currentMaterial, options_.ratchet_salt, 256,
&newMaterial) != 0) {
return std::vector<uint8_t>();
}
return newMaterial;
}

private:
int currentKeyIndex = 0;
KeyProviderOptions options_;
std::vector<std::shared_ptr<KeySet>> cryptoKeyRing_;
};

class KeyManager : public rtc::RefCountInterface {
public:
enum { kRawKeySize = 32 };

public:
virtual const std::vector<std::vector<uint8_t>> keys(
virtual const std::shared_ptr<ParticipantKeyHandler> GetKey(
const std::string participant_id) const = 0;

virtual bool SetKey(const std::string participant_id,
int index,
std::vector<uint8_t> key) = 0;

virtual const std::vector<uint8_t> RatchetKey(
const std::string participant_id,
int key_index) = 0;

virtual const std::vector<uint8_t> ExportKey(const std::string participant_id,
int key_index) const = 0;

virtual KeyProviderOptions& options() = 0;

protected:
virtual ~KeyManager() {}
};

enum FrameCryptionError {
class DefaultKeyManagerImpl : public KeyManager {
public:
DefaultKeyManagerImpl(KeyProviderOptions options) : options_(options) {}
~DefaultKeyManagerImpl() override = default;

/// Set the key at the given index.
bool SetKey(const std::string participant_id,
int index,
std::vector<uint8_t> key) override {
webrtc::MutexLock lock(&mutex_);

if (keys_.find(participant_id) == keys_.end()) {
keys_[participant_id] = std::make_shared<ParticipantKeyHandler>(options_);
}

auto keyHandler = keys_[participant_id];
keyHandler->SetKeyFromMaterial(key, index);

return true;
}

const std::shared_ptr<ParticipantKeyHandler> GetKey(
const std::string participant_id) const override {
webrtc::MutexLock lock(&mutex_);
if (keys_.find(participant_id) == keys_.end()) {
return nullptr;
}

return keys_.find(participant_id)->second;
}

const std::vector<uint8_t> RatchetKey(const std::string participant_id,
int key_index) override {
webrtc::MutexLock lock(&mutex_);
if (keys_.find(participant_id) == keys_.end()) {
return std::vector<uint8_t>();
}

return keys_[participant_id]->RatchetKey(key_index);
}

const std::vector<uint8_t> ExportKey(const std::string participant_id,
int key_index) const override {
webrtc::MutexLock lock(&mutex_);
if (keys_.find(participant_id) == keys_.end()) {
return std::vector<uint8_t>();
}

auto keySet = GetKey(participant_id);

if (!keySet) {
return std::vector<uint8_t>();
}

return keySet->GetKeySet(key_index)->material;
}

KeyProviderOptions& options() override { return options_; }

private:
mutable webrtc::Mutex mutex_;
KeyProviderOptions options_;
std::unordered_map<std::string, std::shared_ptr<ParticipantKeyHandler>> keys_;
};

enum FrameCryptionState {
kNew = 0,
kOk,
kEncryptionFailed,
kDecryptionFailed,
kMissingKey,
kKeyRatcheted,
kInternalError,
};

class FrameCryptorTransformerObserver {
public:
virtual void OnFrameCryptionError(const std::string participant_id,
FrameCryptionError error) = 0;
virtual void OnFrameCryptionStateChanged(const std::string participant_id,
FrameCryptionState error) = 0;

protected:
virtual ~FrameCryptorTransformerObserver() {}
Expand Down Expand Up @@ -85,6 +255,7 @@ class RTC_EXPORT FrameCryptorTransformer
}

virtual int key_index() const { return key_index_; };

virtual void SetEnabled(bool enabled) {
webrtc::MutexLock lock(&mutex_);
enabled_cryption_ = enabled;
Expand Down Expand Up @@ -143,8 +314,8 @@ class RTC_EXPORT FrameCryptorTransformer
rtc::scoped_refptr<KeyManager> key_manager_;
FrameCryptorTransformerObserver* observer_ = nullptr;
std::unique_ptr<rtc::Thread> thread_;
FrameCryptionError last_enc_error_ = FrameCryptionError::kNew;
FrameCryptionError last_dec_error_ = FrameCryptionError::kNew;
FrameCryptionState last_enc_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
};

} // namespace webrtc
Expand Down
18 changes: 9 additions & 9 deletions sdk/android/api/org/webrtc/FrameCryptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,24 @@
import androidx.annotation.Nullable;

public class FrameCryptor {

public enum FrameCryptorErrorState {
public enum FrameCryptionState {
NEW,
OK,
ENCRYPTIONFAILED,
DECRYPTIONFAILED,
MISSINGKEY,
KEYRATCHETED,
INTERNALERROR;

@CalledByNative("FrameCryptorErrorState")
static FrameCryptorErrorState fromNativeIndex(int nativeIndex) {
@CalledByNative("FrameCryptionState")
static FrameCryptionState fromNativeIndex(int nativeIndex) {
return values()[nativeIndex];
}
}

public static interface Observer {
@CalledByNative("Observer")
void onFrameCryptorErrorState(String participantId, FrameCryptorErrorState newState);
void onFrameCryptionStateChanged(String participantId, FrameCryptionState newState);
}

private long nativeFrameCryptor;
Expand Down Expand Up @@ -74,23 +74,23 @@ public void setKeyIndex(int index) {

public void dispose() {
checkFrameCryptorExists();
nativeUnSetObserver(nativeFrameCryptor);
JniCommon.nativeReleaseRef(nativeFrameCryptor);
nativeFrameCryptor = 0;
if(observerPtr != 0) {
if (observerPtr != 0) {
JniCommon.nativeReleaseRef(observerPtr);
observerPtr = 0;
}
nativeUnSetObserver(nativeFrameCryptor);
}

public void setObserver(@Nullable Observer observer) {
checkFrameCryptorExists();
long newPtr = nativeSetObserver(nativeFrameCryptor, observer);
if(observerPtr != 0) {
if (observerPtr != 0) {
JniCommon.nativeReleaseRef(observerPtr);
observerPtr = 0;
}
newPtr= observerPtr;
newPtr = observerPtr;
}

private void checkFrameCryptorExists() {
Expand Down
8 changes: 5 additions & 3 deletions sdk/android/api/org/webrtc/FrameCryptorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.webrtc;

public class FrameCryptorFactory {
public static FrameCryptorKeyManager createFrameCryptorKeyManager() {
return nativeCreateFrameCryptorKeyManager();
public static FrameCryptorKeyManager createFrameCryptorKeyManager(
boolean sharedKey, byte[] ratchetSalt, int ratchetWindowSize, byte[] uncryptedMagicBytes) {
return nativeCreateFrameCryptorKeyManager(sharedKey, ratchetSalt, ratchetWindowSize, uncryptedMagicBytes);
}

public static FrameCryptor createFrameCryptorForRtpSender(RtpSender rtpSender,
Expand All @@ -38,5 +39,6 @@ private static native FrameCryptor nativeCreateFrameCryptorForRtpSender(
private static native FrameCryptor nativeCreateFrameCryptorForRtpReceiver(
long rtpReceiver, String participantId, int algorithm, long nativeFrameCryptorKeyManager);

private static native FrameCryptorKeyManager nativeCreateFrameCryptorKeyManager();
private static native FrameCryptorKeyManager nativeCreateFrameCryptorKeyManager(
boolean sharedKey, byte[] ratchetSalt, int ratchetWindowSize, byte[] uncryptedMagicBytes);
}
20 changes: 11 additions & 9 deletions sdk/android/api/org/webrtc/FrameCryptorKeyManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@ public long getNativeKeyManager() {
}

public boolean setKey(String participantId, int index, byte[] key) {
checkKeyManagerExists();
return nativeSetKey(nativeKeyManager, participantId, index, key);
}

public boolean setKeys(String participantId, ArrayList<byte[]> keys) {
return nativeSetKeys(nativeKeyManager, participantId, keys);
public byte[] ratchetKey(String participantId, int index) {
checkKeyManagerExists();
return nativeRatchetKey(nativeKeyManager, participantId, index);
}

public ArrayList<byte[]> getKeys(String participantId) {
return nativeGetKeys(nativeKeyManager, participantId);
public byte[] exportKey(String participantId, int index) {
checkKeyManagerExists();
return nativeExportKey(nativeKeyManager, participantId, index);
}

public void dispose() {
Expand All @@ -54,11 +57,10 @@ private void checkKeyManagerExists() {
}
}

private static native long createNativeKeyManager();
private static native boolean nativeSetKey(
long keyManagerPointer, String participantId, int index, byte[] key);
private static native boolean nativeSetKeys(
long keyManagerPointer, String participantId, ArrayList<byte[]> keys);
private static native ArrayList<byte[]> nativeGetKeys(
long keyManagerPointer, String participantId);
private static native byte[] nativeRatchetKey(
long keyManagerPointer, String participantId, int index);
private static native byte[] nativeExportKey(
long keyManagerPointer, String participantId, int index);
}
Loading