Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix progress reporting for key import #4131

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 98 additions & 8 deletions spec/integ/crypto/megolm-backup.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import { IKeyBackup } from "../../../src/crypto/backup";
import { flushPromises } from "../../test-utils/flushPromises";
import { defer, IDeferred } from "../../../src/utils";
import { DecryptionFailureCode } from "../../../src/crypto-api";
import { ImportRoomKeysOpts } from "../../../src/crypto-api";

const ROOM_ID = testData.TEST_ROOM_ID;

Expand Down Expand Up @@ -311,6 +312,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe

describe("recover from backup", () => {
let aliceCrypto: Crypto.CryptoApi;
let importMockImpl: jest.Mock;

beforeEach(async () => {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);
Expand All @@ -322,6 +324,22 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
// tell Alice to trust the dummy device that signed the backup
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);

importMockImpl = jest
.fn()
.mockImplementation((keys: IMegolmSessionData[], version: String, opts?: ImportRoomKeysOpts) => {
// need to report progress
if (opts?.progressCallback) {
opts.progressCallback({
stage: "load_keys",
successes: keys.length,
failures: 0,
total: keys.length,
});
}
});
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = importMockImpl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should definitely not do this. Doing so turns all of these tests from integration tests (where we test the whole stack, including the rust bits) into unit tests (where we are just testing the behaviour of MatrixClient).

To be honest, it is a problem that we have existing tests that mock out CryptoApi.importBackedUpRoomKeys but which claim to be integration tests, but that's a problem for another day.

});

it("can restore from backup (Curve25519 version)", async function () {
Expand Down Expand Up @@ -397,10 +415,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
}

it("Should import full backup in chunks", async function () {
const importMockImpl = jest.fn();
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = importMockImpl;

// We need several rooms with several sessions to test chunking
const { response, expectedTotal } = createBackupDownloadResponse([45, 300, 345, 12, 130]);

Expand Down Expand Up @@ -459,7 +473,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
throw new Error("test error");
})
// Ok for other chunks
.mockResolvedValue(undefined);
.mockImplementation(importMockImpl);

const { response, expectedTotal } = createBackupDownloadResponse([100, 300]);

Expand Down Expand Up @@ -498,9 +512,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
});

it("Should continue if some keys fails to decrypt", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest.fn();

const decryptionFailureCount = 2;

const mockDecryptor = {
Expand Down Expand Up @@ -540,6 +551,85 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
expect(result.imported).toStrictEqual(expectedTotal - decryptionFailureCount);
});

it("Should report failures when decryption works but import fails", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again: by mocking out importBackedUpRoomKeys, these are unit tests rather than integration tests. We have no real evidence that the actual RustCrypto implementation behaves the same way that these mocks do.

Could we not make up some data which causes RustCrypto.importBackedUpRoomKeys to report a failure, without having to mock this out?

.fn()
.mockImplementationOnce((keys: IMegolmSessionData[], version: String, opts?: ImportRoomKeysOpts) => {
// report 10 failures to import
opts!.progressCallback!({
stage: "load_keys",
successes: 20,
failures: 10,
total: 30,
});
return Promise.resolve();
})
// Ok for other chunks
.mockResolvedValue(importMockImpl);

const { response, expectedTotal } = createBackupDownloadResponse([30]);

fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response);

const check = await aliceCrypto.checkKeyBackupAndEnable();

const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);

expect(result.total).toStrictEqual(expectedTotal);
// A chunk failed to import
expect(result.imported).toStrictEqual(20);
});

