Skip to content

Commit

Permalink
more improvements for E2EE. (#96)
Browse files Browse the repository at this point in the history
* other improvements for E2EE.

* fix log.

* update.

* fix.

* revert changes.

* clang-format.
  • Loading branch information
cloudwebrtc authored Sep 21, 2023
1 parent 6af5bfd commit 0649214
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 68 deletions.
85 changes: 50 additions & 35 deletions api/crypto/frame_crypto_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const EVP_CIPHER* GetAesCbcAlgorithmFromKeySize(size_t key_size_bytes) {
}

inline bool FrameIsH264(webrtc::TransformableFrameInterface* frame,
webrtc::FrameCryptorTransformer::MediaType type) {
webrtc::FrameCryptorTransformer::MediaType type) {
switch (type) {
case webrtc::FrameCryptorTransformer::MediaType::kVideoFrame: {
auto videoFrame =
Expand Down Expand Up @@ -314,11 +314,18 @@ FrameCryptorTransformer::FrameCryptorTransformer(
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider)
: signaling_thread_(signaling_thread),
thread_(rtc::Thread::Create()),
participant_id_(participant_id),
type_(type),
algorithm_(algorithm),
key_provider_(key_provider) {
RTC_DCHECK(key_provider_ != nullptr);
thread_->SetName("FrameCryptorTransformer", this);
thread_->Start();
}

FrameCryptorTransformer::~FrameCryptorTransformer() {
thread_->Stop();
}

void FrameCryptorTransformer::Transform(
Expand All @@ -333,10 +340,16 @@ void FrameCryptorTransformer::Transform(
// do encrypt or decrypt here...
switch (frame->GetDirection()) {
case webrtc::TransformableFrameInterface::Direction::kSender:
encryptFrame(std::move(frame));
RTC_DCHECK(thread_ != nullptr);
thread_->PostTask([frame = std::move(frame), this]() mutable {
encryptFrame(std::move(frame));
});
break;
case webrtc::TransformableFrameInterface::Direction::kReceiver:
decryptFrame(std::move(frame));
RTC_DCHECK(thread_ != nullptr);
thread_->PostTask([frame = std::move(frame), this]() mutable {
decryptFrame(std::move(frame));
});
break;
case webrtc::TransformableFrameInterface::Direction::kUnknown:
// do nothing
Expand Down Expand Up @@ -371,6 +384,8 @@ void FrameCryptorTransformer::encryptFrame(

rtc::ArrayView<const uint8_t> date_in = frame->GetData();
if (date_in.size() == 0 || !enabled_cryption) {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::encryptFrame() "
"date_in.size() == 0 || enabled_cryption == false";
sink_callback->OnTransformedFrame(std::move(frame));
return;
}
Expand Down Expand Up @@ -425,7 +440,8 @@ void FrameCryptorTransformer::encryptFrame(
data_out.AppendData(frame_header);

if (FrameIsH264(frame.get(), type_)) {
H264::WriteRbsp(data_without_header.data(),data_without_header.size(), &data_out);
H264::WriteRbsp(data_without_header.data(), data_without_header.size(),
&data_out);
} else {
data_out.AppendData(data_without_header);
RTC_CHECK_EQ(data_out.size(), frame_header.size() +
Expand Down Expand Up @@ -490,34 +506,31 @@ void FrameCryptorTransformer::decryptFrame(
rtc::ArrayView<const uint8_t> date_in = frame->GetData();

if (date_in.size() == 0 || !enabled_cryption) {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() "
"date_in.size() == 0 || enabled_cryption == false";
sink_callback->OnTransformedFrame(std::move(frame));
return;
}

auto uncrypted_magic_bytes = key_provider_->options().uncrypted_magic_bytes;
if (uncrypted_magic_bytes.size() > 0 &&
date_in.size() >= uncrypted_magic_bytes.size() + 1) {
auto tmp =
date_in.subview(date_in.size() - (uncrypted_magic_bytes.size() + 1),
uncrypted_magic_bytes.size());

if (uncrypted_magic_bytes == std::vector<uint8_t>(tmp.begin(), tmp.end())) {
date_in.size() >= uncrypted_magic_bytes.size()) {
auto tmp = date_in.subview(date_in.size() - (uncrypted_magic_bytes.size()),
uncrypted_magic_bytes.size());
auto data = std::vector<uint8_t>(tmp.begin(), tmp.end());
if (uncrypted_magic_bytes == data) {
RTC_CHECK_EQ(tmp.size(), uncrypted_magic_bytes.size());
auto frame_type = date_in.subview(date_in.size() - 1, 1);
RTC_CHECK_EQ(frame_type.size(), 1);

RTC_LOG(LS_INFO)
<< "FrameCryptorTransformer::uncrypted_magic_bytes( type "
<< frame_type[0] << ", tmp " << to_hex(tmp.data(), tmp.size())
<< ", magic bytes "
<< to_hex(uncrypted_magic_bytes.data(), uncrypted_magic_bytes.size())
<< ")";
RTC_LOG(LS_INFO) << "FrameCryptorTransformer::uncrypted_magic_bytes( tmp "
<< to_hex(tmp.data(), tmp.size()) << ", magic bytes "
<< to_hex(uncrypted_magic_bytes.data(),
uncrypted_magic_bytes.size())
<< ")";

// magic bytes detected, this is a non-encrypted frame, skip frame
// decryption.
rtc::Buffer data_out;
data_out.AppendData(date_in.subview(
0, date_in.size() - uncrypted_magic_bytes.size() - 1));
data_out.AppendData(
date_in.subview(0, date_in.size() - uncrypted_magic_bytes.size()));
frame->SetData(data_out);
sink_callback->OnTransformedFrame(std::move(frame));
return;
Expand All @@ -539,8 +552,8 @@ void FrameCryptorTransformer::decryptFrame(

if (ivLength != getIvSize()) {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() ivLength["
<< static_cast<int>(ivLength) << "] != getIvSize()["
<< static_cast<int>(getIvSize()) << "]";
<< static_cast<int>(ivLength) << "] != getIvSize()["
<< static_cast<int>(getIvSize()) << "]";
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
onFrameCryptionStateChanged(last_dec_error_);
Expand Down Expand Up @@ -585,7 +598,8 @@ void FrameCryptorTransformer::decryptFrame(

if (FrameIsH264(frame.get(), type_) &&
NeedsRbspUnescaping(encrypted_buffer.data(), encrypted_buffer.size())) {
encrypted_buffer.SetData(H264::ParseRbsp(encrypted_buffer.data(), encrypted_buffer.size()));
encrypted_buffer.SetData(
H264::ParseRbsp(encrypted_buffer.data(), encrypted_buffer.size()));
}

rtc::Buffer encrypted_payload(encrypted_buffer.size() - ivLength - 2);
Expand Down Expand Up @@ -665,10 +679,11 @@ void FrameCryptorTransformer::decryptFrame(
}

if (!decryption_success) {
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
key_handler->DecryptionFailure();
onFrameCryptionStateChanged(last_dec_error_);
if (key_handler->DecryptionFailure()) {
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
onFrameCryptionStateChanged(last_dec_error_);
}
}
return;
}
Expand All @@ -686,15 +701,15 @@ void FrameCryptorTransformer::decryptFrame(
sink_callback->OnTransformedFrame(std::move(frame));
}

void FrameCryptorTransformer::onFrameCryptionStateChanged(FrameCryptionState state) {
void FrameCryptorTransformer::onFrameCryptionStateChanged(
FrameCryptionState state) {
webrtc::MutexLock lock(&mutex_);
if(observer_) {
if (observer_) {
RTC_DCHECK(signaling_thread_ != nullptr);
signaling_thread_->PostTask(
[observer = observer_, state = state, participant_id = participant_id_]() mutable {
observer->OnFrameCryptionStateChanged(participant_id, state);
}
);
signaling_thread_->PostTask([observer = observer_, state = state,
participant_id = participant_id_]() mutable {
observer->OnFrameCryptionStateChanged(participant_id, state);
});
}
}

Expand Down
79 changes: 46 additions & 33 deletions api/crypto/frame_crypto_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ struct KeyProviderOptions {
std::vector<uint8_t> uncrypted_magic_bytes;
int ratchet_window_size;
int failure_tolerance;
KeyProviderOptions() : shared_key(false), ratchet_window_size(0), failure_tolerance(-1) {}
KeyProviderOptions()
: shared_key(false), ratchet_window_size(0), failure_tolerance(-1) {}
KeyProviderOptions(KeyProviderOptions& copy)
: shared_key(copy.shared_key),
ratchet_salt(copy.ratchet_salt),
Expand All @@ -55,10 +56,10 @@ struct KeyProviderOptions {

class KeyProvider : public rtc::RefCountInterface {
public:

virtual bool SetSharedKey(int key_index, std::vector<uint8_t> key) = 0;

virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) = 0;
virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(
const std::string participant_id) = 0;

virtual const std::vector<uint8_t> RatchetSharedKey(int key_index) = 0;

Expand Down Expand Up @@ -94,8 +95,10 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
KeySet(std::vector<uint8_t> material, std::vector<uint8_t> encryptionKey)
: material(material), encryption_key(encryptionKey) {}
};

public:
ParticipantKeyHandler(KeyProvider* key_provider) : key_provider_(key_provider) {
ParticipantKeyHandler(KeyProvider* key_provider)
: key_provider_(key_provider) {
crypto_key_ring_.resize(KEYRING_SIZE);
}

Expand All @@ -116,7 +119,8 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
}
auto current_material = key_set->material;
std::vector<uint8_t> new_material;
if (DerivePBKDF2KeyFromRawKey(current_material, key_provider_->options().ratchet_salt, 256,
if (DerivePBKDF2KeyFromRawKey(current_material,
key_provider_->options().ratchet_salt, 256,
&new_material) != 0) {
return std::vector<uint8_t>();
}
Expand All @@ -139,16 +143,17 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
std::vector<uint8_t> RatchetKeyMaterial(
std::vector<uint8_t> current_material) {
std::vector<uint8_t> new_material;
if (DerivePBKDF2KeyFromRawKey(current_material, key_provider_->options().ratchet_salt, 256,
if (DerivePBKDF2KeyFromRawKey(current_material,
key_provider_->options().ratchet_salt, 256,
&new_material) != 0) {
return std::vector<uint8_t>();
}
return new_material;
}

rtc::scoped_refptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> derived_key;
if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits,
&derived_key) == 0) {
Expand Down Expand Up @@ -177,16 +182,19 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
DeriveKeys(password, key_provider_->options().ratchet_salt, 128);
}

void DecryptionFailure() {
bool DecryptionFailure() {
webrtc::MutexLock lock(&mutex_);
if (key_provider_->options().failure_tolerance < 0) {
return;
return false;
}
decryption_failure_count_ += 1;

if (decryption_failure_count_ > key_provider_->options().failure_tolerance) {
if (decryption_failure_count_ >
key_provider_->options().failure_tolerance) {
has_valid_key_ = false;
return true;
}
return false;
}

private:
Expand All @@ -206,16 +214,16 @@ class DefaultKeyProviderImpl : public KeyProvider {
/// Set the shared key.
bool SetSharedKey(int key_index, std::vector<uint8_t> key) override {
webrtc::MutexLock lock(&mutex_);
if(options_.shared_key) {
if (options_.shared_key) {
if (keys_.find("shared") == keys_.end()) {
keys_["shared"] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
}

auto key_handler = keys_["shared"];
key_handler->SetKey(key, key_index);

for(auto& key_pair : keys_) {
if(key_pair.first != "shared") {
for (auto& key_pair : keys_) {
if (key_pair.first != "shared") {
key_pair.second->SetKey(key, key_index);
}
}
Expand All @@ -227,13 +235,13 @@ class DefaultKeyProviderImpl : public KeyProvider {
const std::vector<uint8_t> RatchetSharedKey(int key_index) override {
webrtc::MutexLock lock(&mutex_);
auto it = keys_.find("shared");
if(it == keys_.end()) {
if (it == keys_.end()) {
return std::vector<uint8_t>();
}
auto new_key = it->second->RatchetKey(key_index);
if(options_.shared_key) {
for(auto& key_pair : keys_) {
if(key_pair.first != "shared") {
if (options_.shared_key) {
for (auto& key_pair : keys_) {
if (key_pair.first != "shared") {
key_pair.second->SetKey(new_key, key_index);
}
}
Expand All @@ -244,19 +252,20 @@ class DefaultKeyProviderImpl : public KeyProvider {
const std::vector<uint8_t> ExportSharedKey(int key_index) const override {
webrtc::MutexLock lock(&mutex_);
auto it = keys_.find("shared");
if(it == keys_.end()) {
if (it == keys_.end()) {
return std::vector<uint8_t>();
}
auto key_set = it->second->GetKeySet(key_index);
if(key_set) {
if (key_set) {
return key_set->material;
}
return std::vector<uint8_t>();
}

const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) override {
const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(
const std::string participant_id) override {
webrtc::MutexLock lock(&mutex_);
if(options_.shared_key && keys_.find("shared") != keys_.end()) {
if (options_.shared_key && keys_.find("shared") != keys_.end()) {
auto shared_key_handler = keys_["shared"];
if (keys_.find(participant_id) != keys_.end()) {
return keys_[participant_id];
Expand All @@ -276,7 +285,8 @@ class DefaultKeyProviderImpl : public KeyProvider {
webrtc::MutexLock lock(&mutex_);

if (keys_.find(participant_id) == keys_.end()) {
keys_[participant_id] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
keys_[participant_id] =
rtc::make_ref_counted<ParticipantKeyHandler>(this);
}

auto key_handler = keys_[participant_id];
Expand Down Expand Up @@ -326,7 +336,8 @@ class DefaultKeyProviderImpl : public KeyProvider {
private:
mutable webrtc::Mutex mutex_;
KeyProviderOptions options_;
std::unordered_map<std::string, rtc::scoped_refptr<ParticipantKeyHandler>> keys_;
std::unordered_map<std::string, rtc::scoped_refptr<ParticipantKeyHandler>>
keys_;
};

enum FrameCryptionState {
Expand Down Expand Up @@ -361,19 +372,20 @@ class RTC_EXPORT FrameCryptorTransformer
kAesCbc,
};

explicit FrameCryptorTransformer(rtc::Thread* signaling_thread,
const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider);

explicit FrameCryptorTransformer(
rtc::Thread* signaling_thread,
const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider);
~FrameCryptorTransformer();
virtual void RegisterFrameCryptorTransformerObserver(
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer) {
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer) {
webrtc::MutexLock lock(&mutex_);
observer_ = observer;
}

virtual void UnRegisterFrameCryptorTransformerObserver() {
virtual void UnRegisterFrameCryptorTransformerObserver() {
webrtc::MutexLock lock(&mutex_);
observer_ = nullptr;
}
Expand Down Expand Up @@ -431,6 +443,7 @@ class RTC_EXPORT FrameCryptorTransformer

private:
TaskQueueBase* const signaling_thread_;
std::unique_ptr<rtc::Thread> thread_;
std::string participant_id_;
mutable webrtc::Mutex mutex_;
mutable webrtc::Mutex sink_mutex_;
Expand All @@ -445,7 +458,7 @@ class RTC_EXPORT FrameCryptorTransformer
rtc::scoped_refptr<KeyProvider> key_provider_;
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer_;
FrameCryptionState last_enc_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
};

} // namespace webrtc
Expand Down

0 comments on commit 0649214

Please sign in to comment.