Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/send frame cryptor events from signaling thread #95

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 26 additions & 32 deletions api/crypto/frame_crypto_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,16 @@ int AesEncryptDecrypt(EncryptOrDecrypt mode,
return AesCbcEncryptDecrypt(mode, raw_key, iv, data, buffer);
}
}

namespace webrtc {

FrameCryptorTransformer::FrameCryptorTransformer(
rtc::Thread* signaling_thread,
const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider)
: participant_id_(participant_id),
: signaling_thread_(signaling_thread),
participant_id_(participant_id),
type_(type),
algorithm_(algorithm),
key_provider_(key_provider) {
Expand Down Expand Up @@ -341,9 +342,7 @@ void FrameCryptorTransformer::encryptFrame(
<< "FrameCryptorTransformer::encryptFrame() sink_callback is NULL";
if (last_enc_error_ != FrameCryptionState::kInternalError) {
last_enc_error_ = FrameCryptionState::kInternalError;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_enc_error_);
onFrameCryptionStateChanged(last_dec_error_);
cloudwebrtc marked this conversation as resolved.
Show resolved Hide resolved
}
return;
}
Expand All @@ -365,9 +364,7 @@ void FrameCryptorTransformer::encryptFrame(
<< participant_id_;
if (last_enc_error_ != FrameCryptionState::kMissingKey) {
last_enc_error_ = FrameCryptionState::kMissingKey;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_enc_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
return;
}
Expand Down Expand Up @@ -417,17 +414,13 @@ void FrameCryptorTransformer::encryptFrame(
<< " iv=" << to_hex(iv.data(), iv.size());
if (last_enc_error_ != FrameCryptionState::kOk) {
last_enc_error_ = FrameCryptionState::kOk;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_enc_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
sink_callback->OnTransformedFrame(std::move(frame));
} else {
if (last_enc_error_ != FrameCryptionState::kEncryptionFailed) {
last_enc_error_ = FrameCryptionState::kEncryptionFailed;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_enc_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
RTC_LOG(LS_ERROR) << "FrameCryptorTransformer::encryptFrame() failed";
}
Expand All @@ -452,9 +445,7 @@ void FrameCryptorTransformer::decryptFrame(
<< "FrameCryptorTransformer::decryptFrame() sink_callback is NULL";
if (last_dec_error_ != FrameCryptionState::kInternalError) {
last_dec_error_ = FrameCryptionState::kInternalError;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_dec_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
return;
}
Expand Down Expand Up @@ -515,9 +506,7 @@ void FrameCryptorTransformer::decryptFrame(
<< static_cast<int>(getIvSize()) << "]";
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_dec_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
return;
}
Expand All @@ -534,9 +523,7 @@ void FrameCryptorTransformer::decryptFrame(
<< participant_id_;
if (last_dec_error_ != FrameCryptionState::kMissingKey) {
last_dec_error_ = FrameCryptionState::kMissingKey;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_dec_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
return;
}
Expand Down Expand Up @@ -570,7 +557,7 @@ void FrameCryptorTransformer::decryptFrame(
decryption_success = true;
} else {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() failed";
std::shared_ptr<ParticipantKeyHandler::KeySet> ratcheted_key_set;
rtc::scoped_refptr<ParticipantKeyHandler::KeySet> ratcheted_key_set;
auto currentKeyMaterial = key_set->material;
if (key_provider_->options().ratchet_window_size > 0) {
while (ratchet_count < key_provider_->options().ratchet_window_size) {
Expand All @@ -596,9 +583,7 @@ void FrameCryptorTransformer::decryptFrame(
key_handler->SetHasValidKey();
if (last_dec_error_ != FrameCryptionState::kKeyRatcheted) {
last_dec_error_ = FrameCryptionState::kKeyRatcheted;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_dec_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
break;
}
Expand All @@ -623,9 +608,7 @@ void FrameCryptorTransformer::decryptFrame(
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
key_handler->DecryptionFailure();
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_,
last_dec_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
return;
}
Expand All @@ -647,12 +630,23 @@ void FrameCryptorTransformer::decryptFrame(

if (last_dec_error_ != FrameCryptionState::kOk) {
last_dec_error_ = FrameCryptionState::kOk;
if (observer_)
observer_->OnFrameCryptionStateChanged(participant_id_, last_dec_error_);
onFrameCryptionStateChanged(last_dec_error_);
}
sink_callback->OnTransformedFrame(std::move(frame));
}

void FrameCryptorTransformer::onFrameCryptionStateChanged(FrameCryptionState state) {
webrtc::MutexLock lock(&mutex_);
if(observer_) {
RTC_DCHECK(signaling_thread_ != nullptr);
signaling_thread_->PostTask(
[observer = observer_, state = state, participant_id = participant_id_]() mutable {
observer->OnFrameCryptionStateChanged(participant_id, state);
}
);
}
}

rtc::Buffer FrameCryptorTransformer::makeIv(uint32_t ssrc, uint32_t timestamp) {
uint32_t send_count = 0;
if (send_counts_.find(ssrc) == send_counts_.end()) {
Expand Down
53 changes: 31 additions & 22 deletions api/crypto/frame_crypto_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <unordered_map>

#include "api/frame_transformer_interface.h"
#include "api/task_queue/pending_task_safety_flag.h"
#include "api/task_queue/task_queue_base.h"
#include "rtc_base/buffer.h"
#include "rtc_base/synchronization/mutex.h"
#include "rtc_base/system/rtc_export.h"
Expand Down Expand Up @@ -56,7 +58,7 @@ class KeyProvider : public rtc::RefCountInterface {

virtual bool SetSharedKey(int key_index, std::vector<uint8_t> key) = 0;

virtual const std::shared_ptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) = 0;
virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) = 0;

virtual const std::vector<uint8_t> RatchetSharedKey(int key_index) = 0;

Expand All @@ -66,7 +68,7 @@ class KeyProvider : public rtc::RefCountInterface {
int key_index,
std::vector<uint8_t> key) = 0;

virtual const std::shared_ptr<ParticipantKeyHandler> GetKey(
virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetKey(
const std::string participant_id) const = 0;

virtual const std::vector<uint8_t> RatchetKey(
Expand All @@ -84,9 +86,9 @@ class KeyProvider : public rtc::RefCountInterface {
virtual ~KeyProvider() {}
};

class ParticipantKeyHandler {
class ParticipantKeyHandler : public rtc::RefCountInterface {
public:
struct KeySet {
struct KeySet : public rtc::RefCountInterface {
std::vector<uint8_t> material;
std::vector<uint8_t> encryption_key;
KeySet(std::vector<uint8_t> material, std::vector<uint8_t> encryptionKey)
Expand All @@ -99,8 +101,8 @@ class ParticipantKeyHandler {

virtual ~ParticipantKeyHandler() = default;

std::shared_ptr<ParticipantKeyHandler> Clone() {
auto clone = std::make_shared<ParticipantKeyHandler>(key_provider_);
rtc::scoped_refptr<ParticipantKeyHandler> Clone() {
auto clone = rtc::make_ref_counted<ParticipantKeyHandler>(key_provider_);
clone->crypto_key_ring_ = crypto_key_ring_;
clone->current_key_index_ = current_key_index_;
clone->has_valid_key_ = has_valid_key_;
Expand All @@ -124,7 +126,7 @@ class ParticipantKeyHandler {
return new_material;
}

virtual std::shared_ptr<KeySet> GetKeySet(int key_index) {
virtual rtc::scoped_refptr<KeySet> GetKeySet(int key_index) {
webrtc::MutexLock lock(&mutex_);
return crypto_key_ring_[key_index != -1 ? key_index : current_key_index_];
}
Expand All @@ -144,13 +146,13 @@ class ParticipantKeyHandler {
return new_material;
}

std::shared_ptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
rtc::scoped_refptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> derived_key;
if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits,
&derived_key) == 0) {
return std::make_shared<KeySet>(password, derived_key);
return rtc::make_ref_counted<KeySet>(password, derived_key);
}
return nullptr;
}
Expand Down Expand Up @@ -193,7 +195,7 @@ class ParticipantKeyHandler {
mutable webrtc::Mutex mutex_;
int current_key_index_ = 0;
KeyProvider* key_provider_;
std::vector<std::shared_ptr<KeySet>> crypto_key_ring_;
std::vector<rtc::scoped_refptr<KeySet>> crypto_key_ring_;
};

class DefaultKeyProviderImpl : public KeyProvider {
Expand All @@ -206,7 +208,7 @@ class DefaultKeyProviderImpl : public KeyProvider {
webrtc::MutexLock lock(&mutex_);
if(options_.shared_key) {
if (keys_.find("shared") == keys_.end()) {
keys_["shared"] = std::make_shared<ParticipantKeyHandler>(this);
keys_["shared"] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
}

auto key_handler = keys_["shared"];
Expand Down Expand Up @@ -252,7 +254,7 @@ class DefaultKeyProviderImpl : public KeyProvider {
return std::vector<uint8_t>();
}

const std::shared_ptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) override {
const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) override {
webrtc::MutexLock lock(&mutex_);
if(options_.shared_key && keys_.find("shared") != keys_.end()) {
auto shared_key_handler = keys_["shared"];
Expand All @@ -274,15 +276,15 @@ class DefaultKeyProviderImpl : public KeyProvider {
webrtc::MutexLock lock(&mutex_);

if (keys_.find(participant_id) == keys_.end()) {
keys_[participant_id] = std::make_shared<ParticipantKeyHandler>(this);
keys_[participant_id] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
}

auto key_handler = keys_[participant_id];
key_handler->SetKey(key, index);
return true;
}

const std::shared_ptr<ParticipantKeyHandler> GetKey(
const rtc::scoped_refptr<ParticipantKeyHandler> GetKey(
const std::string participant_id) const override {
webrtc::MutexLock lock(&mutex_);

Expand Down Expand Up @@ -324,7 +326,7 @@ class DefaultKeyProviderImpl : public KeyProvider {
private:
mutable webrtc::Mutex mutex_;
KeyProviderOptions options_;
std::unordered_map<std::string, std::shared_ptr<ParticipantKeyHandler>> keys_;
std::unordered_map<std::string, rtc::scoped_refptr<ParticipantKeyHandler>> keys_;
};

enum FrameCryptionState {
Expand All @@ -337,7 +339,7 @@ enum FrameCryptionState {
kInternalError,
};

class FrameCryptorTransformerObserver {
class FrameCryptorTransformerObserver : public rtc::RefCountInterface {
public:
virtual void OnFrameCryptionStateChanged(const std::string participant_id,
FrameCryptionState error) = 0;
Expand All @@ -359,17 +361,23 @@ class RTC_EXPORT FrameCryptorTransformer
kAesCbc,
};

explicit FrameCryptorTransformer(const std::string participant_id,
explicit FrameCryptorTransformer(rtc::Thread* signaling_thread,
const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider);

virtual void SetFrameCryptorTransformerObserver(
FrameCryptorTransformerObserver* observer) {
virtual void RegisterFrameCryptorTransformerObserver(
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer) {
webrtc::MutexLock lock(&mutex_);
observer_ = observer;
}

virtual void UnRegisterFrameCryptorTransformerObserver() {
webrtc::MutexLock lock(&mutex_);
observer_ = nullptr;
}

virtual void SetKeyIndex(int index) {
webrtc::MutexLock lock(&mutex_);
key_index_ = index;
Expand Down Expand Up @@ -417,10 +425,12 @@ class RTC_EXPORT FrameCryptorTransformer
private:
void encryptFrame(std::unique_ptr<webrtc::TransformableFrameInterface> frame);
void decryptFrame(std::unique_ptr<webrtc::TransformableFrameInterface> frame);
void onFrameCryptionStateChanged(FrameCryptionState error);
rtc::Buffer makeIv(uint32_t ssrc, uint32_t timestamp);
uint8_t getIvSize();

private:
TaskQueueBase* const signaling_thread_;
cloudwebrtc marked this conversation as resolved.
Show resolved Hide resolved
std::string participant_id_;
mutable webrtc::Mutex mutex_;
mutable webrtc::Mutex sink_mutex_;
Expand All @@ -433,10 +443,9 @@ class RTC_EXPORT FrameCryptorTransformer
int key_index_ = 0;
std::map<uint32_t, uint32_t> send_counts_;
rtc::scoped_refptr<KeyProvider> key_provider_;
FrameCryptorTransformerObserver* observer_ = nullptr;
std::unique_ptr<rtc::Thread> thread_;
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer_;
FrameCryptionState last_enc_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
};

} // namespace webrtc
Expand Down
12 changes: 6 additions & 6 deletions sdk/android/api/org/webrtc/FrameCryptorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ public static FrameCryptorKeyProvider createFrameCryptorKeyProvider(
return nativeCreateFrameCryptorKeyProvider(sharedKey, ratchetSalt, ratchetWindowSize, uncryptedMagicBytes, failureTolerance);
}

public static FrameCryptor createFrameCryptorForRtpSender(RtpSender rtpSender,
public static FrameCryptor createFrameCryptorForRtpSender(PeerConnectionFactory factory, RtpSender rtpSender,
String participantId, FrameCryptorAlgorithm algorithm, FrameCryptorKeyProvider keyProvider) {
return nativeCreateFrameCryptorForRtpSender(rtpSender.getNativeRtpSender(), participantId,
return nativeCreateFrameCryptorForRtpSender(factory.getNativePeerConnectionFactory(),rtpSender.getNativeRtpSender(), participantId,
algorithm.ordinal(), keyProvider.getNativeKeyProvider());
}

public static FrameCryptor createFrameCryptorForRtpReceiver(RtpReceiver rtpReceiver,
public static FrameCryptor createFrameCryptorForRtpReceiver(PeerConnectionFactory factory, RtpReceiver rtpReceiver,
String participantId, FrameCryptorAlgorithm algorithm, FrameCryptorKeyProvider keyProvider) {
return nativeCreateFrameCryptorForRtpReceiver(rtpReceiver.getNativeRtpReceiver(), participantId,
return nativeCreateFrameCryptorForRtpReceiver(factory.getNativePeerConnectionFactory(), rtpReceiver.getNativeRtpReceiver(), participantId,
algorithm.ordinal(), keyProvider.getNativeKeyProvider());
}

private static native FrameCryptor nativeCreateFrameCryptorForRtpSender(
private static native FrameCryptor nativeCreateFrameCryptorForRtpSender(long factory,
long rtpSender, String participantId, int algorithm, long nativeFrameCryptorKeyProvider);
private static native FrameCryptor nativeCreateFrameCryptorForRtpReceiver(
private static native FrameCryptor nativeCreateFrameCryptorForRtpReceiver(long factory,
long rtpReceiver, String participantId, int algorithm, long nativeFrameCryptorKeyProvider);

private static native FrameCryptorKeyProvider nativeCreateFrameCryptorKeyProvider(
Expand Down
Loading