it("Should report failures when decryption works but import fails - per room variant", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest
.fn()
.mockImplementationOnce((keys: IMegolmSessionData[], version: String, opts?: ImportRoomKeysOpts) => {
// report 10 failures to import
opts!.progressCallback!({
stage: "load_keys",
successes: 20,
failures: 10,
total: 30,
});
return Promise.resolve();
})
// Ok for other chunks
.mockResolvedValue(importMockImpl);

const { response, expectedTotal } = createBackupDownloadResponse([30]);
const roomId = Object.keys(response.rooms)[0];

fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/${roomId}`, response.rooms[roomId]);

const check = await aliceCrypto.checkKeyBackupAndEnable();

const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
roomId,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);

expect(result.total).toStrictEqual(expectedTotal);
// A chunk failed to import
expect(result.imported).toStrictEqual(20);
});

it("recover specific session from backup", async function () {
fetchMock.get(
"express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id",
Expand Down
44 changes: 36 additions & 8 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,13 @@ import { LocalNotificationSettings } from "./@types/local_notifications";
import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api";
import {
BootstrapCrossSigningOpts,
CrossSigningKeyInfo,
CryptoApi,
ImportRoomKeyProgressData,
ImportRoomKeysOpts,
} from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
import {
AddSecretStorageKeyOpts,
Expand Down Expand Up @@ -3916,11 +3922,18 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
async (chunk) => {
// We have a chunk of decrypted keys: import them
try {
const backupVersion = backupInfo.version!;
let success = 0;
let failures = 0;
Comment on lines +3925 to +3926
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we give these clearer names, like thisChunkSuccesses / thisChunkFailures ?

const partialProgress = (stage: ImportRoomKeyProgressData): void => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like stage is a confusing name, when ImportRoomKeyProgressData also has a property called stage ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be clearer to inline this function? YMMV

success = stage.successes ?? 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am led to wonder if we couldn't just change the type definition for ImportRoomKeyProgressData to make successes and failures mandatory.

failures = stage.failures ?? 0;
};
await this.cryptoBackend!.importBackedUpRoomKeys(chunk, backupVersion, {
untrusted,
progressCallback: partialProgress,
});
totalImported += chunk.length;
totalImported += success;
totalFailures += failures;
} catch (e) {
totalFailures += chunk.length;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks kinda wrong now: we could import some keys and then throw an exception, so we would end up with success + failure > total.

Maybe:

Suggested change
totalFailures += chunk.length;
totalFailures += (chunk.length - success);

// We failed to import some keys, but we should still try to import the rest?
Expand All @@ -3947,11 +3960,25 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
for (const k of keys) {
k.room_id = targetRoomId!;
}
await this.cryptoBackend.importBackedUpRoomKeys(keys, backupVersion, {
progressCallback,
untrusted,
});
totalImported = keys.length;
try {
let success = 0;
let failures = 0;
const partialProgress = (stage: ImportRoomKeyProgressData): void => {
success = stage.successes ?? 0;
failures = stage.failures ?? 0;
};
await this.cryptoBackend!.importBackedUpRoomKeys(keys, backupVersion, {
untrusted,
progressCallback: partialProgress,
});
totalImported += success;
totalFailures += failures;
} catch (e) {
totalFailures += keys.length;
// We failed to import some keys, but we should still try to import the rest?
// Log the error and continue
logger.error("Error importing keys from backup", e);
}
} else {
totalKeyCount = 1;
try {
Expand All @@ -3967,6 +3994,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
});
totalImported = 1;
} catch (e) {
totalFailures = 1;
this.logger.debug("Failed to decrypt megolm session from backup", e);
}
}
Expand Down
13 changes: 12 additions & 1 deletion src/rust-crypto/backup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ export class RustBackupManager extends TypedEventEmitter<RustBackupCryptoEvents,
}
keysByRoom.get(roomId)!.set(key.session_id, key);
}
await this.olmMachine.importBackedUpRoomKeys(
const result: RustSdkCryptoJs.RoomKeyImportResult = await this.olmMachine.importBackedUpRoomKeys(
keysByRoom,
(progress: BigInt, total: BigInt, failures: BigInt): void => {
const importOpt: ImportRoomKeyProgressData = {
Expand All @@ -265,6 +265,17 @@ export class RustBackupManager extends TypedEventEmitter<RustBackupCryptoEvents,
},
backupVersion,
);
// call the progress callback one last time with the final state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// call the progress callback one last time with the final state
// Call the progress callback one last time with the final state.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also extend this comment to explain why this is required? It seems like OlmMachine.importBackedUpRoomKeys runs the callback after each key so it's not obvious.

if (opts?.progressCallback) {
// We use total count here and not imported count.
// Imported count could be 0 if all the keys were already imported.
opts.progressCallback({
total: result.totalCount,
successes: result.totalCount,
stage: "load_keys",
failures: keys.length - result.totalCount,
});
}
}

private keyBackupCheckInProgress: Promise<KeyBackupCheck | null> | null = null;
Expand Down
Loading