diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml
index 9502e7a0b185..c6369340b5c3 100644
--- a/.github/workflows/full.yml
+++ b/.github/workflows/full.yml
@@ -165,6 +165,8 @@ jobs:
id: run_tests
run: |
pytest tests/python/
+ ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py
+ CVAT_ALLOW_STATIC_CACHE="true" pytest -k "TestTaskData" tests/python
- name: Creating a log file from cvat containers
if: failure() && steps.run_tests.conclusion == 'failure'
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 75b200597f47..0c9211b0c4a5 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -177,8 +177,9 @@ jobs:
COVERAGE_PROCESS_START: ".coveragerc"
run: |
pytest tests/python/ --cov --cov-report=json
- for COVERAGE_FILE in `find -name "coverage*.json" -type f -printf "%f\n"`; do mv ${COVERAGE_FILE} "${COVERAGE_FILE%%.*}_0.json"; done
ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py --cov --cov-report=json
+ CVAT_ALLOW_STATIC_CACHE="true" pytest -k "TestTaskData" tests/python --cov --cov-report=json
+ for COVERAGE_FILE in `find -name "coverage*.json" -type f -printf "%f\n"`; do mv ${COVERAGE_FILE} "${COVERAGE_FILE%%.*}_0.json"; done
- name: Uploading code coverage results as an artifact
uses: actions/upload-artifact@v4
diff --git a/.github/workflows/schedule.yml b/.github/workflows/schedule.yml
index d8e514cbb449..c2071cd85d13 100644
--- a/.github/workflows/schedule.yml
+++ b/.github/workflows/schedule.yml
@@ -170,6 +170,12 @@ jobs:
pytest tests/python/
pytest tests/python/ --stop-services
+ ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py
+ pytest tests/python/ --stop-services
+
+ CVAT_ALLOW_STATIC_CACHE="true" pytest tests/python
+ pytest tests/python/ --stop-services
+
- name: Unit tests
env:
HOST_COVERAGE_DATA_DIR: ${{ github.workspace }}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 76694d291264..c5ce8ad0c8db 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,45 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
+
+## \[2.19.1\] - 2024-09-26
+
+### Security
+
+- Fixed a security issue that occurred in PATCH requests to projects|tasks|jobs|memberships
+ ()
+
+
+## \[2.19.0\] - 2024-09-20
+
+### Added
+
+- Quality management tab on `quality control` allows to enabling/disabling GT frames
+ ()
+
+### Changed
+
+- Moved quality control from `analytics` page to `quality control` page
+ ()
+
+### Removed
+
+- Quality report no longer available in CVAT community version
+ ()
+
+### Fixed
+
+- Fixing a problem when project export does not export skeleton tracks
+ ()
+
+### Security
+
+- Fixed an XSS vulnerability in request-related endpoints
+ ()
+
+- Fixed an XSS vulnerability in the quality report data endpoint
+ ()
+
## \[2.18.0\] - 2024-09-10
diff --git a/changelog.d/20240812_161617_mzhiltso_job_chunks.md b/changelog.d/20240812_161617_mzhiltso_job_chunks.md
new file mode 100644
index 000000000000..af931641d6df
--- /dev/null
+++ b/changelog.d/20240812_161617_mzhiltso_job_chunks.md
@@ -0,0 +1,24 @@
+### Added
+
+- A server setting to enable or disable storage of permanent media chunks on the server filesystem
+ ()
+- \[Server API\] `GET /api/jobs/{id}/data/?type=chunk&index=x` parameter combination.
+ The new `index` parameter allows to retrieve job chunks using 0-based index in each job,
+ instead of the `number` parameter, which used task chunk ids.
+ ()
+
+### Changed
+
+- Job assignees will not receive frames from adjacent jobs in chunks
+ ()
+
+### Deprecated
+
+- \[Server API\] `GET /api/jobs/{id}/data/?type=chunk&number=x` parameter combination
+ ()
+
+
+### Fixed
+
+- Various memory leaks in video reading on the server
+ ()
diff --git a/changelog.d/20240826_093730_klakhov_support_quality_plugin.md b/changelog.d/20240826_093730_klakhov_support_quality_plugin.md
deleted file mode 100644
index 2b9324c2deb8..000000000000
--- a/changelog.d/20240826_093730_klakhov_support_quality_plugin.md
+++ /dev/null
@@ -1,14 +0,0 @@
-### Changed
-
-- Moved quality control from `analytics` page to `quality control` page
- ()
-
-### Removed
-
-- Quality report no longer available in CVAT community version
- ()
-
-### Added
-
-- Quality management tab on `quality control` allows to enabling/disabling GT frames
- ()
diff --git a/changelog.d/20240910_134547_dmitrii.lavrukhin_fix_skeleton_project_export.md b/changelog.d/20240910_134547_dmitrii.lavrukhin_fix_skeleton_project_export.md
deleted file mode 100644
index 241a24222121..000000000000
--- a/changelog.d/20240910_134547_dmitrii.lavrukhin_fix_skeleton_project_export.md
+++ /dev/null
@@ -1,4 +0,0 @@
-### Fixed
-
-- Fixing a problem when project export does not export skeleton tracks
- ()
diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt
index faf23813c8aa..5803973a9bcf 100644
--- a/cvat-cli/requirements/base.txt
+++ b/cvat-cli/requirements/base.txt
@@ -1,3 +1,3 @@
-cvat-sdk~=2.19.0
+cvat-sdk~=2.20.0
Pillow>=10.3.0
setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability
diff --git a/cvat-cli/src/cvat_cli/version.py b/cvat-cli/src/cvat_cli/version.py
index e982493d5113..27f8f2d92f8d 100644
--- a/cvat-cli/src/cvat_cli/version.py
+++ b/cvat-cli/src/cvat_cli/version.py
@@ -1 +1 @@
-VERSION = "2.19.0"
+VERSION = "2.20.0"
diff --git a/cvat-core/src/frames.ts b/cvat-core/src/frames.ts
index 96295af7d57d..dda847cf7a72 100644
--- a/cvat-core/src/frames.ts
+++ b/cvat-core/src/frames.ts
@@ -3,7 +3,7 @@
//
// SPDX-License-Identifier: MIT
-import _ from 'lodash';
+import _, { range, sortedIndexOf } from 'lodash';
import {
FrameDecoder, BlockType, DimensionType, ChunkQuality, decodeContextImages, RequestOutdatedError,
} from 'cvat-data';
@@ -25,7 +25,7 @@ const frameDataCache: Record | null;
activeContextRequest: Promise> | null;
@@ -34,7 +34,7 @@ const frameDataCache: Record;
- getChunk: (chunkNumber: number, quality: ChunkQuality) => Promise;
+ getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise;
}> = {};
// frame meta data storage by job id
@@ -55,6 +55,8 @@ export class FramesMetaData {
public size: number;
public startFrame: number;
public stopFrame: number;
+ public frameStep: number;
+ public chunkCount: number;
#updateTrigger: FieldUpdateTrigger;
@@ -103,6 +105,17 @@ export class FramesMetaData {
}
}
+ const frameStep: number = (() => {
+ if (data.frame_filter) {
+ const frameStepParts = data.frame_filter.split('=', 2);
+ if (frameStepParts.length !== 2) {
+ throw new ArgumentError(`Invalid frame filter '${data.frame_filter}'`);
+ }
+ return +frameStepParts[1];
+ }
+ return 1;
+ })();
+
Object.defineProperties(
this,
Object.freeze({
@@ -133,6 +146,20 @@ export class FramesMetaData {
stopFrame: {
get: () => data.stop_frame,
},
+ frameStep: {
+ get: () => frameStep,
+ },
+ }),
+ );
+
+ const chunkCount: number = Math.ceil(this.getDataFrameNumbers().length / this.chunkSize);
+
+ Object.defineProperties(
+ this,
+ Object.freeze({
+ chunkCount: {
+ get: () => chunkCount,
+ },
}),
);
}
@@ -144,6 +171,40 @@ export class FramesMetaData {
resetUpdated(): void {
this.#updateTrigger.reset();
}
+
+ getFrameIndex(dataFrameNumber: number): number {
+ // Here we use absolute (task source data) frame numbers.
+ // TODO: migrate from data frame numbers to local frame numbers to simplify code.
+ // Requires server changes in api/jobs/{id}/data/meta/
+ // for included_frames, start_frame, stop_frame fields
+
+ if (dataFrameNumber < this.startFrame || dataFrameNumber > this.stopFrame) {
+ throw new ArgumentError(`Frame number ${dataFrameNumber} doesn't belong to the job`);
+ }
+
+ let frameIndex = null;
+ if (this.includedFrames) {
+ frameIndex = sortedIndexOf(this.includedFrames, dataFrameNumber);
+ if (frameIndex === -1) {
+ throw new ArgumentError(`Frame number ${dataFrameNumber} doesn't belong to the job`);
+ }
+ } else {
+ frameIndex = Math.floor((dataFrameNumber - this.startFrame) / this.frameStep);
+ }
+ return frameIndex;
+ }
+
+ getFrameChunkIndex(dataFrameNumber: number): number {
+ return Math.floor(this.getFrameIndex(dataFrameNumber) / this.chunkSize);
+ }
+
+ getDataFrameNumbers(): number[] {
+ if (this.includedFrames) {
+ return this.includedFrames;
+ }
+
+ return range(this.startFrame, this.stopFrame + 1, this.frameStep);
+ }
}
export class FrameData {
@@ -206,12 +267,14 @@ export class FrameData {
}
class PrefetchAnalyzer {
- #chunkSize: number;
#requestedFrames: number[];
+ #meta: FramesMetaData;
+ #getDataFrameNumber: (frameNumber: number) => number;
- constructor(chunkSize) {
- this.#chunkSize = chunkSize;
+ constructor(meta: FramesMetaData, dataFrameNumberGetter: (frameNumber: number) => number) {
this.#requestedFrames = [];
+ this.#meta = meta;
+ this.#getDataFrameNumber = dataFrameNumberGetter;
}
shouldPrefetchNext(current: number, isPlaying: boolean, isChunkCached: (chunk) => boolean): boolean {
@@ -219,13 +282,16 @@ class PrefetchAnalyzer {
return true;
}
- const currentChunk = Math.floor(current / this.#chunkSize);
+ const currentDataFrameNumber = this.#getDataFrameNumber(current);
+ const currentChunk = this.#meta.getFrameChunkIndex(currentDataFrameNumber);
const { length } = this.#requestedFrames;
const isIncreasingOrder = this.#requestedFrames
.every((val, index) => index === 0 || val > this.#requestedFrames[index - 1]);
if (
length && (isIncreasingOrder && current > this.#requestedFrames[length - 1]) &&
- (current % this.#chunkSize) >= Math.ceil(this.#chunkSize / 2) &&
+ (
+ this.#meta.getFrameIndex(currentDataFrameNumber) % this.#meta.chunkSize
+ ) >= Math.ceil(this.#meta.chunkSize / 2) &&
!isChunkCached(currentChunk + 1)
) {
// is increasing order including the current frame
@@ -247,13 +313,25 @@ class PrefetchAnalyzer {
this.#requestedFrames.push(frame);
// only half of chunk size is considered in this logic
- const limit = Math.ceil(this.#chunkSize / 2);
+ const limit = Math.ceil(this.#meta.chunkSize / 2);
if (this.#requestedFrames.length > limit) {
this.#requestedFrames.shift();
}
}
}
+function getDataStartFrame(meta: FramesMetaData, localStartFrame: number): number {
+ return meta.startFrame - localStartFrame * meta.frameStep;
+}
+
+function getDataFrameNumber(frameNumber: number, dataStartFrame: number, step: number): number {
+ return frameNumber * step + dataStartFrame;
+}
+
+function getFrameNumber(dataFrameNumber: number, dataStartFrame: number, step: number): number {
+ return (dataFrameNumber - dataStartFrame) / step;
+}
+
Object.defineProperty(FrameData.prototype.data, 'implementation', {
value(this: FrameData, onServerRequest) {
return new Promise<{
@@ -262,40 +340,57 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
imageData: ImageBitmap | Blob;
} | Blob>((resolve, reject) => {
const {
- provider, prefetchAnalizer, chunkSize, stopFrame, decodeForward, forwardStep, decodedBlocksCacheSize,
+ meta, provider, prefetchAnalyzer, chunkSize, startFrame,
+ decodeForward, forwardStep, decodedBlocksCacheSize,
} = frameDataCache[this.jobID];
const requestId = +_.uniqueId();
- const chunkNumber = Math.floor(this.number / chunkSize);
+ const dataStartFrame = getDataStartFrame(meta, startFrame);
+ const requestedDataFrameNumber = getDataFrameNumber(
+ this.number, dataStartFrame, meta.frameStep,
+ );
+ const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber);
+ const segmentFrameNumbers = meta.getDataFrameNumbers().map(
+ (dataFrameNumber: number) => getFrameNumber(
+ dataFrameNumber, dataStartFrame, meta.frameStep,
+ ),
+ );
const frame = provider.frame(this.number);
- function findTheNextNotDecodedChunk(searchFrom: number): number {
- let firstFrameInNextChunk = searchFrom + forwardStep;
- let nextChunkNumber = Math.floor(firstFrameInNextChunk / chunkSize);
- while (nextChunkNumber === chunkNumber) {
- firstFrameInNextChunk += forwardStep;
- nextChunkNumber = Math.floor(firstFrameInNextChunk / chunkSize);
+ function findTheNextNotDecodedChunk(currentFrameIndex: number): number | null {
+ const { chunkCount } = meta;
+ let nextFrameIndex = currentFrameIndex + forwardStep;
+ let nextChunkIndex = Math.floor(nextFrameIndex / chunkSize);
+ while (nextChunkIndex === chunkIndex) {
+ nextFrameIndex += forwardStep;
+ nextChunkIndex = Math.floor(nextFrameIndex / chunkSize);
}
- if (provider.isChunkCached(nextChunkNumber)) {
- return findTheNextNotDecodedChunk(firstFrameInNextChunk);
+ if (nextChunkIndex < 0 || chunkCount <= nextChunkIndex) {
+ return null;
}
- return nextChunkNumber;
+ if (provider.isChunkCached(nextChunkIndex)) {
+ return findTheNextNotDecodedChunk(nextFrameIndex);
+ }
+
+ return nextChunkIndex;
}
if (frame) {
if (
- prefetchAnalizer.shouldPrefetchNext(
+ prefetchAnalyzer.shouldPrefetchNext(
this.number,
decodeForward,
(chunk) => provider.isChunkCached(chunk),
) && decodedBlocksCacheSize > 1 && !frameDataCache[this.jobID].activeChunkRequest
) {
- const nextChunkNumber = findTheNextNotDecodedChunk(this.number);
+ const nextChunkIndex = findTheNextNotDecodedChunk(
+ meta.getFrameIndex(requestedDataFrameNumber),
+ );
const predecodeChunksMax = Math.floor(decodedBlocksCacheSize / 2);
- if (nextChunkNumber * chunkSize <= stopFrame &&
- nextChunkNumber <= chunkNumber + predecodeChunksMax
+ if (nextChunkIndex !== null &&
+ nextChunkIndex <= chunkIndex + predecodeChunksMax
) {
frameDataCache[this.jobID].activeChunkRequest = new Promise((resolveForward) => {
const releasePromise = (): void => {
@@ -304,7 +399,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
};
frameDataCache[this.jobID].getChunk(
- nextChunkNumber, ChunkQuality.COMPRESSED,
+ nextChunkIndex, ChunkQuality.COMPRESSED,
).then((chunk: ArrayBuffer) => {
if (!(this.jobID in frameDataCache)) {
// check if frameDataCache still exist
@@ -316,8 +411,11 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
provider.cleanup(1);
provider.requestDecodeBlock(
chunk,
- nextChunkNumber * chunkSize,
- Math.min(stopFrame, (nextChunkNumber + 1) * chunkSize - 1),
+ nextChunkIndex,
+ segmentFrameNumbers.slice(
+ nextChunkIndex * chunkSize,
+ (nextChunkIndex + 1) * chunkSize,
+ ),
() => {},
releasePromise,
releasePromise,
@@ -334,7 +432,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
renderHeight: this.height,
imageData: frame,
});
- prefetchAnalizer.addRequested(this.number);
+ prefetchAnalyzer.addRequested(this.number);
return;
}
@@ -355,7 +453,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
renderHeight: this.height,
imageData: currentFrame,
});
- prefetchAnalizer.addRequested(this.number);
+ prefetchAnalyzer.addRequested(this.number);
return;
}
@@ -364,7 +462,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
) => {
let wasResolved = false;
frameDataCache[this.jobID].getChunk(
- chunkNumber, ChunkQuality.COMPRESSED,
+ chunkIndex, ChunkQuality.COMPRESSED,
).then((chunk: ArrayBuffer) => {
try {
if (!(this.jobID in frameDataCache)) {
@@ -378,8 +476,11 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
provider
.requestDecodeBlock(
chunk,
- chunkNumber * chunkSize,
- Math.min(stopFrame, (chunkNumber + 1) * chunkSize - 1),
+ chunkIndex,
+ segmentFrameNumbers.slice(
+ chunkIndex * chunkSize,
+ (chunkIndex + 1) * chunkSize,
+ ),
(_frame: number, bitmap: ImageBitmap | Blob) => {
if (decodeForward) {
// resolve immediately only if is not playing
@@ -395,7 +496,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
renderHeight: this.height,
imageData: bitmap,
});
- prefetchAnalizer.addRequested(this.number);
+ prefetchAnalyzer.addRequested(this.number);
}
}, () => {
frameDataCache[this.jobID].activeChunkRequest = null;
@@ -592,7 +693,7 @@ export async function getFrame(
isPlaying: boolean,
step: number,
dimension: DimensionType,
- getChunk: (chunkNumber: number, quality: ChunkQuality) => Promise,
+ getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise,
): Promise {
if (!(jobID in frameDataCache)) {
const blockType = chunkType === 'video' ? BlockType.MP4VIDEO : BlockType.ARCHIVE;
@@ -608,6 +709,13 @@ export async function getFrame(
const decodedBlocksCacheSize = Math.min(
Math.floor((2048 * 1024 * 1024) / ((mean + stdDev) * 4 * chunkSize)) || 1, 10,
);
+
+ // TODO: migrate to local frame numbers
+ const dataStartFrame = getDataStartFrame(meta, startFrame);
+ const dataFrameNumberGetter = (frameNumber: number): number => (
+ getDataFrameNumber(frameNumber, dataStartFrame, meta.frameStep)
+ );
+
frameDataCache[jobID] = {
meta,
chunkSize,
@@ -618,11 +726,13 @@ export async function getFrame(
forwardStep: step,
provider: new FrameDecoder(
blockType,
- chunkSize,
decodedBlocksCacheSize,
+ (frameNumber: number): number => (
+ meta.getFrameChunkIndex(dataFrameNumberGetter(frameNumber))
+ ),
dimension,
),
- prefetchAnalizer: new PrefetchAnalyzer(chunkSize),
+ prefetchAnalyzer: new PrefetchAnalyzer(meta, dataFrameNumberGetter),
decodedBlocksCacheSize,
activeChunkRequest: null,
activeContextRequest: null,
@@ -697,8 +807,11 @@ export async function findFrame(
let lastUndeletedFrame = null;
const check = (frame): boolean => {
if (meta.includedFrames) {
- return (meta.includedFrames.includes(frame)) &&
- (!filters.notDeleted || !(frame in meta.deletedFrames));
+ // meta.includedFrames contains input frame numbers now
+ const dataStartFrame = meta.startFrame; // this is only true when includedFrames is set
+ return (meta.includedFrames.includes(
+ getDataFrameNumber(frame, dataStartFrame, meta.frameStep))
+ ) && (!filters.notDeleted || !(frame in meta.deletedFrames));
}
if (filters.notDeleted) {
return !(frame in meta.deletedFrames);
@@ -726,6 +839,18 @@ export function getCachedChunks(jobID): number[] {
return frameDataCache[jobID].provider.cachedChunks(true);
}
+export function getJobFrameNumbers(jobID): number[] {
+ if (!(jobID in frameDataCache)) {
+ return [];
+ }
+
+ const { meta, startFrame } = frameDataCache[jobID];
+ const dataStartFrame = getDataStartFrame(meta, startFrame);
+ return meta.getDataFrameNumbers().map((dataFrameNumber: number): number => (
+ getFrameNumber(dataFrameNumber, dataStartFrame, meta.frameStep)
+ ));
+}
+
export function clear(jobID: number): void {
if (jobID in frameDataCache) {
frameDataCache[jobID].provider.close();
diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts
index 91dc52a71821..51309426198a 100644
--- a/cvat-core/src/server-proxy.ts
+++ b/cvat-core/src/server-proxy.ts
@@ -1438,7 +1438,7 @@ async function getData(jid: number, chunk: number, quality: ChunkQuality, retry
...enableOrganization(),
quality,
type: 'chunk',
- number: chunk,
+ index: chunk,
},
responseType: 'arraybuffer',
});
diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts
index fa77c934abde..961771708726 100644
--- a/cvat-core/src/session-implementation.ts
+++ b/cvat-core/src/session-implementation.ts
@@ -18,6 +18,7 @@ import {
deleteFrame,
restoreFrame,
getCachedChunks,
+ getJobFrameNumbers,
clear as clearFrames,
findFrame,
getContextImage,
@@ -189,7 +190,7 @@ export function implementJob(Job: typeof JobClass): typeof JobClass {
isPlaying,
step,
this.dimension,
- (chunkNumber, quality) => this.frames.chunk(chunkNumber, quality),
+ (chunkIndex, quality) => this.frames.chunk(chunkIndex, quality),
);
},
});
@@ -244,6 +245,14 @@ export function implementJob(Job: typeof JobClass): typeof JobClass {
},
});
+ Object.defineProperty(Job.prototype.frames.frameNumbers, 'implementation', {
+ value: function includedFramesImplementation(
+ this: JobClass,
+ ): ReturnType {
+ return Promise.resolve(getJobFrameNumbers(this.id));
+ },
+ });
+
Object.defineProperty(Job.prototype.frames.preview, 'implementation', {
value: function previewImplementation(
this: JobClass,
@@ -273,10 +282,10 @@ export function implementJob(Job: typeof JobClass): typeof JobClass {
Object.defineProperty(Job.prototype.frames.chunk, 'implementation', {
value: function chunkImplementation(
this: JobClass,
- chunkNumber: Parameters[0],
+ chunkIndex: Parameters[0],
quality: Parameters[1],
): ReturnType {
- return serverProxy.frames.getData(this.id, chunkNumber, quality);
+ return serverProxy.frames.getData(this.id, chunkIndex, quality);
},
});
@@ -829,7 +838,7 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass {
isPlaying,
step,
this.dimension,
- (chunkNumber, quality) => job.frames.chunk(chunkNumber, quality),
+ (chunkIndex, quality) => job.frames.chunk(chunkIndex, quality),
);
return result;
},
diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts
index 1985a72b2683..54133ff6b667 100644
--- a/cvat-core/src/session.ts
+++ b/cvat-core/src/session.ts
@@ -233,6 +233,10 @@ function buildDuplicatedAPI(prototype) {
const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.cachedChunks);
return result;
},
+ async frameNumbers() {
+ const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.frameNumbers);
+ return result;
+ },
async preview() {
const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.preview);
return result;
@@ -255,11 +259,11 @@ function buildDuplicatedAPI(prototype) {
);
return result;
},
- async chunk(chunkNumber, quality) {
+ async chunk(chunkIndex, quality) {
const result = await PluginRegistry.apiWrapper.call(
this,
prototype.frames.chunk,
- chunkNumber,
+ chunkIndex,
quality,
);
return result;
@@ -380,6 +384,7 @@ export class Session {
restore: (frame: number) => Promise;
save: () => Promise;
cachedChunks: () => Promise;
+ frameNumbers: () => Promise;
preview: () => Promise;
contextImage: (frame: number) => Promise>;
search: (
@@ -443,6 +448,7 @@ export class Session {
restore: Object.getPrototypeOf(this).frames.restore.bind(this),
save: Object.getPrototypeOf(this).frames.save.bind(this),
cachedChunks: Object.getPrototypeOf(this).frames.cachedChunks.bind(this),
+ frameNumbers: Object.getPrototypeOf(this).frames.frameNumbers.bind(this),
preview: Object.getPrototypeOf(this).frames.preview.bind(this),
search: Object.getPrototypeOf(this).frames.search.bind(this),
contextImage: Object.getPrototypeOf(this).frames.contextImage.bind(this),
diff --git a/cvat-data/src/ts/cvat-data.ts b/cvat-data/src/ts/cvat-data.ts
index 2f832ac9d3f5..baf00ac443c1 100644
--- a/cvat-data/src/ts/cvat-data.ts
+++ b/cvat-data/src/ts/cvat-data.ts
@@ -72,8 +72,8 @@ export function decodeContextImages(
decodeContextImages.mutex = new Mutex();
interface BlockToDecode {
- start: number;
- end: number;
+ chunkFrameNumbers: number[];
+ chunkIndex: number;
block: ArrayBuffer;
onDecodeAll(): void;
onDecode(frame: number, bitmap: ImageBitmap | Blob): void;
@@ -82,7 +82,6 @@ interface BlockToDecode {
export class FrameDecoder {
private blockType: BlockType;
- private chunkSize: number;
/*
ImageBitmap when decode zip or video chunks
Blob when 3D dimension
@@ -100,11 +99,12 @@ export class FrameDecoder {
private renderHeight: number;
private zipWorker: Worker | null;
private videoWorker: Worker | null;
+ private getChunkIndex: (frame: number) => number;
constructor(
blockType: BlockType,
- chunkSize: number,
cachedBlockCount: number,
+ getChunkIndex: (frame: number) => number,
dimension: DimensionType = DimensionType.DIMENSION_2D,
) {
this.mutex = new Mutex();
@@ -117,7 +117,7 @@ export class FrameDecoder {
this.renderWidth = 1920;
this.renderHeight = 1080;
- this.chunkSize = chunkSize;
+ this.getChunkIndex = getChunkIndex;
this.blockType = blockType;
this.decodedChunks = {};
@@ -125,8 +125,8 @@ export class FrameDecoder {
this.chunkIsBeingDecoded = null;
}
- isChunkCached(chunkNumber: number): boolean {
- return chunkNumber in this.decodedChunks;
+ isChunkCached(chunkIndex: number): boolean {
+ return chunkIndex in this.decodedChunks;
}
hasFreeSpace(): boolean {
@@ -155,17 +155,37 @@ export class FrameDecoder {
}
}
+ private validateFrameNumbers(frameNumbers: number[]): void {
+ if (!Array.isArray(frameNumbers) || !frameNumbers.length) {
+ throw new Error('chunkFrameNumbers must not be empty');
+ }
+
+ // ensure is ordered
+ for (let i = 1; i < frameNumbers.length; ++i) {
+ const prev = frameNumbers[i - 1];
+ const current = frameNumbers[i];
+ if (current <= prev) {
+ throw new Error(
+ 'chunkFrameNumbers must be sorted in the ascending order, ' +
+ `got a (${prev}, ${current}) pair instead`,
+ );
+ }
+ }
+ }
+
requestDecodeBlock(
block: ArrayBuffer,
- start: number,
- end: number,
+ chunkIndex: number,
+ chunkFrameNumbers: number[],
onDecode: (frame: number, bitmap: ImageBitmap | Blob) => void,
onDecodeAll: () => void,
onReject: (e: Error) => void,
): void {
+ this.validateFrameNumbers(chunkFrameNumbers);
+
if (this.requestedChunkToDecode !== null) {
// a chunk was already requested to be decoded, but decoding didn't start yet
- if (start === this.requestedChunkToDecode.start && end === this.requestedChunkToDecode.end) {
+ if (chunkIndex === this.requestedChunkToDecode.chunkIndex) {
// it was the same chunk
this.requestedChunkToDecode.onReject(new RequestOutdatedError());
@@ -175,12 +195,14 @@ export class FrameDecoder {
// it was other chunk
this.requestedChunkToDecode.onReject(new RequestOutdatedError());
}
- } else if (this.chunkIsBeingDecoded === null || this.chunkIsBeingDecoded.start !== start) {
+ } else if (this.chunkIsBeingDecoded === null ||
+ chunkIndex !== this.chunkIsBeingDecoded.chunkIndex
+ ) {
// everything was decoded or decoding other chunk is in process
this.requestedChunkToDecode = {
+ chunkFrameNumbers,
+ chunkIndex,
block,
- start,
- end,
onDecode,
onDecodeAll,
onReject,
@@ -203,9 +225,9 @@ export class FrameDecoder {
}
frame(frameNumber: number): ImageBitmap | Blob | null {
- const chunkNumber = Math.floor(frameNumber / this.chunkSize);
- if (chunkNumber in this.decodedChunks) {
- return this.decodedChunks[chunkNumber][frameNumber];
+ const chunkIndex = this.getChunkIndex(frameNumber);
+ if (chunkIndex in this.decodedChunks) {
+ return this.decodedChunks[chunkIndex][frameNumber];
}
return null;
@@ -253,8 +275,8 @@ export class FrameDecoder {
releaseMutex();
};
try {
- const { start, end, block } = this.requestedChunkToDecode;
- if (start !== blockToDecode.start) {
+ const { chunkFrameNumbers, chunkIndex, block } = this.requestedChunkToDecode;
+ if (chunkIndex !== blockToDecode.chunkIndex) {
// request is not relevant, another block was already requested
// it happens when A is being decoded, B comes and wait for mutex, C comes and wait for mutex
// B is not necessary anymore, because C already was requested
@@ -262,8 +284,11 @@ export class FrameDecoder {
throw new RequestOutdatedError();
}
- const chunkNumber = Math.floor(start / this.chunkSize);
- this.orderedStack = [chunkNumber, ...this.orderedStack];
+ const getFrameNumber = (chunkFrameIndex: number): number => (
+ chunkFrameNumbers[chunkFrameIndex]
+ );
+
+ this.orderedStack = [chunkIndex, ...this.orderedStack];
this.cleanup();
const decodedFrames: Record = {};
this.chunkIsBeingDecoded = this.requestedChunkToDecode;
@@ -273,7 +298,7 @@ export class FrameDecoder {
this.videoWorker = new Worker(
new URL('./3rdparty/Decoder.worker', import.meta.url),
);
- let index = start;
+ let index = 0;
this.videoWorker.onmessage = (e) => {
if (e.data.consoleLog) {
@@ -281,6 +306,7 @@ export class FrameDecoder {
return;
}
const keptIndex = index;
+ const frameNumber = getFrameNumber(keptIndex);
// do not use e.data.height and e.data.width because they might be not correct
// instead, try to understand real height and width of decoded image via scale factor
@@ -295,11 +321,11 @@ export class FrameDecoder {
width,
height,
)).then((bitmap) => {
- decodedFrames[keptIndex] = bitmap;
- this.chunkIsBeingDecoded.onDecode(keptIndex, decodedFrames[keptIndex]);
+ decodedFrames[frameNumber] = bitmap;
+ this.chunkIsBeingDecoded.onDecode(frameNumber, decodedFrames[frameNumber]);
- if (keptIndex === end) {
- this.decodedChunks[chunkNumber] = decodedFrames;
+ if (keptIndex === chunkFrameNumbers.length - 1) {
+ this.decodedChunks[chunkIndex] = decodedFrames;
this.chunkIsBeingDecoded.onDecodeAll();
this.chunkIsBeingDecoded = null;
release();
@@ -343,7 +369,7 @@ export class FrameDecoder {
this.zipWorker = this.zipWorker || new Worker(
new URL('./unzip_imgs.worker', import.meta.url),
);
- let index = start;
+ let decodedCount = 0;
this.zipWorker.onmessage = async (event) => {
if (event.data.error) {
@@ -353,16 +379,18 @@ export class FrameDecoder {
return;
}
- decodedFrames[event.data.index] = event.data.data as ImageBitmap | Blob;
- this.chunkIsBeingDecoded.onDecode(event.data.index, decodedFrames[event.data.index]);
+ const frameNumber = getFrameNumber(event.data.index);
+ decodedFrames[frameNumber] = event.data.data as ImageBitmap | Blob;
+ this.chunkIsBeingDecoded.onDecode(frameNumber, decodedFrames[frameNumber]);
- if (index === end) {
- this.decodedChunks[chunkNumber] = decodedFrames;
+ if (decodedCount === chunkFrameNumbers.length - 1) {
+ this.decodedChunks[chunkIndex] = decodedFrames;
this.chunkIsBeingDecoded.onDecodeAll();
this.chunkIsBeingDecoded = null;
release();
}
- index++;
+
+ decodedCount++;
};
this.zipWorker.onerror = (event: ErrorEvent) => {
@@ -373,8 +401,8 @@ export class FrameDecoder {
this.zipWorker.postMessage({
block,
- start,
- end,
+ start: 0,
+ end: chunkFrameNumbers.length - 1,
dimension: this.dimension,
dimension2D: DimensionType.DIMENSION_2D,
});
@@ -400,9 +428,12 @@ export class FrameDecoder {
}
public cachedChunks(includeInProgress = false): number[] {
- const chunkIsBeingDecoded = includeInProgress && this.chunkIsBeingDecoded ?
- Math.floor(this.chunkIsBeingDecoded.start / this.chunkSize) : null;
- return Object.keys(this.decodedChunks).map((chunkNumber: string) => +chunkNumber).concat(
+ const chunkIsBeingDecoded = (
+ includeInProgress && this.chunkIsBeingDecoded ?
+ this.chunkIsBeingDecoded.chunkIndex :
+ null
+ );
+ return Object.keys(this.decodedChunks).map((chunkIndex: string) => +chunkIndex).concat(
...(chunkIsBeingDecoded !== null ? [chunkIsBeingDecoded] : []),
).sort((a, b) => a - b);
}
diff --git a/cvat-sdk/gen/generate.sh b/cvat-sdk/gen/generate.sh
index 8e89efeb2bad..1b5eb7befde7 100755
--- a/cvat-sdk/gen/generate.sh
+++ b/cvat-sdk/gen/generate.sh
@@ -8,7 +8,7 @@ set -e
GENERATOR_VERSION="v6.0.1"
-VERSION="2.19.0"
+VERSION="2.20.0"
LIB_NAME="cvat_sdk"
LAYER1_LIB_NAME="${LIB_NAME}/api_client"
DST_DIR="$(cd "$(dirname -- "$0")/.." && pwd)"
diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts
index 31b73314a131..b3fa8b503aaa 100644
--- a/cvat-ui/src/actions/annotation-actions.ts
+++ b/cvat-ui/src/actions/annotation-actions.ts
@@ -587,12 +587,13 @@ export function confirmCanvasReadyAsync(): ThunkAction {
const { instance: job } = state.annotation.job;
const { changeFrameEvent } = state.annotation.player.frame;
const chunks = await job.frames.cachedChunks() as number[];
- const { startFrame, stopFrame, dataChunkSize } = job;
+ const includedFrames = await job.frames.frameNumbers() as number[];
+ const { frameCount, dataChunkSize } = job;
const ranges = chunks.map((chunk) => (
[
- Math.max(startFrame, chunk * dataChunkSize),
- Math.min(stopFrame, (chunk + 1) * dataChunkSize - 1),
+ includedFrames[chunk * dataChunkSize],
+ includedFrames[Math.min(frameCount - 1, (chunk + 1) * dataChunkSize - 1)],
]
)).reduce>((acc, val) => {
if (acc.length && acc[acc.length - 1][1] + 1 === val[0]) {
@@ -905,7 +906,8 @@ export function getJobAsync({
// frame query parameter does not work for GT job
const frameNumber = Number.isInteger(initialFrame) && gtJob?.id !== job.id ?
- initialFrame as number : (await job.frames.search(
+ initialFrame as number :
+ (await job.frames.search(
{ notDeleted: !showDeletedFrames }, job.startFrame, job.stopFrame,
)) || job.startFrame;
diff --git a/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx b/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx
index 2088d14d7ccf..f1a2e9cf2892 100644
--- a/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx
+++ b/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx
@@ -169,17 +169,14 @@ function PlayerNavigation(props: Props): JSX.Element {
{!!ranges && (
)}
diff --git a/cvat-ui/src/components/quality-control/task-quality/allocation-table.tsx b/cvat-ui/src/components/quality-control/task-quality/allocation-table.tsx
index e1a4862ad822..b1619bf6fcd3 100644
--- a/cvat-ui/src/components/quality-control/task-quality/allocation-table.tsx
+++ b/cvat-ui/src/components/quality-control/task-quality/allocation-table.tsx
@@ -109,10 +109,12 @@ function AllocationTable(props: Readonly): JSX.Element {
render: (active: boolean, record: RowData): JSX.Element => (
active ? (
{ onDeleteFrames([record.frame]); }}
/>
) : (
{ onRestoreFrames([record.frame]); }}
component={RestoreIcon}
/>
@@ -130,7 +132,7 @@ function AllocationTable(props: Readonly): JSX.Element {
{
selection.selectedRowKeys.length !== 0 ? (
<>
-
+
{
const framesToUpdate = selection.selectedRows
@@ -141,7 +143,7 @@ function AllocationTable(props: Readonly): JSX.Element {
}}
/>
-
+
{
const framesToUpdate = selection.selectedRows
@@ -163,7 +165,7 @@ function AllocationTable(props: Readonly): JSX.Element {
if (!rowData.active) {
return 'cvat-allocation-frame-row cvat-allocation-frame-row-excluded';
}
- return 'cvat-allocation-frame';
+ return 'cvat-allocation-frame-row';
}}
columns={columns}
dataSource={data}
diff --git a/cvat-ui/src/components/quality-control/task-quality/summary.tsx b/cvat-ui/src/components/quality-control/task-quality/summary.tsx
index be249341b37d..59767192ed0d 100644
--- a/cvat-ui/src/components/quality-control/task-quality/summary.tsx
+++ b/cvat-ui/src/components/quality-control/task-quality/summary.tsx
@@ -20,14 +20,14 @@ export default function SummaryComponent(props: Readonly): JSX.Element {
-
+
Excluded count:
{' '}
{excludedCount}
-
+
Total count:
{' '}
@@ -36,7 +36,7 @@ export default function SummaryComponent(props: Readonly): JSX.Element {
-
+
Active count:
{' '}
diff --git a/cvat/__init__.py b/cvat/__init__.py
index 3b50a1130077..764ba8d6be04 100644
--- a/cvat/__init__.py
+++ b/cvat/__init__.py
@@ -4,6 +4,6 @@
from cvat.utils.version import get_version
-VERSION = (2, 19, 0, 'alpha', 0)
+VERSION = (2, 20, 0, 'alpha', 0)
__version__ = get_version(VERSION)
diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py
index 534f885449e1..eb8fdf26b52c 100644
--- a/cvat/apps/dataset_manager/bindings.py
+++ b/cvat/apps/dataset_manager/bindings.py
@@ -30,8 +30,8 @@
from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.dataset_manager.util import add_prefetch_fields
-from cvat.apps.engine.frame_provider import FrameProvider
-from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, DimensionType, Job,
+from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality, FrameOutputType
+from cvat.apps.engine.models import (AttributeSpec, AttributeType, DimensionType, Job,
JobType, Label, LabelType, Project, SegmentType, ShapeType,
Task)
from cvat.apps.engine.rq_job_handler import RQJobMetaField
@@ -301,7 +301,7 @@ def start(self) -> int:
@property
def stop(self) -> int:
- return len(self)
+ return max(0, len(self) - 1)
def _get_queryset(self):
raise NotImplementedError()
@@ -437,7 +437,7 @@ def _export_tag(self, tag):
def _export_track(self, track, idx):
track['shapes'] = list(filter(lambda x: not self._is_frame_deleted(x['frame']), track['shapes']))
tracked_shapes = TrackManager.get_interpolated_shapes(
- track, 0, self.stop, self._annotation_ir.dimension)
+ track, 0, self.stop + 1, self._annotation_ir.dimension)
for tracked_shape in tracked_shapes:
tracked_shape["attributes"] += track["attributes"]
tracked_shape["track_id"] = track["track_id"] if self._use_server_track_ids else idx
@@ -493,7 +493,7 @@ def get_frame(idx):
anno_manager = AnnotationManager(self._annotation_ir)
for shape in sorted(
- anno_manager.to_shapes(self.stop, self._annotation_ir.dimension,
+ anno_manager.to_shapes(self.stop + 1, self._annotation_ir.dimension,
# Skip outside, deleted and excluded frames
included_frames=included_frames,
include_outside=False,
@@ -840,7 +840,7 @@ def start(self) -> int:
@property
def stop(self) -> int:
segment = self._db_job.segment
- return segment.stop_frame + 1
+ return segment.stop_frame
@property
def db_instance(self):
@@ -1410,7 +1410,7 @@ def add_task(self, task, files):
@attrs(frozen=True, auto_attribs=True)
class ImageSource:
- db_data: Data
+ db_task: Task
is_video: bool = attrib(kw_only=True)
class ImageProvider:
@@ -1439,8 +1439,10 @@ def video_frame_loader(_):
# optimization for videos: use numpy arrays instead of bytes
# some formats or transforms can require image data
return self._frame_provider.get_frame(frame_index,
- quality=FrameProvider.Quality.ORIGINAL,
- out_type=FrameProvider.Type.NUMPY_ARRAY)[0]
+ quality=FrameQuality.ORIGINAL,
+ out_type=FrameOutputType.NUMPY_ARRAY
+ ).data
+
return dm.Image(data=video_frame_loader, **image_kwargs)
else:
def image_loader(_):
@@ -1448,8 +1450,10 @@ def image_loader(_):
# for images use encoded data to avoid recoding
return self._frame_provider.get_frame(frame_index,
- quality=FrameProvider.Quality.ORIGINAL,
- out_type=FrameProvider.Type.BUFFER)[0].getvalue()
+ quality=FrameQuality.ORIGINAL,
+ out_type=FrameOutputType.BUFFER
+ ).data.getvalue()
+
return dm.ByteImage(data=image_loader, **image_kwargs)
def _load_source(self, source_id: int, source: ImageSource) -> None:
@@ -1457,7 +1461,7 @@ def _load_source(self, source_id: int, source: ImageSource) -> None:
return
self._unload_source()
- self._frame_provider = FrameProvider(source.db_data)
+ self._frame_provider = TaskFrameProvider(source.db_task)
self._current_source_id = source_id
def _unload_source(self) -> None:
@@ -1473,7 +1477,7 @@ def __init__(self, sources: Dict[int, ImageSource]) -> None:
self._images_per_source = {
source_id: {
image.id: image
- for image in source.db_data.images.prefetch_related('related_files')
+ for image in source.db_task.data.images.prefetch_related('related_files')
}
for source_id, source in sources.items()
}
@@ -1482,7 +1486,7 @@ def get_image_for_frame(self, source_id: int, frame_id: int, **image_kwargs):
source = self._sources[source_id]
point_cloud_path = osp.join(
- source.db_data.get_upload_dirname(), image_kwargs['path'],
+ source.db_task.data.get_upload_dirname(), image_kwargs['path'],
)
image = self._images_per_source[source_id][frame_id]
@@ -1595,11 +1599,18 @@ def __init__(
is_video = instance_meta['mode'] == 'interpolation'
ext = ''
if is_video:
- ext = FrameProvider.VIDEO_FRAME_EXT
+ ext = TaskFrameProvider.VIDEO_FRAME_EXT
if dimension == DimensionType.DIM_3D or include_images:
+ if isinstance(instance_data, TaskData):
+ db_task = instance_data.db_instance
+ elif isinstance(instance_data, JobData):
+ db_task = instance_data.db_instance.segment.task
+ else:
+ assert False
+
self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[dimension](
- {0: ImageSource(instance_data.db_data, is_video=is_video)}
+ {0: ImageSource(db_task, is_video=is_video)}
)
for frame_data in instance_data.group_by_frame(include_empty=True):
@@ -1681,13 +1692,13 @@ def __init__(
if self._dimension == DimensionType.DIM_3D or include_images:
self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[self._dimension](
{
- task.id: ImageSource(task.data, is_video=task.mode == 'interpolation')
+ task.id: ImageSource(task, is_video=task.mode == 'interpolation')
for task in project_data.tasks
}
)
ext_per_task: Dict[int, str] = {
- task.id: FrameProvider.VIDEO_FRAME_EXT if is_video else ''
+ task.id: TaskFrameProvider.VIDEO_FRAME_EXT if is_video else ''
for task in project_data.tasks
for is_video in [task.mode == 'interpolation']
}
diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py
index 0191dfe1c8c4..4651fd398451 100644
--- a/cvat/apps/dataset_manager/formats/cvat.py
+++ b/cvat/apps/dataset_manager/formats/cvat.py
@@ -27,7 +27,7 @@
import_dm_annotations,
match_dm_item)
from cvat.apps.dataset_manager.util import make_zip_archive
-from cvat.apps.engine.frame_provider import FrameProvider
+from cvat.apps.engine.frame_provider import FrameQuality, FrameOutputType, make_frame_provider
from .registry import dm_env, exporter, importer
@@ -1371,16 +1371,19 @@ def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callb
dumper.close_document()
def dump_media_files(instance_data: CommonData, img_dir: str, project_data: ProjectData = None):
+ frame_provider = make_frame_provider(instance_data.db_instance)
+
ext = ''
if instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation':
- ext = FrameProvider.VIDEO_FRAME_EXT
-
- frame_provider = FrameProvider(instance_data.db_data)
- frames = frame_provider.get_frames(
- instance_data.start, instance_data.stop,
- frame_provider.Quality.ORIGINAL,
- frame_provider.Type.BUFFER)
- for frame_id, (frame_data, _) in zip(instance_data.rel_range, frames):
+ ext = frame_provider.VIDEO_FRAME_EXT
+
+ frames = frame_provider.iterate_frames(
+ start_frame=instance_data.start,
+ stop_frame=instance_data.stop,
+ quality=FrameQuality.ORIGINAL,
+ out_type=FrameOutputType.BUFFER,
+ )
+ for frame_id, frame in zip(instance_data.rel_range, frames):
if (project_data is not None and (instance_data.db_instance.id, frame_id) in project_data.deleted_frames) \
or frame_id in instance_data.deleted_frames:
continue
@@ -1389,7 +1392,7 @@ def dump_media_files(instance_data: CommonData, img_dir: str, project_data: Proj
img_path = osp.join(img_dir, frame_name + ext)
os.makedirs(osp.dirname(img_path), exist_ok=True)
with open(img_path, 'wb') as f:
- f.write(frame_data.getvalue())
+ f.write(frame.data.getvalue())
def _export_task_or_job(dst_file, temp_dir, instance_data, anno_callback, save_images=False):
with open(osp.join(temp_dir, 'annotations.xml'), 'wb') as f:
diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py
index 3bdc2e3df610..7b9eebd5ab97 100644
--- a/cvat/apps/dataset_manager/tests/test_formats.py
+++ b/cvat/apps/dataset_manager/tests/test_formats.py
@@ -1,6 +1,6 @@
# Copyright (C) 2020-2022 Intel Corporation
-# Copyright (C) 2022 CVAT.ai Corporation
+# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
@@ -14,10 +14,8 @@
from datumaro.components.dataset import Dataset, DatasetItem
from datumaro.components.annotation import Mask
from django.contrib.auth.models import Group, User
-from PIL import Image
from rest_framework import status
-from rest_framework.test import APIClient, APITestCase
import cvat.apps.dataset_manager as dm
from cvat.apps.dataset_manager.annotation import AnnotationIR
@@ -26,36 +24,13 @@
from cvat.apps.dataset_manager.task import TaskAnnotation
from cvat.apps.dataset_manager.util import make_zip_archive
from cvat.apps.engine.models import Task
-from cvat.apps.engine.tests.utils import get_paginated_collection
+from cvat.apps.engine.tests.utils import (
+ get_paginated_collection, ForceLogin, generate_image_file, ApiTestBase
+)
-
-def generate_image_file(filename, size=(100, 100)):
- f = BytesIO()
- image = Image.new('RGB', size=size)
- image.save(f, 'jpeg')
- f.name = filename
- f.seek(0)
- return f
-
-class ForceLogin:
- def __init__(self, user, client):
- self.user = user
- self.client = client
-
- def __enter__(self):
- if self.user:
- self.client.force_login(self.user,
- backend='django.contrib.auth.backends.ModelBackend')
-
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- if self.user:
- self.client.logout()
-
-class _DbTestBase(APITestCase):
+class _DbTestBase(ApiTestBase):
def setUp(self):
- self.client = APIClient()
+ super().setUp()
@classmethod
def setUpTestData(cls):
@@ -94,6 +69,11 @@ def _create_task(self, data, image_data):
response = self.client.post("/api/tasks/%s/data" % tid,
data=image_data)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
+ rq_id = response.json()["rq_id"]
+
+ response = self.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
response = self.client.get("/api/tasks/%s" % tid)
diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py
index 86aecfb42ddb..a4eef83174c3 100644
--- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py
+++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py
@@ -171,6 +171,11 @@ def _create_task(self, data, image_data):
response = self.client.post("/api/tasks/%s/data" % tid,
data=image_data)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
+ rq_id = response.json()["rq_id"]
+
+ response = self.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
response = self.client.get("/api/tasks/%s" % tid)
@@ -413,7 +418,7 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self):
url = self._generate_url_dump_tasks_annotations(task_id)
for user, edata in list(expected.items()):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
@@ -522,7 +527,7 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self):
url = self._generate_url_dump_tasks_annotations(task_id)
for user, edata in list(expected.items()):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
@@ -607,7 +612,7 @@ def test_api_v2_dump_tag_annotations(self):
for user, edata in list(expected.items()):
with self.subTest(format=f"{edata['name']}"):
with TestDir() as test_dir:
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
user_name = edata['name']
url = self._generate_url_dump_tasks_annotations(task_id)
@@ -849,7 +854,7 @@ def test_api_v2_export_dataset(self):
# dump annotations
url = self._generate_url_dump_task_dataset(task_id)
for user, edata in list(expected.items()):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
@@ -2112,7 +2117,7 @@ def test_api_v2_export_import_dataset(self):
self._create_annotations(task, dump_format_name, "random")
for user, edata in list(expected.items()):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
@@ -2175,7 +2180,7 @@ def test_api_v2_export_annotations(self):
url = self._generate_url_dump_project_annotations(project['id'], dump_format_name)
for user, edata in list(expected.items()):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
diff --git a/cvat/apps/engine/apps.py b/cvat/apps/engine/apps.py
index 326920e8b494..bcad84510f5d 100644
--- a/cvat/apps/engine/apps.py
+++ b/cvat/apps/engine/apps.py
@@ -10,6 +10,14 @@ class EngineConfig(AppConfig):
name = 'cvat.apps.engine'
def ready(self):
+ from django.conf import settings
+
+ from . import default_settings
+
+ for key in dir(default_settings):
+ if key.isupper() and not hasattr(settings, key):
+ setattr(settings, key, getattr(default_settings, key))
+
# Required to define signals in application
import cvat.apps.engine.signals
# Required in order to silent "unused-import" in pyflake
diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py
index 2603c2fd5a13..bc4c8616bd7f 100644
--- a/cvat/apps/engine/cache.py
+++ b/cvat/apps/engine/cache.py
@@ -1,349 +1,700 @@
# Copyright (C) 2020-2022 Intel Corporation
-# Copyright (C) 2022-2023 CVAT.ai Corporation
+# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
+from __future__ import annotations
+
import io
import os
-import zipfile
-from datetime import datetime, timezone
-from io import BytesIO
-import shutil
+import os.path
+import pickle # nosec
import tempfile
+import zipfile
import zlib
-
-from typing import Optional, Tuple
-
+from contextlib import ExitStack, closing
+from datetime import datetime, timezone
+from itertools import groupby, pairwise
+from typing import (
+ Any,
+ Callable,
+ Collection,
+ Generator,
+ Iterator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ overload,
+)
+
+import av
import cv2
import PIL.Image
-import pickle # nosec
-from django.conf import settings
+import PIL.ImageOps
from django.core.cache import caches
from rest_framework.exceptions import NotFound, ValidationError
-from cvat.apps.engine.cloud_provider import (Credentials,
- db_storage_to_storage_instance,
- get_cloud_storage_instance)
+from cvat.apps.engine import models
+from cvat.apps.engine.cloud_provider import (
+ Credentials,
+ db_storage_to_storage_instance,
+ get_cloud_storage_instance,
+)
from cvat.apps.engine.log import ServerLogManager
-from cvat.apps.engine.media_extractors import (ImageDatasetManifestReader,
- Mpeg4ChunkWriter,
- Mpeg4CompressedChunkWriter,
- VideoDatasetManifestReader,
- ZipChunkWriter,
- ZipCompressedChunkWriter)
-from cvat.apps.engine.mime_types import mimetypes
-from cvat.apps.engine.models import (DataChoice, DimensionType, Job, Image,
- StorageChoice, CloudStorage)
+from cvat.apps.engine.media_extractors import (
+ FrameQuality,
+ IChunkWriter,
+ ImageReaderWithManifest,
+ Mpeg4ChunkWriter,
+ Mpeg4CompressedChunkWriter,
+ VideoReader,
+ VideoReaderWithManifest,
+ ZipChunkWriter,
+ ZipCompressedChunkWriter,
+)
from cvat.apps.engine.utils import md5_hash, preload_images
from utils.dataset_manifest import ImageManifestManager
slogger = ServerLogManager(__name__)
+
+DataWithMime = Tuple[io.BytesIO, str]
+_CacheItem = Tuple[io.BytesIO, str, int]
+
+
class MediaCache:
- def __init__(self, dimension=DimensionType.DIM_2D):
- self._dimension = dimension
- self._cache = caches['media']
-
- def _get_or_set_cache_item(self, key, create_function):
- def create_item():
- slogger.glob.info(f'Starting to prepare chunk: key {key}')
- item = create_function()
- slogger.glob.info(f'Ending to prepare chunk: key {key}')
-
- if item[0]:
- item = (item[0], item[1], zlib.crc32(item[0].getbuffer()))
+ def __init__(self) -> None:
+ self._cache = caches["media"]
+
+ def _get_checksum(self, value: bytes) -> int:
+ return zlib.crc32(value)
+
+ def _get_or_set_cache_item(
+ self, key: str, create_callback: Callable[[], DataWithMime]
+ ) -> _CacheItem:
+ def create_item() -> _CacheItem:
+ slogger.glob.info(f"Starting to prepare chunk: key {key}")
+ item_data = create_callback()
+ slogger.glob.info(f"Ending to prepare chunk: key {key}")
+
+ item_data_bytes = item_data[0].getvalue()
+ item = (item_data[0], item_data[1], self._get_checksum(item_data_bytes))
+ if item_data_bytes:
self._cache.set(key, item)
return item
- slogger.glob.info(f'Starting to get chunk from cache: key {key}')
- try:
- item = self._cache.get(key)
- except pickle.UnpicklingError:
- slogger.glob.error(f'Unable to get item from cache: key {key}', exc_info=True)
- item = None
- slogger.glob.info(f'Ending to get chunk from cache: key {key}, is_cached {bool(item)}')
-
+ item = self._get_cache_item(key)
if not item:
item = create_item()
else:
# compare checksum
item_data = item[0].getbuffer() if isinstance(item[0], io.BytesIO) else item[0]
item_checksum = item[2] if len(item) == 3 else None
- if item_checksum != zlib.crc32(item_data):
- slogger.glob.info(f'Recreating cache item {key} due to checksum mismatch')
+ if item_checksum != self._get_checksum(item_data):
+ slogger.glob.info(f"Recreating cache item {key} due to checksum mismatch")
item = create_item()
- return item[0], item[1]
+ return item
- def get_task_chunk_data_with_mime(self, chunk_number, quality, db_data):
- item = self._get_or_set_cache_item(
- key=f'{db_data.id}_{chunk_number}_{quality}',
- create_function=lambda: self._prepare_task_chunk(db_data, quality, chunk_number),
- )
+ def _get_cache_item(self, key: str) -> Optional[_CacheItem]:
+ slogger.glob.info(f"Starting to get chunk from cache: key {key}")
+ try:
+ item = self._cache.get(key)
+ except pickle.UnpicklingError:
+ slogger.glob.error(f"Unable to get item from cache: key {key}", exc_info=True)
+ item = None
+ slogger.glob.info(f"Ending to get chunk from cache: key {key}, is_cached {bool(item)}")
return item
- def get_selective_job_chunk_data_with_mime(self, chunk_number, quality, job):
- item = self._get_or_set_cache_item(
- key=f'job_{job.id}_{chunk_number}_{quality}',
- create_function=lambda: self.prepare_selective_job_chunk(job, quality, chunk_number),
- )
+ def _has_key(self, key: str) -> bool:
+ return self._cache.has_key(key)
+
+ def _make_cache_key_prefix(
+ self, obj: Union[models.Task, models.Segment, models.Job, models.CloudStorage]
+ ) -> str:
+ if isinstance(obj, models.Task):
+ return f"task_{obj.id}"
+ elif isinstance(obj, models.Segment):
+ return f"segment_{obj.id}"
+ elif isinstance(obj, models.Job):
+ return f"job_{obj.id}"
+ elif isinstance(obj, models.CloudStorage):
+ return f"cloudstorage_{obj.id}"
+ else:
+ assert False, f"Unexpected object type {type(obj)}"
- return item
+ def _make_chunk_key(
+ self,
+ db_obj: Union[models.Task, models.Segment, models.Job],
+ chunk_number: int,
+ *,
+ quality: FrameQuality,
+ ) -> str:
+ return f"{self._make_cache_key_prefix(db_obj)}_chunk_{chunk_number}_{quality}"
+
+ def _make_preview_key(self, db_obj: Union[models.Segment, models.CloudStorage]) -> str:
+ return f"{self._make_cache_key_prefix(db_obj)}_preview"
- def get_local_preview_with_mime(self, frame_number, db_data):
- item = self._get_or_set_cache_item(
- key=f'data_{db_data.id}_{frame_number}_preview',
- create_function=lambda: self._prepare_local_preview(frame_number, db_data),
+ def _make_segment_task_chunk_key(
+ self,
+ db_obj: models.Segment,
+ chunk_number: int,
+ *,
+ quality: FrameQuality,
+ ) -> str:
+ return f"{self._make_cache_key_prefix(db_obj)}_task_chunk_{chunk_number}_{quality}"
+
+ def _make_context_image_preview_key(self, db_data: models.Data, frame_number: int) -> str:
+ return f"context_image_{db_data.id}_{frame_number}_preview"
+
+ @overload
+ def _to_data_with_mime(self, cache_item: _CacheItem) -> DataWithMime: ...
+
+ @overload
+ def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: ...
+
+ def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]:
+ if not cache_item:
+ return None
+
+ return cache_item[:2]
+
+ def get_or_set_segment_chunk(
+ self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality
+ ) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ key=self._make_chunk_key(db_segment, chunk_number, quality=quality),
+ create_callback=lambda: self.prepare_segment_chunk(
+ db_segment, chunk_number, quality=quality
+ ),
+ )
)
- return item
+ def get_task_chunk(
+ self, db_task: models.Task, chunk_number: int, *, quality: FrameQuality
+ ) -> Optional[DataWithMime]:
+ return self._to_data_with_mime(
+ self._get_cache_item(key=self._make_chunk_key(db_task, chunk_number, quality=quality))
+ )
- def get_cloud_preview_with_mime(
+ def get_or_set_task_chunk(
self,
- db_storage: CloudStorage,
- ) -> Optional[Tuple[io.BytesIO, str]]:
- key = f'cloudstorage_{db_storage.id}_preview'
- return self._cache.get(key)
+ db_task: models.Task,
+ chunk_number: int,
+ *,
+ quality: FrameQuality,
+ set_callback: Callable[[], DataWithMime],
+ ) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ key=self._make_chunk_key(db_task, chunk_number, quality=quality),
+ create_callback=set_callback,
+ )
+ )
- def get_or_set_cloud_preview_with_mime(
- self,
- db_storage: CloudStorage,
- ) -> Tuple[io.BytesIO, str]:
- key = f'cloudstorage_{db_storage.id}_preview'
+ def get_segment_task_chunk(
+ self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality
+ ) -> Optional[DataWithMime]:
+ return self._to_data_with_mime(
+ self._get_cache_item(
+ key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality)
+ )
+ )
- item = self._get_or_set_cache_item(
- key, create_function=lambda: self._prepare_cloud_preview(db_storage)
+ def get_or_set_segment_task_chunk(
+ self,
+ db_segment: models.Segment,
+ chunk_number: int,
+ *,
+ quality: FrameQuality,
+ set_callback: Callable[[], DataWithMime],
+ ) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality),
+ create_callback=set_callback,
+ )
)
- return item
+ def get_or_set_selective_job_chunk(
+ self, db_job: models.Job, chunk_number: int, *, quality: FrameQuality
+ ) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ key=self._make_chunk_key(db_job, chunk_number, quality=quality),
+ create_callback=lambda: self.prepare_masked_range_segment_chunk(
+ db_job.segment, chunk_number, quality=quality
+ ),
+ )
+ )
- def get_frame_context_images(self, db_data, frame_number):
- item = self._get_or_set_cache_item(
- key=f'context_image_{db_data.id}_{frame_number}',
- create_function=lambda: self._prepare_context_image(db_data, frame_number)
+ def get_or_set_segment_preview(self, db_segment: models.Segment) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ self._make_preview_key(db_segment),
+ create_callback=lambda: self._prepare_segment_preview(db_segment),
+ )
)
- return item
+ def get_cloud_preview(self, db_storage: models.CloudStorage) -> Optional[DataWithMime]:
+ return self._to_data_with_mime(self._get_cache_item(self._make_preview_key(db_storage)))
- @staticmethod
- def _get_frame_provider_class():
- from cvat.apps.engine.frame_provider import \
- FrameProvider # TODO: remove circular dependency
- return FrameProvider
-
- from contextlib import contextmanager
-
- @staticmethod
- @contextmanager
- def _get_images(db_data, chunk_number, dimension):
- images = []
- tmp_dir = None
- upload_dir = {
- StorageChoice.LOCAL: db_data.get_upload_dirname(),
- StorageChoice.SHARE: settings.SHARE_ROOT,
- StorageChoice.CLOUD_STORAGE: db_data.get_upload_dirname(),
- }[db_data.storage]
+ def get_or_set_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ self._make_preview_key(db_storage),
+ create_callback=lambda: self._prepare_cloud_preview(db_storage),
+ )
+ )
- try:
- if hasattr(db_data, 'video'):
- source_path = os.path.join(upload_dir, db_data.video.path)
-
- reader = VideoDatasetManifestReader(manifest_path=db_data.get_manifest_path(),
- source_path=source_path, chunk_number=chunk_number,
- chunk_size=db_data.chunk_size, start=db_data.start_frame,
- stop=db_data.stop_frame, step=db_data.get_frame_step())
- for frame in reader:
- images.append((frame, source_path, None))
- else:
- reader = ImageDatasetManifestReader(manifest_path=db_data.get_manifest_path(),
- chunk_number=chunk_number, chunk_size=db_data.chunk_size,
- start=db_data.start_frame, stop=db_data.stop_frame,
- step=db_data.get_frame_step())
- if db_data.storage == StorageChoice.CLOUD_STORAGE:
- db_cloud_storage = db_data.cloud_storage
- assert db_cloud_storage, 'Cloud storage instance was deleted'
- credentials = Credentials()
- credentials.convert_from_db({
- 'type': db_cloud_storage.credentials_type,
- 'value': db_cloud_storage.credentials,
- })
- details = {
- 'resource': db_cloud_storage.resource,
- 'credentials': credentials,
- 'specific_attributes': db_cloud_storage.get_specific_attributes()
+ def get_or_set_frame_context_images_chunk(
+ self, db_data: models.Data, frame_number: int
+ ) -> DataWithMime:
+ return self._to_data_with_mime(
+ self._get_or_set_cache_item(
+ key=self._make_context_image_preview_key(db_data, frame_number),
+ create_callback=lambda: self.prepare_context_images_chunk(db_data, frame_number),
+ )
+ )
+
+ def _read_raw_images(
+ self,
+ db_task: models.Task,
+ frame_ids: Sequence[int],
+ *,
+ manifest_path: str,
+ ):
+ db_data = db_task.data
+
+ if os.path.isfile(manifest_path) and db_data.storage == models.StorageChoice.CLOUD_STORAGE:
+ reader = ImageReaderWithManifest(manifest_path)
+ with ExitStack() as es:
+ db_cloud_storage = db_data.cloud_storage
+ assert db_cloud_storage, "Cloud storage instance was deleted"
+ credentials = Credentials()
+ credentials.convert_from_db(
+ {
+ "type": db_cloud_storage.credentials_type,
+ "value": db_cloud_storage.credentials,
}
- cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details)
+ )
+ details = {
+ "resource": db_cloud_storage.resource,
+ "credentials": credentials,
+ "specific_attributes": db_cloud_storage.get_specific_attributes(),
+ }
+ cloud_storage_instance = get_cloud_storage_instance(
+ cloud_provider=db_cloud_storage.provider_type, **details
+ )
+
+ tmp_dir = es.enter_context(tempfile.TemporaryDirectory(prefix="cvat"))
+ files_to_download = []
+ checksums = []
+ media = []
+ for item in reader.iterate_frames(frame_ids):
+ file_name = f"{item['name']}{item['extension']}"
+ fs_filename = os.path.join(tmp_dir, file_name)
+
+ files_to_download.append(file_name)
+ checksums.append(item.get("checksum", None))
+ media.append((fs_filename, fs_filename, None))
+
+ cloud_storage_instance.bulk_download_to_dir(
+ files=files_to_download, upload_dir=tmp_dir
+ )
+ media = preload_images(media)
+
+ for checksum, (_, fs_filename, _) in zip(checksums, media):
+ if checksum and not md5_hash(fs_filename) == checksum:
+ slogger.cloud_storage[db_cloud_storage.id].warning(
+ "Hash sums of files {} do not match".format(file_name)
+ )
+
+ yield from media
+ else:
+ requested_frame_iter = iter(frame_ids)
+ next_requested_frame_id = next(requested_frame_iter, None)
+ if next_requested_frame_id is None:
+ return
+
+ # TODO: find a way to use prefetched results, if provided
+ db_images = (
+ db_data.images.order_by("frame")
+ .filter(frame__gte=frame_ids[0], frame__lte=frame_ids[-1])
+ .values_list("frame", "path")
+ .all()
+ )
- tmp_dir = tempfile.mkdtemp(prefix='cvat')
- files_to_download = []
- checksums = []
- for item in reader:
- file_name = f"{item['name']}{item['extension']}"
- fs_filename = os.path.join(tmp_dir, file_name)
+ raw_data_dir = db_data.get_raw_data_dirname()
+ media = []
+ for frame_id, frame_path in db_images:
+ if frame_id == next_requested_frame_id:
+ source_path = os.path.join(raw_data_dir, frame_path)
+ media.append((source_path, source_path, None))
- files_to_download.append(file_name)
- checksums.append(item.get('checksum', None))
- images.append((fs_filename, fs_filename, None))
+ next_requested_frame_id = next(requested_frame_iter, None)
- cloud_storage_instance.bulk_download_to_dir(files=files_to_download, upload_dir=tmp_dir)
- images = preload_images(images)
+ if next_requested_frame_id is None:
+ break
- for checksum, (_, fs_filename, _) in zip(checksums, images):
- if checksum and not md5_hash(fs_filename) == checksum:
- slogger.cloud_storage[db_cloud_storage.id].warning('Hash sums of files {} do not match'.format(file_name))
- else:
- for item in reader:
- source_path = os.path.join(upload_dir, f"{item['name']}{item['extension']}")
- images.append((source_path, source_path, None))
- if dimension == DimensionType.DIM_2D:
- images = preload_images(images)
-
- yield images
- finally:
- if db_data.storage == StorageChoice.CLOUD_STORAGE and tmp_dir is not None:
- shutil.rmtree(tmp_dir)
-
- def _prepare_task_chunk(self, db_data, quality, chunk_number):
- FrameProvider = self._get_frame_provider_class()
-
- writer_classes = {
- FrameProvider.Quality.COMPRESSED : Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == DataChoice.VIDEO else ZipCompressedChunkWriter,
- FrameProvider.Quality.ORIGINAL : Mpeg4ChunkWriter if db_data.original_chunk_type == DataChoice.VIDEO else ZipChunkWriter,
- }
-
- image_quality = 100 if writer_classes[quality] in [Mpeg4ChunkWriter, ZipChunkWriter] else db_data.image_quality
- mime_type = 'video/mp4' if writer_classes[quality] in [Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter] else 'application/zip'
-
- kwargs = {}
- if self._dimension == DimensionType.DIM_3D:
- kwargs["dimension"] = DimensionType.DIM_3D
- writer = writer_classes[quality](image_quality, **kwargs)
-
- buff = BytesIO()
- with self._get_images(db_data, chunk_number, self._dimension) as images:
- writer.save_as_chunk(images, buff)
- buff.seek(0)
+ assert next_requested_frame_id is None
- return buff, mime_type
+ if db_task.dimension == models.DimensionType.DIM_2D:
+ media = preload_images(media)
- def prepare_selective_job_chunk(self, db_job: Job, quality, chunk_number: int):
- db_data = db_job.segment.task.data
+ yield from media
- FrameProvider = self._get_frame_provider_class()
- frame_provider = FrameProvider(db_data, self._dimension)
+ def _read_raw_frames(
+ self, db_task: models.Task, frame_ids: Sequence[int]
+ ) -> Generator[Tuple[Union[av.VideoFrame, PIL.Image.Image], str, str], None, None]:
+ for prev_frame, cur_frame in pairwise(frame_ids):
+ assert (
+ prev_frame <= cur_frame
+ ), f"Requested frame ids must be sorted, got a ({prev_frame}, {cur_frame}) pair"
- frame_set = db_job.segment.frame_set
- frame_step = db_data.get_frame_step()
- chunk_frames = []
+ db_data = db_task.data
- writer = ZipCompressedChunkWriter(db_data.image_quality, dimension=self._dimension)
- dummy_frame = BytesIO()
- PIL.Image.new('RGB', (1, 1)).save(dummy_frame, writer.IMAGE_EXT)
+ manifest_path = db_data.get_manifest_path()
- if hasattr(db_data, 'video'):
- frame_size = (db_data.video.width, db_data.video.height)
- else:
- frame_size = None
+ if hasattr(db_data, "video"):
+ source_path = os.path.join(db_data.get_raw_data_dirname(), db_data.video.path)
- for frame_idx in range(db_data.chunk_size):
- frame_idx = (
- db_data.start_frame + chunk_number * db_data.chunk_size + frame_idx * frame_step
+ reader = VideoReaderWithManifest(
+ manifest_path=manifest_path,
+ source_path=source_path,
+ allow_threading=False,
)
- if db_data.stop_frame < frame_idx:
- break
-
- frame_bytes = None
-
- if frame_idx in frame_set:
- frame_bytes = frame_provider.get_frame(frame_idx, quality=quality)[0]
+ if not os.path.isfile(manifest_path):
+ try:
+ reader.manifest.link(source_path, force=True)
+ reader.manifest.create()
+ except Exception as e:
+ slogger.task[db_task.id].warning(
+ f"Failed to create video manifest: {e}", exc_info=True
+ )
+ reader = None
+
+ if reader:
+ for frame in reader.iterate_frames(frame_filter=frame_ids):
+ yield (frame, source_path, None)
+ else:
+ reader = VideoReader([source_path], allow_threading=False)
- if frame_size is not None:
- # Decoded video frames can have different size, restore the original one
+ for frame_tuple in reader.iterate_frames(frame_filter=frame_ids):
+ yield frame_tuple
+ else:
+ yield from self._read_raw_images(db_task, frame_ids, manifest_path=manifest_path)
+
+ def prepare_segment_chunk(
+ self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality
+ ) -> DataWithMime:
+ if db_segment.type == models.SegmentType.RANGE:
+ return self.prepare_range_segment_chunk(db_segment, chunk_number, quality=quality)
+ elif db_segment.type == models.SegmentType.SPECIFIC_FRAMES:
+ return self.prepare_masked_range_segment_chunk(
+ db_segment, chunk_number, quality=quality
+ )
+ else:
+ assert False, f"Unknown segment type {db_segment.type}"
+
+ def prepare_range_segment_chunk(
+ self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality
+ ) -> DataWithMime:
+ db_task = db_segment.task
+ db_data = db_task.data
+
+ chunk_size = db_data.chunk_size
+ chunk_frame_ids = list(db_segment.frame_set)[
+ chunk_size * chunk_number : chunk_size * (chunk_number + 1)
+ ]
+
+ return self.prepare_custom_range_segment_chunk(db_task, chunk_frame_ids, quality=quality)
+
+ def prepare_custom_range_segment_chunk(
+ self, db_task: models.Task, frame_ids: Sequence[int], *, quality: FrameQuality
+ ) -> DataWithMime:
+ with closing(self._read_raw_frames(db_task, frame_ids=frame_ids)) as frame_iter:
+ return prepare_chunk(frame_iter, quality=quality, db_task=db_task)
+
+ def prepare_masked_range_segment_chunk(
+ self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality
+ ) -> DataWithMime:
+ db_task = db_segment.task
+ db_data = db_task.data
+
+ chunk_size = db_data.chunk_size
+ chunk_frame_ids = sorted(db_segment.frame_set)[
+ chunk_size * chunk_number : chunk_size * (chunk_number + 1)
+ ]
+
+ return self.prepare_custom_masked_range_segment_chunk(
+ db_task, chunk_frame_ids, chunk_number, quality=quality
+ )
- frame = PIL.Image.open(frame_bytes)
- if frame.size != frame_size:
- frame = frame.resize(frame_size)
+ def prepare_custom_masked_range_segment_chunk(
+ self,
+ db_task: models.Task,
+ frame_ids: Collection[int],
+ chunk_number: int,
+ *,
+ quality: FrameQuality,
+ insert_placeholders: bool = False,
+ ) -> DataWithMime:
+ db_data = db_task.data
- frame_bytes = BytesIO()
- frame.save(frame_bytes, writer.IMAGE_EXT)
- frame_bytes.seek(0)
+ frame_step = db_data.get_frame_step()
- else:
- # Populate skipped frames with placeholder data,
- # this is required for video chunk decoding implementation in UI
- frame_bytes = BytesIO(dummy_frame.getvalue())
+ image_quality = 100 if quality == FrameQuality.ORIGINAL else db_data.image_quality
+ writer = ZipCompressedChunkWriter(image_quality, dimension=db_task.dimension)
+
+ dummy_frame = io.BytesIO()
+ PIL.Image.new("RGB", (1, 1)).save(dummy_frame, writer.IMAGE_EXT)
+
+ # Optimize frame access if all the required frames are already cached
+ # Otherwise we might need to download files.
+ # This is not needed for video tasks, as it will reduce performance
+ from cvat.apps.engine.frame_provider import FrameOutputType, TaskFrameProvider
+
+ task_frame_provider = TaskFrameProvider(db_task)
+
+ use_cached_data = False
+ if db_task.mode != "interpolation":
+ required_frame_set = set(frame_ids)
+ available_chunks = [
+ self._has_key(self._make_chunk_key(db_segment, chunk_number, quality=quality))
+ for db_segment in db_task.segment_set.filter(type=models.SegmentType.RANGE).all()
+ for chunk_number, _ in groupby(
+ sorted(required_frame_set.intersection(db_segment.frame_set)),
+ key=lambda frame: frame // db_data.chunk_size,
+ )
+ ]
+ use_cached_data = bool(available_chunks) and all(available_chunks)
+
+ if hasattr(db_data, "video"):
+ frame_size = (db_data.video.width, db_data.video.height)
+ else:
+ frame_size = None
- if frame_bytes is not None:
- chunk_frames.append((frame_bytes, None, None))
+ def get_frames():
+ with ExitStack() as es:
+ es.callback(task_frame_provider.unload)
+
+ if insert_placeholders:
+ frame_range = (
+ (
+ db_data.start_frame
+ + chunk_number * db_data.chunk_size
+ + chunk_frame_idx * frame_step
+ )
+ for chunk_frame_idx in range(db_data.chunk_size)
+ )
+ else:
+ frame_range = frame_ids
+
+ if not use_cached_data:
+ frames_gen = self._read_raw_frames(db_task, frame_ids)
+ frames_iter = iter(es.enter_context(closing(frames_gen)))
+
+ for abs_frame_idx in frame_range:
+ if db_data.stop_frame < abs_frame_idx:
+ break
+
+ if abs_frame_idx in frame_ids:
+ if use_cached_data:
+ frame_data = task_frame_provider.get_frame(
+ task_frame_provider.get_rel_frame_number(abs_frame_idx),
+ quality=quality,
+ out_type=FrameOutputType.BUFFER,
+ )
+ frame = frame_data.data
+ else:
+ frame, _, _ = next(frames_iter)
+
+ if hasattr(db_data, "video"):
+ # Decoded video frames can have different size, restore the original one
+
+ if isinstance(frame, av.VideoFrame):
+ frame = frame.to_image()
+ else:
+ frame = PIL.Image.open(frame)
+
+ if frame.size != frame_size:
+ frame = frame.resize(frame_size)
+ else:
+ # Populate skipped frames with placeholder data,
+ # this is required for video chunk decoding implementation in UI
+ frame = io.BytesIO(dummy_frame.getvalue())
+
+ yield (frame, None, None)
+
+ buff = io.BytesIO()
+ with closing(get_frames()) as frame_iter:
+ writer.save_as_chunk(
+ frame_iter,
+ buff,
+ zip_compress_level=1,
+ # there are likely to be many skips with repeated placeholder frames
+ # in SPECIFIC_FRAMES segments, it makes sense to compress the archive
+ )
- buff = BytesIO()
- writer.save_as_chunk(chunk_frames, buff, compress_frames=False,
- zip_compress_level=1 # these are likely to be many skips in SPECIFIC_FRAMES segments
- )
buff.seek(0)
+ return buff, get_chunk_mime_type_for_writer(writer)
- return buff, 'application/zip'
+ def _prepare_segment_preview(self, db_segment: models.Segment) -> DataWithMime:
+ if db_segment.task.dimension == models.DimensionType.DIM_3D:
+ # TODO
+ preview = PIL.Image.open(
+ os.path.join(os.path.dirname(__file__), "assets/3d_preview.jpeg")
+ )
+ else:
+ from cvat.apps.engine.frame_provider import ( # avoid circular import
+ FrameOutputType,
+ make_frame_provider,
+ )
- def _prepare_local_preview(self, frame_number, db_data):
- FrameProvider = self._get_frame_provider_class()
- frame_provider = FrameProvider(db_data, self._dimension)
- buff, mime_type = frame_provider.get_preview(frame_number)
+ task_frame_provider = make_frame_provider(db_segment.task)
+ segment_frame_provider = make_frame_provider(db_segment)
+ preview = segment_frame_provider.get_frame(
+ task_frame_provider.get_rel_frame_number(min(db_segment.frame_set)),
+ quality=FrameQuality.COMPRESSED,
+ out_type=FrameOutputType.PIL,
+ ).data
- return buff, mime_type
+ return prepare_preview_image(preview)
- def _prepare_cloud_preview(self, db_storage):
+ def _prepare_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime:
storage = db_storage_to_storage_instance(db_storage)
if not db_storage.manifests.count():
- raise ValidationError('Cannot get the cloud storage preview. There is no manifest file')
+ raise ValidationError("Cannot get the cloud storage preview. There is no manifest file")
+
preview_path = None
- for manifest_model in db_storage.manifests.all():
- manifest_prefix = os.path.dirname(manifest_model.filename)
- full_manifest_path = os.path.join(db_storage.get_storage_dirname(), manifest_model.filename)
- if not os.path.exists(full_manifest_path) or \
- datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) < storage.get_file_last_modified(manifest_model.filename):
- storage.download_file(manifest_model.filename, full_manifest_path)
+ for db_manifest in db_storage.manifests.all():
+ manifest_prefix = os.path.dirname(db_manifest.filename)
+
+ full_manifest_path = os.path.join(
+ db_storage.get_storage_dirname(), db_manifest.filename
+ )
+ if not os.path.exists(full_manifest_path) or datetime.fromtimestamp(
+ os.path.getmtime(full_manifest_path), tz=timezone.utc
+ ) < storage.get_file_last_modified(db_manifest.filename):
+ storage.download_file(db_manifest.filename, full_manifest_path)
+
manifest = ImageManifestManager(
- os.path.join(db_storage.get_storage_dirname(), manifest_model.filename),
- db_storage.get_storage_dirname()
+ os.path.join(db_storage.get_storage_dirname(), db_manifest.filename),
+ db_storage.get_storage_dirname(),
)
# need to update index
manifest.set_index()
if not len(manifest):
continue
+
preview_info = manifest[0]
- preview_filename = ''.join([preview_info['name'], preview_info['extension']])
+ preview_filename = "".join([preview_info["name"], preview_info["extension"]])
preview_path = os.path.join(manifest_prefix, preview_filename)
break
+
if not preview_path:
- msg = 'Cloud storage {} does not contain any images'.format(db_storage.pk)
+ msg = "Cloud storage {} does not contain any images".format(db_storage.pk)
slogger.cloud_storage[db_storage.pk].info(msg)
raise NotFound(msg)
buff = storage.download_fileobj(preview_path)
- mime_type = mimetypes.guess_type(preview_path)[0]
+ image = PIL.Image.open(buff)
+ return prepare_preview_image(image)
- return buff, mime_type
+ def prepare_context_images_chunk(self, db_data: models.Data, frame_number: int) -> DataWithMime:
+ zip_buffer = io.BytesIO()
- def _prepare_context_image(self, db_data, frame_number):
- zip_buffer = BytesIO()
- try:
- image = Image.objects.get(data_id=db_data.id, frame=frame_number)
- except Image.DoesNotExist:
- return None, None
- with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file:
- if not image.related_files.count():
- return None, None
- common_path = os.path.commonpath(list(map(lambda x: str(x.path), image.related_files.all())))
- for i in image.related_files.all():
+ related_images = db_data.related_files.filter(primary_image__frame=frame_number).all()
+ if not related_images:
+ return zip_buffer, ""
+
+ with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
+ common_path = os.path.commonpath(list(map(lambda x: str(x.path), related_images)))
+ for i in related_images:
path = os.path.realpath(str(i.path))
name = os.path.relpath(str(i.path), common_path)
image = cv2.imread(path)
- success, result = cv2.imencode('.JPEG', image)
+ success, result = cv2.imencode(".JPEG", image)
if not success:
raise Exception('Failed to encode image to ".jpeg" format')
- zip_file.writestr(f'{name}.jpg', result.tobytes())
- mime_type = 'application/zip'
+ zip_file.writestr(f"{name}.jpg", result.tobytes())
+
zip_buffer.seek(0)
+ mime_type = "application/zip"
return zip_buffer, mime_type
+
+
+def prepare_preview_image(image: PIL.Image.Image) -> DataWithMime:
+ PREVIEW_SIZE = (256, 256)
+ PREVIEW_MIME = "image/jpeg"
+
+ image = PIL.ImageOps.exif_transpose(image)
+ image.thumbnail(PREVIEW_SIZE)
+
+ output_buf = io.BytesIO()
+ image.convert("RGB").save(output_buf, format="JPEG")
+ return output_buf, PREVIEW_MIME
+
+
+def prepare_chunk(
+ task_chunk_frames: Iterator[Tuple[Any, str, int]],
+ *,
+ quality: FrameQuality,
+ db_task: models.Task,
+ dump_unchanged: bool = False,
+) -> DataWithMime:
+ # TODO: refactor all chunk building into another class
+
+ db_data = db_task.data
+
+ writer_classes: dict[FrameQuality, Type[IChunkWriter]] = {
+ FrameQuality.COMPRESSED: (
+ Mpeg4CompressedChunkWriter
+ if db_data.compressed_chunk_type == models.DataChoice.VIDEO
+ else ZipCompressedChunkWriter
+ ),
+ FrameQuality.ORIGINAL: (
+ Mpeg4ChunkWriter
+ if db_data.original_chunk_type == models.DataChoice.VIDEO
+ else ZipChunkWriter
+ ),
+ }
+
+ writer_class = writer_classes[quality]
+
+ image_quality = 100 if quality == FrameQuality.ORIGINAL else db_data.image_quality
+
+ writer_kwargs = {}
+ if db_task.dimension == models.DimensionType.DIM_3D:
+ writer_kwargs["dimension"] = models.DimensionType.DIM_3D
+ merged_chunk_writer = writer_class(image_quality, **writer_kwargs)
+
+ writer_kwargs = {}
+ if dump_unchanged and isinstance(merged_chunk_writer, ZipCompressedChunkWriter):
+ writer_kwargs = dict(compress_frames=False, zip_compress_level=1)
+
+ buffer = io.BytesIO()
+ merged_chunk_writer.save_as_chunk(task_chunk_frames, buffer, **writer_kwargs)
+
+ buffer.seek(0)
+ return buffer, get_chunk_mime_type_for_writer(writer_class)
+
+
+def get_chunk_mime_type_for_writer(writer: Union[IChunkWriter, Type[IChunkWriter]]) -> str:
+ if isinstance(writer, IChunkWriter):
+ writer_class = type(writer)
+ else:
+ writer_class = writer
+
+ if issubclass(writer_class, ZipChunkWriter):
+ return "application/zip"
+ elif issubclass(writer_class, Mpeg4ChunkWriter):
+ return "video/mp4"
+ else:
+ assert False, f"Unknown chunk writer class {writer_class}"
diff --git a/cvat/apps/engine/default_settings.py b/cvat/apps/engine/default_settings.py
new file mode 100644
index 000000000000..826fe1c9bef2
--- /dev/null
+++ b/cvat/apps/engine/default_settings.py
@@ -0,0 +1,16 @@
+# Copyright (C) 2024 CVAT.ai Corporation
+#
+# SPDX-License-Identifier: MIT
+
+import os
+
+from attrs.converters import to_bool
+
+MEDIA_CACHE_ALLOW_STATIC_CACHE = to_bool(os.getenv("CVAT_ALLOW_STATIC_CACHE", False))
+"""
+Allow or disallow static media cache.
+If disabled, CVAT will only use the dynamic media cache. New tasks requesting static media cache
+will be automatically switched to the dynamic cache.
+When enabled, this option can increase data access speed and reduce server load,
+but significantly increase disk space occupied by tasks.
+"""
diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py
index 4e2f42ef7933..ea14b40a75ad 100644
--- a/cvat/apps/engine/frame_provider.py
+++ b/cvat/apps/engine/frame_provider.py
@@ -3,226 +3,693 @@
#
# SPDX-License-Identifier: MIT
+from __future__ import annotations
+
+import io
+import itertools
import math
-from enum import Enum
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+from enum import Enum, auto
from io import BytesIO
-import os
-
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+
+import av
import cv2
import numpy as np
-from PIL import Image, ImageOps
+from datumaro.util import take_by
+from django.conf import settings
+from PIL import Image
+from rest_framework.exceptions import ValidationError
-from cvat.apps.engine.cache import MediaCache
-from cvat.apps.engine.media_extractors import VideoReader, ZipReader
+from cvat.apps.engine import models
+from cvat.apps.engine.cache import DataWithMime, MediaCache, prepare_chunk
+from cvat.apps.engine.media_extractors import (
+ FrameQuality,
+ IMediaReader,
+ RandomAccessIterator,
+ VideoReader,
+ ZipReader,
+)
from cvat.apps.engine.mime_types import mimetypes
-from cvat.apps.engine.models import DataChoice, StorageMethodChoice, DimensionType
-from rest_framework.exceptions import ValidationError
-class RandomAccessIterator:
- def __init__(self, iterable):
- self.iterable = iterable
- self.iterator = None
- self.pos = -1
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return self[self.pos + 1]
-
- def __getitem__(self, idx):
- assert 0 <= idx
- if self.iterator is None or idx <= self.pos:
- self.reset()
- v = None
- while self.pos < idx:
- # NOTE: don't keep the last item in self, it can be expensive
- v = next(self.iterator)
- self.pos += 1
- return v
-
- def reset(self):
- self.close()
- self.iterator = iter(self.iterable)
-
- def close(self):
- if self.iterator is not None:
- if close := getattr(self.iterator, 'close', None):
- close()
- self.iterator = None
- self.pos = -1
-
-class FrameProvider:
- VIDEO_FRAME_EXT = '.PNG'
- VIDEO_FRAME_MIME = 'image/png'
-
- class Quality(Enum):
- COMPRESSED = 0
- ORIGINAL = 100
-
- class Type(Enum):
- BUFFER = 0
- PIL = 1
- NUMPY_ARRAY = 2
-
- class ChunkLoader:
- def __init__(self, reader_class, path_getter):
- self.chunk_id = None
+_T = TypeVar("_T")
+
+
+class _ChunkLoader(metaclass=ABCMeta):
+ def __init__(
+ self,
+ reader_class: Type[IMediaReader],
+ *,
+ reader_params: Optional[dict] = None,
+ ) -> None:
+ self.chunk_id: Optional[int] = None
+ self.chunk_reader: Optional[RandomAccessIterator] = None
+ self.reader_class = reader_class
+ self.reader_params = reader_params
+
+ def load(self, chunk_id: int) -> RandomAccessIterator[Tuple[Any, str, int]]:
+ if self.chunk_id != chunk_id:
+ self.unload()
+
+ self.chunk_id = chunk_id
+ self.chunk_reader = RandomAccessIterator(
+ self.reader_class(
+ [self.read_chunk(chunk_id)[0]],
+ **(self.reader_params or {}),
+ )
+ )
+ return self.chunk_reader
+
+ def unload(self):
+ self.chunk_id = None
+ if self.chunk_reader:
+ self.chunk_reader.close()
self.chunk_reader = None
- self.reader_class = reader_class
- self.get_chunk_path = path_getter
-
- def load(self, chunk_id):
- if self.chunk_id != chunk_id:
- self.unload()
-
- self.chunk_id = chunk_id
- self.chunk_reader = RandomAccessIterator(
- self.reader_class([self.get_chunk_path(chunk_id)]))
- return self.chunk_reader
-
- def unload(self):
- self.chunk_id = None
- if self.chunk_reader:
- self.chunk_reader.close()
- self.chunk_reader = None
-
- class BuffChunkLoader(ChunkLoader):
- def __init__(self, reader_class, path_getter, quality, db_data):
- super().__init__(reader_class, path_getter)
- self.quality = quality
- self.db_data = db_data
-
- def load(self, chunk_id):
- if self.chunk_id != chunk_id:
- self.chunk_id = chunk_id
- self.chunk_reader = RandomAccessIterator(
- self.reader_class([self.get_chunk_path(chunk_id, self.quality, self.db_data)[0]]))
- return self.chunk_reader
-
- def __init__(self, db_data, dimension=DimensionType.DIM_2D):
- self._db_data = db_data
- self._dimension = dimension
- self._loaders = {}
-
- reader_class = {
- DataChoice.IMAGESET: ZipReader,
- DataChoice.VIDEO: VideoReader,
- }
- if db_data.storage_method == StorageMethodChoice.CACHE:
- cache = MediaCache(dimension=dimension)
-
- self._loaders[self.Quality.COMPRESSED] = self.BuffChunkLoader(
- reader_class[db_data.compressed_chunk_type],
- cache.get_task_chunk_data_with_mime,
- self.Quality.COMPRESSED,
- self._db_data)
- self._loaders[self.Quality.ORIGINAL] = self.BuffChunkLoader(
- reader_class[db_data.original_chunk_type],
- cache.get_task_chunk_data_with_mime,
- self.Quality.ORIGINAL,
- self._db_data)
- else:
- self._loaders[self.Quality.COMPRESSED] = self.ChunkLoader(
- reader_class[db_data.compressed_chunk_type],
- db_data.get_compressed_chunk_path)
- self._loaders[self.Quality.ORIGINAL] = self.ChunkLoader(
- reader_class[db_data.original_chunk_type],
- db_data.get_original_chunk_path)
+ @abstractmethod
+ def read_chunk(self, chunk_id: int) -> DataWithMime: ...
- def __len__(self):
- return self._db_data.size
- def unload(self):
- for loader in self._loaders.values():
- loader.unload()
+class _FileChunkLoader(_ChunkLoader):
+ def __init__(
+ self,
+ reader_class: Type[IMediaReader],
+ get_chunk_path_callback: Callable[[int], str],
+ *,
+ reader_params: Optional[dict] = None,
+ ) -> None:
+ super().__init__(reader_class, reader_params=reader_params)
+ self.get_chunk_path = get_chunk_path_callback
+
+ def read_chunk(self, chunk_id: int) -> DataWithMime:
+ chunk_path = self.get_chunk_path(chunk_id)
+ with open(chunk_path, "rb") as f:
+ return (
+ io.BytesIO(f.read()),
+ mimetypes.guess_type(chunk_path)[0],
+ )
+
+
+class _BufferChunkLoader(_ChunkLoader):
+ def __init__(
+ self,
+ reader_class: Type[IMediaReader],
+ get_chunk_callback: Callable[[int], DataWithMime],
+ *,
+ reader_params: Optional[dict] = None,
+ ) -> None:
+ super().__init__(reader_class, reader_params=reader_params)
+ self.get_chunk = get_chunk_callback
+
+ def read_chunk(self, chunk_id: int) -> DataWithMime:
+ return self.get_chunk(chunk_id)
+
- def _validate_frame_number(self, frame_number):
- frame_number_ = int(frame_number)
- if frame_number_ < 0 or frame_number_ >= self._db_data.size:
- raise ValidationError('Incorrect requested frame number: {}'.format(frame_number_))
+class FrameOutputType(Enum):
+ BUFFER = auto()
+ PIL = auto()
+ NUMPY_ARRAY = auto()
- chunk_number = frame_number_ // self._db_data.chunk_size
- frame_offset = frame_number_ % self._db_data.chunk_size
- return frame_number_, chunk_number, frame_offset
+Frame2d = Union[BytesIO, np.ndarray, Image.Image]
+Frame3d = BytesIO
+AnyFrame = Union[Frame2d, Frame3d]
- def get_chunk_number(self, frame_number):
- return int(frame_number) // self._db_data.chunk_size
- def _validate_chunk_number(self, chunk_number):
- chunk_number_ = int(chunk_number)
- if chunk_number_ < 0 or chunk_number_ >= math.ceil(self._db_data.size / self._db_data.chunk_size):
- raise ValidationError('requested chunk does not exist')
+@dataclass
+class DataWithMeta(Generic[_T]):
+ data: _T
+ mime: str
- return chunk_number_
+
+class IFrameProvider(metaclass=ABCMeta):
+ VIDEO_FRAME_EXT = ".PNG"
+ VIDEO_FRAME_MIME = "image/png"
+
+ def unload(self):
+ pass
@classmethod
- def _av_frame_to_png_bytes(cls, av_frame):
+ def _av_frame_to_png_bytes(cls, av_frame: av.VideoFrame) -> BytesIO:
ext = cls.VIDEO_FRAME_EXT
- image = av_frame.to_ndarray(format='bgr24')
+ image = av_frame.to_ndarray(format="bgr24")
success, result = cv2.imencode(ext, image)
if not success:
- raise RuntimeError("Failed to encode image to '%s' format" % (ext))
+ raise RuntimeError(f"Failed to encode image to '{ext}' format")
return BytesIO(result.tobytes())
- def _convert_frame(self, frame, reader_class, out_type):
- if out_type == self.Type.BUFFER:
- return self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame
- elif out_type == self.Type.PIL:
- return frame.to_image() if reader_class is VideoReader else Image.open(frame)
- elif out_type == self.Type.NUMPY_ARRAY:
- if reader_class is VideoReader:
- image = frame.to_ndarray(format='bgr24')
+ def _convert_frame(
+ self, frame: Any, reader_class: Type[IMediaReader], out_type: FrameOutputType
+ ) -> AnyFrame:
+ if out_type == FrameOutputType.BUFFER:
+ return (
+ self._av_frame_to_png_bytes(frame)
+ if issubclass(reader_class, VideoReader)
+ else frame
+ )
+ elif out_type == FrameOutputType.PIL:
+ return frame.to_image() if issubclass(reader_class, VideoReader) else Image.open(frame)
+ elif out_type == FrameOutputType.NUMPY_ARRAY:
+ if issubclass(reader_class, VideoReader):
+ image = frame.to_ndarray(format="bgr24")
else:
image = np.array(Image.open(frame))
if len(image.shape) == 3 and image.shape[2] in {3, 4}:
- image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR
+ image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR
return image
else:
- raise RuntimeError('unsupported output type')
+ raise RuntimeError("unsupported output type")
+
+ @abstractmethod
+ def validate_frame_number(self, frame_number: int) -> int: ...
+
+ @abstractmethod
+ def validate_chunk_number(self, chunk_number: int) -> int: ...
+
+ @abstractmethod
+ def get_chunk_number(self, frame_number: int) -> int: ...
+
+ @abstractmethod
+ def get_preview(self) -> DataWithMeta[BytesIO]: ...
+
+ @abstractmethod
+ def get_chunk(
+ self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL
+ ) -> DataWithMeta[BytesIO]: ...
+
+ @abstractmethod
+ def get_frame(
+ self,
+ frame_number: int,
+ *,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ out_type: FrameOutputType = FrameOutputType.BUFFER,
+ ) -> DataWithMeta[AnyFrame]: ...
+
+ @abstractmethod
+ def get_frame_context_images_chunk(
+ self,
+ frame_number: int,
+ ) -> Optional[DataWithMeta[BytesIO]]: ...
+
+ @abstractmethod
+ def iterate_frames(
+ self,
+ *,
+ start_frame: Optional[int] = None,
+ stop_frame: Optional[int] = None,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ out_type: FrameOutputType = FrameOutputType.BUFFER,
+ ) -> Iterator[DataWithMeta[AnyFrame]]: ...
+
+ def _get_abs_frame_number(self, db_data: models.Data, rel_frame_number: int) -> int:
+ return db_data.start_frame + rel_frame_number * db_data.get_frame_step()
+
+ def _get_rel_frame_number(self, db_data: models.Data, abs_frame_number: int) -> int:
+ return (abs_frame_number - db_data.start_frame) // db_data.get_frame_step()
+
+
+class TaskFrameProvider(IFrameProvider):
+ def __init__(self, db_task: models.Task) -> None:
+ self._db_task = db_task
+
+ def validate_frame_number(self, frame_number: int) -> int:
+ if frame_number not in range(0, self._db_task.data.size):
+ raise ValidationError(
+ f"Invalid frame '{frame_number}'. "
+ f"The frame number should be in the [0, {self._db_task.data.size}] range"
+ )
+
+ return frame_number
+
+ def validate_chunk_number(self, chunk_number: int) -> int:
+ last_chunk = math.ceil(self._db_task.data.size / self._db_task.data.chunk_size) - 1
+ if not 0 <= chunk_number <= last_chunk:
+ raise ValidationError(
+ f"Invalid chunk number '{chunk_number}'. "
+ f"The chunk number should be in the [0, {last_chunk}] range"
+ )
+
+ return chunk_number
+
+ def get_chunk_number(self, frame_number: int) -> int:
+ return int(frame_number) // self._db_task.data.chunk_size
+
+ def get_abs_frame_number(self, rel_frame_number: int) -> int:
+ "Returns absolute frame number in the task (in the range [start, stop, step])"
+ return super()._get_abs_frame_number(self._db_task.data, rel_frame_number)
+
+ def get_rel_frame_number(self, abs_frame_number: int) -> int:
+ """
+ Returns relative frame number in the task (in the range [0, task_size - 1]).
+ This is the "normal" frame number, expected in other methods.
+ """
+ return super()._get_rel_frame_number(self._db_task.data, abs_frame_number)
+
+ def get_preview(self) -> DataWithMeta[BytesIO]:
+ return self._get_segment_frame_provider(0).get_preview()
+
+ def get_chunk(
+ self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL
+ ) -> DataWithMeta[BytesIO]:
+ return_type = DataWithMeta[BytesIO]
+ chunk_number = self.validate_chunk_number(chunk_number)
+
+ cache = MediaCache()
+ cached_chunk = cache.get_task_chunk(self._db_task, chunk_number, quality=quality)
+ if cached_chunk:
+ return return_type(cached_chunk[0], cached_chunk[1])
+
+ db_data = self._db_task.data
+ step = db_data.get_frame_step()
+ task_chunk_start_frame = chunk_number * db_data.chunk_size
+ task_chunk_stop_frame = (chunk_number + 1) * db_data.chunk_size - 1
+ task_chunk_frame_set = set(
+ range(
+ db_data.start_frame + task_chunk_start_frame * step,
+ min(db_data.start_frame + task_chunk_stop_frame * step, db_data.stop_frame) + step,
+ step,
+ )
+ )
+
+ matching_segments: list[models.Segment] = sorted(
+ [
+ s
+ for s in self._db_task.segment_set.all()
+ if s.type == models.SegmentType.RANGE
+ if not task_chunk_frame_set.isdisjoint(s.frame_set)
+ ],
+ key=lambda s: s.start_frame,
+ )
+ assert matching_segments
+
+ # Don't put this into set_callback to avoid data duplication in the cache
+
+ if len(matching_segments) == 1:
+ segment_frame_provider = SegmentFrameProvider(matching_segments[0])
+ matching_chunk_index = segment_frame_provider.find_matching_chunk(
+ sorted(task_chunk_frame_set)
+ )
+ if matching_chunk_index is not None:
+ # The requested frames match one of the job chunks, we can use it directly
+ return segment_frame_provider.get_chunk(matching_chunk_index, quality=quality)
+
+ def _set_callback() -> DataWithMime:
+ # Create and return a joined / cleaned chunk
+ task_chunk_frames = {}
+ for db_segment in matching_segments:
+ segment_frame_provider = SegmentFrameProvider(db_segment)
+ segment_frame_set = db_segment.frame_set
+
+ for task_chunk_frame_id in sorted(task_chunk_frame_set):
+ if (
+ task_chunk_frame_id not in segment_frame_set
+ or task_chunk_frame_id in task_chunk_frames
+ ):
+ continue
+
+ frame, frame_name, _ = segment_frame_provider._get_raw_frame(
+ self.get_rel_frame_number(task_chunk_frame_id), quality=quality
+ )
+ task_chunk_frames[task_chunk_frame_id] = (frame, frame_name, None)
+
+ return prepare_chunk(
+ task_chunk_frames.values(),
+ quality=quality,
+ db_task=self._db_task,
+ dump_unchanged=True,
+ )
+
+ buffer, mime_type = cache.get_or_set_task_chunk(
+ self._db_task, chunk_number, quality=quality, set_callback=_set_callback
+ )
+
+ return return_type(data=buffer, mime=mime_type)
+
+ def get_frame(
+ self,
+ frame_number: int,
+ *,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ out_type: FrameOutputType = FrameOutputType.BUFFER,
+ ) -> DataWithMeta[AnyFrame]:
+ return self._get_segment_frame_provider(frame_number).get_frame(
+ frame_number, quality=quality, out_type=out_type
+ )
+
+ def get_frame_context_images_chunk(
+ self,
+ frame_number: int,
+ ) -> Optional[DataWithMeta[BytesIO]]:
+ return self._get_segment_frame_provider(frame_number).get_frame_context_images_chunk(
+ frame_number
+ )
+
+ def iterate_frames(
+ self,
+ *,
+ start_frame: Optional[int] = None,
+ stop_frame: Optional[int] = None,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ out_type: FrameOutputType = FrameOutputType.BUFFER,
+ ) -> Iterator[DataWithMeta[AnyFrame]]:
+ frame_range = itertools.count(start_frame, self._db_task.data.get_frame_step())
+ if stop_frame:
+ frame_range = itertools.takewhile(lambda x: x <= stop_frame, frame_range)
+
+ db_segment = None
+ db_segment_frame_set = None
+ db_segment_frame_provider = None
+ for idx in frame_range:
+ if db_segment and idx not in db_segment_frame_set:
+ db_segment = None
+ db_segment_frame_set = None
+ db_segment_frame_provider = None
+
+ if not db_segment:
+ db_segment = self._get_segment(idx)
+ db_segment_frame_set = set(db_segment.frame_set)
+ db_segment_frame_provider = SegmentFrameProvider(db_segment)
+
+ yield db_segment_frame_provider.get_frame(idx, quality=quality, out_type=out_type)
+
+ def _get_segment(self, validated_frame_number: int) -> models.Segment:
+ if not self._db_task.data or not self._db_task.data.size:
+ raise ValidationError("Task has no data")
+
+ abs_frame_number = self.get_abs_frame_number(validated_frame_number)
+
+ return next(
+ s
+ for s in self._db_task.segment_set.all()
+ if s.type == models.SegmentType.RANGE
+ if abs_frame_number in s.frame_set
+ )
+
+ def _get_segment_frame_provider(self, frame_number: int) -> SegmentFrameProvider:
+ return SegmentFrameProvider(self._get_segment(self.validate_frame_number(frame_number)))
+
+
+class SegmentFrameProvider(IFrameProvider):
+ def __init__(self, db_segment: models.Segment) -> None:
+ super().__init__()
+ self._db_segment = db_segment
+
+ db_data = db_segment.task.data
+
+ reader_class: dict[models.DataChoice, Tuple[Type[IMediaReader], Optional[dict]]] = {
+ models.DataChoice.IMAGESET: (ZipReader, None),
+ models.DataChoice.VIDEO: (
+ VideoReader,
+ {
+ "allow_threading": False
+ # disable threading to avoid unpredictable server
+ # resource consumption during reading in endpoints
+ # can be enabled for other clients
+ },
+ ),
+ }
+
+ self._loaders: dict[FrameQuality, _ChunkLoader] = {}
+ if (
+ db_data.storage_method == models.StorageMethodChoice.CACHE
+ or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE
+ # TODO: separate handling, extract cache creation logic from media cache
+ ):
+ cache = MediaCache()
+
+ self._loaders[FrameQuality.COMPRESSED] = _BufferChunkLoader(
+ reader_class=reader_class[db_data.compressed_chunk_type][0],
+ reader_params=reader_class[db_data.compressed_chunk_type][1],
+ get_chunk_callback=lambda chunk_idx: cache.get_or_set_segment_chunk(
+ db_segment, chunk_idx, quality=FrameQuality.COMPRESSED
+ ),
+ )
+
+ self._loaders[FrameQuality.ORIGINAL] = _BufferChunkLoader(
+ reader_class=reader_class[db_data.original_chunk_type][0],
+ reader_params=reader_class[db_data.original_chunk_type][1],
+ get_chunk_callback=lambda chunk_idx: cache.get_or_set_segment_chunk(
+ db_segment, chunk_idx, quality=FrameQuality.ORIGINAL
+ ),
+ )
+ else:
+ self._loaders[FrameQuality.COMPRESSED] = _FileChunkLoader(
+ reader_class=reader_class[db_data.compressed_chunk_type][0],
+ reader_params=reader_class[db_data.compressed_chunk_type][1],
+ get_chunk_path_callback=lambda chunk_idx: db_data.get_compressed_segment_chunk_path(
+ chunk_idx, segment_id=db_segment.id
+ ),
+ )
+
+ self._loaders[FrameQuality.ORIGINAL] = _FileChunkLoader(
+ reader_class=reader_class[db_data.original_chunk_type][0],
+ reader_params=reader_class[db_data.original_chunk_type][1],
+ get_chunk_path_callback=lambda chunk_idx: db_data.get_original_segment_chunk_path(
+ chunk_idx, segment_id=db_segment.id
+ ),
+ )
+
+ def unload(self):
+ for loader in self._loaders.values():
+ loader.unload()
+
+ def __len__(self):
+ return self._db_segment.frame_count
+
+ def validate_frame_number(self, frame_number: int) -> Tuple[int, int, int]:
+ frame_sequence = list(self._db_segment.frame_set)
+ abs_frame_number = self._get_abs_frame_number(self._db_segment.task.data, frame_number)
+ if abs_frame_number not in frame_sequence:
+ raise ValidationError(f"Incorrect requested frame number: {frame_number}")
+
+ # TODO: maybe optimize search
+ chunk_number, frame_position = divmod(
+ frame_sequence.index(abs_frame_number), self._db_segment.task.data.chunk_size
+ )
+ return frame_number, chunk_number, frame_position
+
+ def get_chunk_number(self, frame_number: int) -> int:
+ return int(frame_number) // self._db_segment.task.data.chunk_size
+
+ def find_matching_chunk(self, frames: Sequence[int]) -> Optional[int]:
+ return next(
+ (
+ i
+ for i, chunk_frames in enumerate(
+ take_by(
+ sorted(self._db_segment.frame_set), self._db_segment.task.data.chunk_size
+ )
+ )
+ if frames == set(chunk_frames)
+ ),
+ None,
+ )
+
+ def validate_chunk_number(self, chunk_number: int) -> int:
+ segment_size = self._db_segment.frame_count
+ last_chunk = math.ceil(segment_size / self._db_segment.task.data.chunk_size) - 1
+ if not 0 <= chunk_number <= last_chunk:
+ raise ValidationError(
+ f"Invalid chunk number '{chunk_number}'. "
+ f"The chunk number should be in the [0, {last_chunk}] range"
+ )
+
+ return chunk_number
+
+ def get_preview(self) -> DataWithMeta[BytesIO]:
+ cache = MediaCache()
+ preview, mime = cache.get_or_set_segment_preview(self._db_segment)
+ return DataWithMeta[BytesIO](preview, mime=mime)
+
+ def get_chunk(
+ self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL
+ ) -> DataWithMeta[BytesIO]:
+ chunk_number = self.validate_chunk_number(chunk_number)
+ chunk_data, mime = self._loaders[quality].read_chunk(chunk_number)
+ return DataWithMeta[BytesIO](chunk_data, mime=mime)
+
+ def _get_raw_frame(
+ self,
+ frame_number: int,
+ *,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ ) -> Tuple[Any, str, Type[IMediaReader]]:
+ _, chunk_number, frame_offset = self.validate_frame_number(frame_number)
+ loader = self._loaders[quality]
+ chunk_reader = loader.load(chunk_number)
+ frame, frame_name, _ = chunk_reader[frame_offset]
+ return frame, frame_name, loader.reader_class
+
+ def get_frame(
+ self,
+ frame_number: int,
+ *,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ out_type: FrameOutputType = FrameOutputType.BUFFER,
+ ) -> DataWithMeta[AnyFrame]:
+ return_type = DataWithMeta[AnyFrame]
- def get_preview(self, frame_number):
- PREVIEW_SIZE = (256, 256)
- PREVIEW_MIME = 'image/jpeg'
+ frame, frame_name, reader_class = self._get_raw_frame(frame_number, quality=quality)
- if self._dimension == DimensionType.DIM_3D:
- # TODO
- preview = Image.open(os.path.join(os.path.dirname(__file__), 'assets/3d_preview.jpeg'))
+ frame = self._convert_frame(frame, reader_class, out_type)
+ if issubclass(reader_class, VideoReader):
+ return return_type(frame, mime=self.VIDEO_FRAME_MIME)
+
+ return return_type(frame, mime=mimetypes.guess_type(frame_name)[0])
+
+ def get_frame_context_images_chunk(
+ self,
+ frame_number: int,
+ ) -> Optional[DataWithMeta[BytesIO]]:
+ self.validate_frame_number(frame_number)
+
+ db_data = self._db_segment.task.data
+
+ cache = MediaCache()
+ if db_data.storage_method == models.StorageMethodChoice.CACHE:
+ data, mime = cache.get_or_set_frame_context_images_chunk(db_data, frame_number)
else:
- preview, _ = self.get_frame(frame_number, self.Quality.COMPRESSED, self.Type.PIL)
+ data, mime = cache.prepare_context_images_chunk(db_data, frame_number)
+
+ if not data.getvalue():
+ return None
+
+ return DataWithMeta[BytesIO](data, mime=mime)
+
+ def iterate_frames(
+ self,
+ *,
+ start_frame: Optional[int] = None,
+ stop_frame: Optional[int] = None,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ out_type: FrameOutputType = FrameOutputType.BUFFER,
+ ) -> Iterator[DataWithMeta[AnyFrame]]:
+ frame_range = itertools.count(start_frame)
+ if stop_frame:
+ frame_range = itertools.takewhile(lambda x: x <= stop_frame, frame_range)
+
+ segment_frame_set = set(self._db_segment.frame_set)
+ for idx in frame_range:
+ if self._get_abs_frame_number(self._db_segment.task.data, idx) in segment_frame_set:
+ yield self.get_frame(idx, quality=quality, out_type=out_type)
+
+
+class JobFrameProvider(SegmentFrameProvider):
+ def __init__(self, db_job: models.Job) -> None:
+ super().__init__(db_job.segment)
+
+ def get_chunk(
+ self,
+ chunk_number: int,
+ *,
+ quality: FrameQuality = FrameQuality.ORIGINAL,
+ is_task_chunk: bool = False,
+ ) -> DataWithMeta[BytesIO]:
+ if not is_task_chunk:
+ return super().get_chunk(chunk_number, quality=quality)
+
+ # Backward compatibility for the "number" parameter
+ # Reproduce the task chunks, limited by this job
+ return_type = DataWithMeta[BytesIO]
+
+ task_frame_provider = TaskFrameProvider(self._db_segment.task)
+ segment_start_chunk = task_frame_provider.get_chunk_number(self._db_segment.start_frame)
+ segment_stop_chunk = task_frame_provider.get_chunk_number(self._db_segment.stop_frame)
+ if not segment_start_chunk <= chunk_number <= segment_stop_chunk:
+ raise ValidationError(
+ f"Invalid chunk number '{chunk_number}'. "
+ "The chunk number should be in the "
+ f"[{segment_start_chunk}, {segment_stop_chunk}] range"
+ )
+
+ cache = MediaCache()
+ cached_chunk = cache.get_segment_task_chunk(self._db_segment, chunk_number, quality=quality)
+ if cached_chunk:
+ return return_type(cached_chunk[0], cached_chunk[1])
+
+ db_data = self._db_segment.task.data
+ step = db_data.get_frame_step()
+ task_chunk_start_frame = chunk_number * db_data.chunk_size
+ task_chunk_stop_frame = (chunk_number + 1) * db_data.chunk_size - 1
+ task_chunk_frame_set = set(
+ range(
+ db_data.start_frame + task_chunk_start_frame * step,
+ min(db_data.start_frame + task_chunk_stop_frame * step, db_data.stop_frame) + step,
+ step,
+ )
+ )
+
+ # Don't put this into set_callback to avoid data duplication in the cache
+ matching_chunk = self.find_matching_chunk(sorted(task_chunk_frame_set))
+ if matching_chunk is not None:
+ return self.get_chunk(matching_chunk, quality=quality)
+
+ def _set_callback() -> DataWithMime:
+ # Create and return a joined / cleaned chunk
+ segment_chunk_frame_ids = sorted(
+ task_chunk_frame_set.intersection(self._db_segment.frame_set)
+ )
+
+ if self._db_segment.type == models.SegmentType.RANGE:
+ return cache.prepare_custom_range_segment_chunk(
+ db_task=self._db_segment.task,
+ frame_ids=segment_chunk_frame_ids,
+ quality=quality,
+ )
+ elif self._db_segment.type == models.SegmentType.SPECIFIC_FRAMES:
+ return cache.prepare_custom_masked_range_segment_chunk(
+ db_task=self._db_segment.task,
+ frame_ids=segment_chunk_frame_ids,
+ chunk_number=chunk_number,
+ quality=quality,
+ insert_placeholders=True,
+ )
+ else:
+ assert False
- preview = ImageOps.exif_transpose(preview)
- preview.thumbnail(PREVIEW_SIZE)
+ buffer, mime_type = cache.get_or_set_segment_task_chunk(
+ self._db_segment, chunk_number, quality=quality, set_callback=_set_callback
+ )
- output_buf = BytesIO()
- preview.convert('RGB').save(output_buf, format="JPEG")
+ return return_type(data=buffer, mime=mime_type)
- return output_buf, PREVIEW_MIME
- def get_chunk(self, chunk_number, quality=Quality.ORIGINAL):
- chunk_number = self._validate_chunk_number(chunk_number)
- if self._db_data.storage_method == StorageMethodChoice.CACHE:
- return self._loaders[quality].get_chunk_path(chunk_number, quality, self._db_data)
- return self._loaders[quality].get_chunk_path(chunk_number)
+@overload
+def make_frame_provider(data_source: models.Job) -> JobFrameProvider: ...
- def get_frame(self, frame_number, quality=Quality.ORIGINAL,
- out_type=Type.BUFFER):
- _, chunk_number, frame_offset = self._validate_frame_number(frame_number)
- loader = self._loaders[quality]
- chunk_reader = loader.load(chunk_number)
- frame, frame_name, _ = chunk_reader[frame_offset]
- frame = self._convert_frame(frame, loader.reader_class, out_type)
- if loader.reader_class is VideoReader:
- return (frame, self.VIDEO_FRAME_MIME)
- return (frame, mimetypes.guess_type(frame_name)[0])
+@overload
+def make_frame_provider(data_source: models.Segment) -> SegmentFrameProvider: ...
+
+
+@overload
+def make_frame_provider(data_source: models.Task) -> TaskFrameProvider: ...
+
- def get_frames(self, start_frame, stop_frame, quality=Quality.ORIGINAL, out_type=Type.BUFFER):
- for idx in range(start_frame, stop_frame):
- yield self.get_frame(idx, quality=quality, out_type=out_type)
+def make_frame_provider(
+ data_source: Union[models.Job, models.Segment, models.Task, Any]
+) -> IFrameProvider:
+ if isinstance(data_source, models.Task):
+ frame_provider = TaskFrameProvider(data_source)
+ elif isinstance(data_source, models.Segment):
+ frame_provider = SegmentFrameProvider(data_source)
+ elif isinstance(data_source, models.Job):
+ frame_provider = JobFrameProvider(data_source)
+ else:
+ raise TypeError(f"Unexpected data source type {type(data_source)}")
- @property
- def data_id(self):
- return self._db_data.id
+ return frame_provider
diff --git a/cvat/apps/engine/log.py b/cvat/apps/engine/log.py
index 5f123d33eef8..6f1740e74fd4 100644
--- a/cvat/apps/engine/log.py
+++ b/cvat/apps/engine/log.py
@@ -59,24 +59,31 @@ def get_logger(logger_name, log_file):
vlogger = logging.getLogger('vector')
+
+def get_migration_log_dir() -> str:
+ return settings.MIGRATIONS_LOGS_ROOT
+
+def get_migration_log_file_path(migration_name: str) -> str:
+ return osp.join(get_migration_log_dir(), f'{migration_name}.log')
+
@contextmanager
def get_migration_logger(migration_name):
- migration_log_file = '{}.log'.format(migration_name)
+ migration_log_file_path = get_migration_log_file_path(migration_name)
stdout = sys.stdout
stderr = sys.stderr
+
# redirect all stdout to the file
- log_file_object = open(osp.join(settings.MIGRATIONS_LOGS_ROOT, migration_log_file), 'w')
- sys.stdout = log_file_object
- sys.stderr = log_file_object
-
- log = logging.getLogger(migration_name)
- log.addHandler(logging.StreamHandler(stdout))
- log.addHandler(logging.StreamHandler(log_file_object))
- log.setLevel(logging.INFO)
-
- try:
- yield log
- finally:
- log_file_object.close()
- sys.stdout = stdout
- sys.stderr = stderr
+ with open(migration_log_file_path, 'w') as log_file_object:
+ sys.stdout = log_file_object
+ sys.stderr = log_file_object
+
+ log = logging.getLogger(migration_name)
+ log.addHandler(logging.StreamHandler(stdout))
+ log.addHandler(logging.StreamHandler(log_file_object))
+ log.setLevel(logging.INFO)
+
+ try:
+ yield log
+ finally:
+ sys.stdout = stdout
+ sys.stderr = stderr
diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py
index 9a352c3b930c..9ddbad10e3a8 100644
--- a/cvat/apps/engine/media_extractors.py
+++ b/cvat/apps/engine/media_extractors.py
@@ -3,6 +3,8 @@
#
# SPDX-License-Identifier: MIT
+from __future__ import annotations
+
import os
import sysconfig
import tempfile
@@ -11,12 +13,20 @@
import io
import itertools
import struct
-from enum import IntEnum
from abc import ABC, abstractmethod
-from contextlib import closing
-from typing import Iterable
+from bisect import bisect
+from contextlib import ExitStack, closing, contextmanager
+from dataclasses import dataclass
+from enum import IntEnum
+from typing import (
+ Any, Callable, ContextManager, Generator, Iterable, Iterator, Optional, Protocol,
+ Sequence, Tuple, TypeVar, Union
+)
import av
+import av.codec
+import av.container
+import av.video.stream
import numpy as np
from natsort import os_sorted
from pyunpack import Archive
@@ -45,6 +55,10 @@ class ORIENTATION(IntEnum):
MIRROR_HORIZONTAL_90_ROTATED=7
NORMAL_270_ROTATED=8
+class FrameQuality(IntEnum):
+ COMPRESSED = 0
+ ORIGINAL = 100
+
def get_mime(name):
for type_name, type_def in MEDIA_TYPES.items():
if type_def['has_mime_type'](name):
@@ -78,21 +92,126 @@ def sort(images, sorting_method=SortingMethod.LEXICOGRAPHICAL, func=None):
else:
raise NotImplementedError()
-def image_size_within_orientation(img: Image):
+def image_size_within_orientation(img: Image.Image):
orientation = img.getexif().get(ORIENTATION_EXIF_TAG, ORIENTATION.NORMAL_HORIZONTAL)
if orientation > 4:
return img.height, img.width
return img.width, img.height
-def has_exif_rotation(img: Image):
+def has_exif_rotation(img: Image.Image):
return img.getexif().get(ORIENTATION_EXIF_TAG, ORIENTATION.NORMAL_HORIZONTAL) != ORIENTATION.NORMAL_HORIZONTAL
+_T = TypeVar("_T")
+
+
+class RandomAccessIterator(Iterator[_T]):
+ def __init__(self, iterable: Iterable[_T]):
+ self.iterable: Iterable[_T] = iterable
+ self.iterator: Optional[Iterator[_T]] = None
+ self.pos: int = -1
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self[self.pos + 1]
+
+ def __getitem__(self, idx: int) -> Optional[_T]:
+ assert 0 <= idx
+ if self.iterator is None or idx <= self.pos:
+ self.reset()
+ v = None
+ while self.pos < idx:
+ # NOTE: don't keep the last item in self, it can be expensive
+ v = next(self.iterator)
+ self.pos += 1
+ return v
+
+ def reset(self):
+ self.close()
+ self.iterator = iter(self.iterable)
+
+ def close(self):
+ if self.iterator is not None:
+ if close := getattr(self.iterator, "close", None):
+ close()
+ self.iterator = None
+ self.pos = -1
+
+
+class Sized(Protocol):
+ def get_size(self) -> int: ...
+
+_MediaT = TypeVar("_MediaT", bound=Sized)
+
+class CachingMediaIterator(RandomAccessIterator[_MediaT]):
+ @dataclass
+ class _CacheItem:
+ value: _MediaT
+ size: int
+
+ def __init__(
+ self,
+ iterable: Iterable,
+ *,
+ max_cache_memory: int,
+ max_cache_entries: int,
+ object_size_callback: Optional[Callable[[_MediaT], int]] = None,
+ ):
+ super().__init__(iterable)
+ self.max_cache_entries = max_cache_entries
+ self.max_cache_memory = max_cache_memory
+ self._get_object_size_callback = object_size_callback
+ self.used_cache_memory = 0
+ self._cache: dict[int, self._CacheItem] = {}
+
+ def _get_object_size(self, obj: _MediaT) -> int:
+ if self._get_object_size_callback:
+ return self._get_object_size_callback(obj)
+
+ return obj.get_size()
+
+ def __getitem__(self, idx: int):
+ cache_item = self._cache.get(idx)
+ if cache_item:
+ return cache_item.value
+
+ value = super().__getitem__(idx)
+ value_size = self._get_object_size(value)
+
+ while (
+ len(self._cache) + 1 > self.max_cache_entries or
+ self.used_cache_memory + value_size > self.max_cache_memory
+ ):
+ min_key = min(self._cache.keys())
+ self._cache.pop(min_key)
+
+ if self.used_cache_memory + value_size <= self.max_cache_memory:
+ self._cache[idx] = self._CacheItem(value, value_size)
+
+ return value
+
+
class IMediaReader(ABC):
- def __init__(self, source_path, step, start, stop, dimension):
+ def __init__(
+ self,
+ source_path,
+ *,
+ start: int = 0,
+ stop: Optional[int] = None,
+ step: int = 1,
+ dimension: DimensionType = DimensionType.DIM_2D
+ ):
self._source_path = source_path
+
self._step = step
+
self._start = start
+ "The first included index"
+
self._stop = stop
+ "The last included index"
+
self._dimension = dimension
@abstractmethod
@@ -140,30 +259,25 @@ def _get_preview(obj):
def get_image_size(self, i):
pass
- def __len__(self):
- return len(self.frame_range)
-
- @property
- def frame_range(self):
- return range(self._start, self._stop, self._step)
-
class ImageListReader(IMediaReader):
def __init__(self,
- source_path,
- step=1,
- start=0,
- stop=None,
- dimension=DimensionType.DIM_2D,
- sorting_method=SortingMethod.LEXICOGRAPHICAL):
+ source_path,
+ step: int = 1,
+ start: int = 0,
+ stop: Optional[int] = None,
+ dimension: DimensionType = DimensionType.DIM_2D,
+ sorting_method: SortingMethod = SortingMethod.LEXICOGRAPHICAL,
+ ):
if not source_path:
raise Exception('No image found')
if not stop:
- stop = len(source_path)
+ stop = len(source_path) - 1
else:
- stop = min(len(source_path), stop + 1)
+ stop = min(len(source_path) - 1, stop)
+
step = max(step, 1)
- assert stop > start
+ assert stop >= start
super().__init__(
source_path=sort(source_path, sorting_method),
@@ -176,7 +290,7 @@ def __init__(self,
self._sorting_method = sorting_method
def __iter__(self):
- for i in range(self._start, self._stop, self._step):
+ for i in self.frame_range:
yield (self.get_image(i), self.get_path(i), i)
def __contains__(self, media_file):
@@ -189,7 +303,7 @@ def filter(self, callback):
source_path,
step=self._step,
start=self._start,
- stop=self._stop - 1,
+ stop=self._stop,
dimension=self._dimension,
sorting_method=self._sorting_method
)
@@ -201,7 +315,7 @@ def get_image(self, i):
return self._source_path[i]
def get_progress(self, pos):
- return (pos - self._start + 1) / (self._stop - self._start)
+ return (pos + 1) / (len(self.frame_range) or 1)
def get_preview(self, frame):
if self._dimension == DimensionType.DIM_3D:
@@ -233,6 +347,13 @@ def reconcile(self, source_files, step=1, start=0, stop=None, dimension=Dimensio
def absolute_source_paths(self):
return [self.get_path(idx) for idx, _ in enumerate(self._source_path)]
+ def __len__(self):
+ return len(self.frame_range)
+
+ @property
+ def frame_range(self):
+ return range(self._start, self._stop + 1, self._step)
+
class DirectoryReader(ImageListReader):
def __init__(self,
source_path,
@@ -403,57 +524,149 @@ def extract(self):
if not self.extract_dir:
os.remove(self._zip_source.filename)
+class _AvVideoReading:
+ @contextmanager
+ def read_av_container(self, source: Union[str, io.BytesIO]) -> av.container.InputContainer:
+ if isinstance(source, io.BytesIO):
+ source.seek(0) # required for re-reading
+
+ container = av.open(source)
+ try:
+ yield container
+ finally:
+ # fixes a memory leak in input container closing
+ # https://github.com/PyAV-Org/PyAV/issues/1117
+ for stream in container.streams:
+ context = stream.codec_context
+ if context and context.is_open:
+ context.close()
+
+ if container.open_files:
+ container.close()
+
+ def decode_stream(
+ self, container: av.container.Container, video_stream: av.video.stream.VideoStream
+ ) -> Generator[av.VideoFrame, None, None]:
+ demux_iter = container.demux(video_stream)
+ try:
+ for packet in demux_iter:
+ yield from packet.decode()
+ finally:
+ # av v9.2.0 seems to have a memory corruption or a deadlock
+ # in exception handling for demux() in the multithreaded mode.
+ # Instead of breaking the iteration, we iterate over packets till the end.
+ # Fixed in av v12.2.0.
+ if av.__version__ == "9.2.0" and video_stream.thread_type == 'AUTO':
+ exhausted = object()
+ while next(demux_iter, exhausted) is not exhausted:
+ pass
+
class VideoReader(IMediaReader):
- def __init__(self, source_path, step=1, start=0, stop=None, dimension=DimensionType.DIM_2D):
+ def __init__(
+ self,
+ source_path: Union[str, io.BytesIO],
+ step: int = 1,
+ start: int = 0,
+ stop: Optional[int] = None,
+ dimension: DimensionType = DimensionType.DIM_2D,
+ *,
+ allow_threading: bool = True,
+ ):
super().__init__(
source_path=source_path,
step=step,
start=start,
- stop=stop + 1 if stop is not None else stop,
+ stop=stop,
dimension=dimension,
)
- def _has_frame(self, i):
- if i >= self._start:
- if (i - self._start) % self._step == 0:
- if self._stop is None or i < self._stop:
- return True
+ self.allow_threading = allow_threading
+ self._frame_count: Optional[int] = None
+ self._frame_size: Optional[tuple[int, int]] = None # (w, h)
- return False
+ def iterate_frames(
+ self,
+ *,
+ frame_filter: Union[bool, Iterable[int]] = True,
+ video_stream: Optional[av.video.stream.VideoStream] = None,
+ ) -> Iterator[Tuple[av.VideoFrame, str, int]]:
+ """
+ If provided, frame_filter must be an ordered sequence in the ascending order.
+ 'True' means using the frames configured in the reader object.
+ 'False' or 'None' means returning all the video frames.
+ """
- def __iter__(self):
- with self._get_av_container() as container:
- stream = container.streams.video[0]
- stream.thread_type = 'AUTO'
- frame_num = 0
- for packet in container.demux(stream):
- for image in packet.decode():
- frame_num += 1
- if self._has_frame(frame_num - 1):
- if packet.stream.metadata.get('rotate'):
- pts = image.pts
- image = av.VideoFrame().from_ndarray(
+ if frame_filter is True:
+ frame_filter = itertools.count(self._start, self._step)
+ if self._stop:
+ frame_filter = itertools.takewhile(lambda x: x <= self._stop, frame_filter)
+ elif not frame_filter:
+ frame_filter = itertools.count()
+
+ frame_filter_iter = iter(frame_filter)
+ next_frame_filter_frame = next(frame_filter_iter, None)
+ if next_frame_filter_frame is None:
+ return
+
+ es = ExitStack()
+
+ needs_init = video_stream is None
+ if needs_init:
+ container = es.enter_context(self._read_av_container())
+ else:
+ container = video_stream.container
+
+ with es:
+ if needs_init:
+ video_stream = container.streams.video[0]
+
+ if self.allow_threading:
+ video_stream.thread_type = 'AUTO'
+
+ frame_counter = itertools.count()
+ with closing(self._decode_stream(container, video_stream)) as stream_decoder:
+ for frame, frame_number in zip(stream_decoder, frame_counter):
+ if frame_number == next_frame_filter_frame:
+ if video_stream.metadata.get('rotate'):
+ pts = frame.pts
+ frame = av.VideoFrame().from_ndarray(
rotate_image(
- image.to_ndarray(format='bgr24'),
- 360 - int(stream.metadata.get('rotate'))
+ frame.to_ndarray(format='bgr24'),
+ 360 - int(video_stream.metadata.get('rotate'))
),
format ='bgr24'
)
- image.pts = pts
- yield (image, self._source_path[0], image.pts)
+ frame.pts = pts
+
+ if self._frame_size is None:
+ self._frame_size = (frame.width, frame.height)
+
+ yield (frame, self._source_path[0], frame.pts)
+
+ next_frame_filter_frame = next(frame_filter_iter, None)
+
+ if next_frame_filter_frame is None:
+ return
+
+ def __iter__(self) -> Iterator[Tuple[av.VideoFrame, str, int]]:
+ return self.iterate_frames()
def get_progress(self, pos):
duration = self._get_duration()
return pos / duration if duration else None
- def _get_av_container(self):
- if isinstance(self._source_path[0], io.BytesIO):
- self._source_path[0].seek(0) # required for re-reading
- return av.open(self._source_path[0])
+ def _read_av_container(self) -> ContextManager[av.container.InputContainer]:
+ return _AvVideoReading().read_av_container(self._source_path[0])
+
+ def _decode_stream(
+ self, container: av.container.Container, video_stream: av.video.stream.VideoStream
+ ) -> Generator[av.VideoFrame, None, None]:
+ return _AvVideoReading().decode_stream(container, video_stream)
def _get_duration(self):
- with self._get_av_container() as container:
+ with self._read_av_container() as container:
stream = container.streams.video[0]
+
duration = None
if stream.duration:
duration = stream.duration
@@ -468,122 +681,128 @@ def _get_duration(self):
return duration
def get_preview(self, frame):
- with self._get_av_container() as container:
+ with self._read_av_container() as container:
stream = container.streams.video[0]
+
tb_denominator = stream.time_base.denominator
needed_time = int((frame / stream.guessed_rate) * tb_denominator)
container.seek(offset=needed_time, stream=stream)
- for packet in container.demux(stream):
- for frame in packet.decode():
- return self._get_preview(frame.to_image() if not stream.metadata.get('rotate') \
- else av.VideoFrame().from_ndarray(
- rotate_image(
- frame.to_ndarray(format='bgr24'),
- 360 - int(container.streams.video[0].metadata.get('rotate'))
- ),
- format ='bgr24'
- ).to_image()
- )
+
+ with closing(self.iterate_frames(video_stream=stream)) as frame_iter:
+ return self._get_preview(next(frame_iter))
def get_image_size(self, i):
- image = (next(iter(self)))[0]
- return image.width, image.height
+ if self._frame_size is not None:
+ return self._frame_size
-class FragmentMediaReader:
- def __init__(self, chunk_number, chunk_size, start, stop, step=1):
- self._start = start
- self._stop = stop + 1 # up to the last inclusive
- self._step = step
- self._chunk_number = chunk_number
- self._chunk_size = chunk_size
- self._start_chunk_frame_number = \
- self._start + self._chunk_number * self._chunk_size * self._step
- self._end_chunk_frame_number = min(self._start_chunk_frame_number \
- + (self._chunk_size - 1) * self._step + 1, self._stop)
- self._frame_range = self._get_frame_range()
+ with closing(iter(self)) as frame_iter:
+ frame = next(frame_iter)[0]
+ self._frame_size = (frame.width, frame.height)
- @property
- def frame_range(self):
- return self._frame_range
+ return self._frame_size
- def _get_frame_range(self):
- frame_range = []
- for idx in range(self._start, self._stop, self._step):
- if idx < self._start_chunk_frame_number:
- continue
- elif idx < self._end_chunk_frame_number and \
- not (idx - self._start_chunk_frame_number) % self._step:
- frame_range.append(idx)
- elif (idx - self._start_chunk_frame_number) % self._step:
- continue
- else:
- break
- return frame_range
+ def get_frame_count(self) -> int:
+ """
+ Returns total frame count in the video
+
+ Note that not all videos provide length / duration metainfo, so the
+ result may require full video decoding.
+
+ The total count is NOT affected by the frame filtering options of the object,
+ i.e. start frame, end frame and frame step.
+ """
+ # It's possible to retrieve frame count from the stream.frames,
+ # but the number may be incorrect.
+ # https://superuser.com/questions/1512575/why-total-frame-count-is-different-in-ffmpeg-than-ffprobe
+ if self._frame_count is not None:
+ return self._frame_count
+
+ frame_count = 0
+ for _ in self.iterate_frames(frame_filter=False):
+ frame_count += 1
+
+ self._frame_count = frame_count
-class ImageDatasetManifestReader(FragmentMediaReader):
- def __init__(self, manifest_path, **kwargs):
- super().__init__(**kwargs)
+ return frame_count
+
+
+class ImageReaderWithManifest:
+ def __init__(self, manifest_path: str):
self._manifest = ImageManifestManager(manifest_path)
self._manifest.init_index()
- def __iter__(self):
- for idx in self._frame_range:
+ def iterate_frames(self, frame_ids: Iterable[int]):
+ for idx in frame_ids:
yield self._manifest[idx]
-class VideoDatasetManifestReader(FragmentMediaReader):
- def __init__(self, manifest_path, **kwargs):
- self.source_path = kwargs.pop('source_path')
- super().__init__(**kwargs)
- self._manifest = VideoManifestManager(manifest_path)
- self._manifest.init_index()
+class VideoReaderWithManifest:
+ # TODO: merge this class with VideoReader
+
+ def __init__(self, manifest_path: str, source_path: str, *, allow_threading: bool = False):
+ self.source_path = source_path
+ self.manifest = VideoManifestManager(manifest_path)
+ if self.manifest.exists:
+ self.manifest.init_index()
+
+ self.allow_threading = allow_threading
+
+ def _read_av_container(self) -> ContextManager[av.container.InputContainer]:
+ return _AvVideoReading().read_av_container(self.source_path)
+
+ def _decode_stream(
+ self, container: av.container.Container, video_stream: av.video.stream.VideoStream
+ ) -> Generator[av.VideoFrame, None, None]:
+ return _AvVideoReading().decode_stream(container, video_stream)
- def _get_nearest_left_key_frame(self):
- if self._start_chunk_frame_number >= \
- self._manifest[len(self._manifest) - 1].get('number'):
- left_border = len(self._manifest) - 1
+ def _get_nearest_left_key_frame(self, frame_id: int) -> tuple[int, int]:
+ nearest_left_keyframe_pos = bisect(
+ self.manifest, frame_id, key=lambda entry: entry.get('number')
+ )
+ if nearest_left_keyframe_pos:
+ frame_number = self.manifest[nearest_left_keyframe_pos - 1].get('number')
+ timestamp = self.manifest[nearest_left_keyframe_pos - 1].get('pts')
else:
- left_border = 0
- delta = len(self._manifest)
- while delta:
- step = delta // 2
- cur_position = left_border + step
- if self._manifest[cur_position].get('number') < self._start_chunk_frame_number:
- cur_position += 1
- left_border = cur_position
- delta -= step + 1
- else:
- delta = step
- if self._manifest[cur_position].get('number') > self._start_chunk_frame_number:
- left_border -= 1
- frame_number = self._manifest[left_border].get('number')
- timestamp = self._manifest[left_border].get('pts')
+ frame_number = 0
+ timestamp = 0
return frame_number, timestamp
- def __iter__(self):
- start_decode_frame_number, start_decode_timestamp = self._get_nearest_left_key_frame()
- with closing(av.open(self.source_path, mode='r')) as container:
- video_stream = next(stream for stream in container.streams if stream.type == 'video')
- video_stream.thread_type = 'AUTO'
+ def iterate_frames(self, *, frame_filter: Iterable[int]) -> Iterable[av.VideoFrame]:
+ "frame_ids must be an ordered sequence in the ascending order"
+
+ frame_filter_iter = iter(frame_filter)
+ next_frame_filter_frame = next(frame_filter_iter, None)
+ if next_frame_filter_frame is None:
+ return
+
+ start_decode_frame_number, start_decode_timestamp = self._get_nearest_left_key_frame(
+ next_frame_filter_frame
+ )
+
+ with self._read_av_container() as container:
+ video_stream = container.streams.video[0]
+ if self.allow_threading:
+ video_stream.thread_type = 'AUTO'
container.seek(offset=start_decode_timestamp, stream=video_stream)
- frame_number = start_decode_frame_number - 1
- for packet in container.demux(video_stream):
- for frame in packet.decode():
- frame_number += 1
- if frame_number in self._frame_range:
+ frame_counter = itertools.count(start_decode_frame_number)
+ with closing(self._decode_stream(container, video_stream)) as stream_decoder:
+ for frame, frame_number in zip(stream_decoder, frame_counter):
+ if frame_number == next_frame_filter_frame:
if video_stream.metadata.get('rotate'):
frame = av.VideoFrame().from_ndarray(
rotate_image(
frame.to_ndarray(format='bgr24'),
- 360 - int(container.streams.video[0].metadata.get('rotate'))
+ 360 - int(video_stream.metadata.get('rotate'))
),
format ='bgr24'
)
+
yield frame
- elif frame_number < self._frame_range[-1]:
- continue
- else:
+
+ next_frame_filter_frame = next(frame_filter_iter, None)
+
+ if next_frame_filter_frame is None:
return
class IChunkWriter(ABC):
@@ -648,33 +867,37 @@ class ZipChunkWriter(IChunkWriter):
POINT_CLOUD_EXT = 'pcd'
def _write_pcd_file(self, image: str|io.BytesIO) -> tuple[io.BytesIO, str, int, int]:
- image_buf = open(image, "rb") if isinstance(image, str) else image
- try:
+ with ExitStack() as es:
+ if isinstance(image, str):
+ image_buf = es.enter_context(open(image, "rb"))
+ else:
+ image_buf = image
+
properties = ValidateDimension.get_pcd_properties(image_buf)
w, h = int(properties["WIDTH"]), int(properties["HEIGHT"])
image_buf.seek(0, 0)
return io.BytesIO(image_buf.read()), self.POINT_CLOUD_EXT, w, h
- finally:
- if isinstance(image, str):
- image_buf.close()
- def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str):
+ def save_as_chunk(self, images: Iterator[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str):
with zipfile.ZipFile(chunk_path, 'x') as zip_chunk:
for idx, (image, path, _) in enumerate(images):
ext = os.path.splitext(path)[1].replace('.', '')
- output = io.BytesIO()
+
if self._dimension == DimensionType.DIM_2D:
# current version of Pillow applies exif rotation immediately when TIFF image opened
# and it removes rotation tag after that
# so, has_exif_rotation(image) will return False for TIFF images even if they were actually rotated
# and original files will be added to the archive (without applied rotation)
# that is why we need the second part of the condition
- if has_exif_rotation(image) or image.format == 'TIFF':
+ if isinstance(image, Image.Image) and (
+ has_exif_rotation(image) or image.format == 'TIFF'
+ ):
+ output = io.BytesIO()
rot_image = ImageOps.exif_transpose(image)
try:
if image.format == 'TIFF':
# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html
- # use loseless lzw compression for tiff images
+ # use lossless lzw compression for tiff images
rot_image.save(output, format='TIFF', compression='tiff_lzw')
else:
rot_image.save(
@@ -686,16 +909,22 @@ def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, s
)
finally:
rot_image.close()
+ elif isinstance(image, io.IOBase):
+ output = image
else:
output = path
else:
- output, ext = self._write_pcd_file(path)[0:2]
- arcname = '{:06d}.{}'.format(idx, ext)
+ if isinstance(image, io.BytesIO):
+ output, ext = self._write_pcd_file(image)[0:2]
+ else:
+ output, ext = self._write_pcd_file(path)[0:2]
+ arcname = '{:06d}.{}'.format(idx, ext)
if isinstance(output, io.BytesIO):
zip_chunk.writestr(arcname, output.getvalue())
else:
zip_chunk.write(filename=output, arcname=arcname)
+
# return empty list because ZipChunkWriter write files as is
# and does not decode it to know img size.
return []
@@ -703,7 +932,7 @@ def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, s
class ZipCompressedChunkWriter(ZipChunkWriter):
def save_as_chunk(
self,
- images: Iterable[tuple[Image.Image|io.IOBase|str, str, str]],
+ images: Iterator[tuple[Image.Image|io.IOBase|str, str, str]],
chunk_path: str, *, compress_frames: bool = True, zip_compress_level: int = 0
):
image_sizes = []
@@ -719,7 +948,11 @@ def save_as_chunk(
w, h = img.size
extension = self.IMAGE_EXT
else:
- image_buf, extension, w, h = self._write_pcd_file(path)
+ if isinstance(image, io.BytesIO):
+ image_buf, extension, w, h = self._write_pcd_file(image)
+ else:
+ image_buf, extension, w, h = self._write_pcd_file(path)
+
image_sizes.append((w, h))
arcname = '{:06d}.{}'.format(idx, extension)
zip_chunk.writestr(arcname, image_buf.getvalue())
@@ -751,7 +984,7 @@ def __init__(self, quality=67):
"preset": "ultrafast",
}
- def _add_video_stream(self, container, w, h, rate, options):
+ def _add_video_stream(self, container: av.container.OutputContainer, w, h, rate, options):
# x264 requires width and height must be divisible by 2 for yuv420p
if h % 2:
h += 1
@@ -772,12 +1005,28 @@ def _add_video_stream(self, container, w, h, rate, options):
return video_stream
- def save_as_chunk(self, images, chunk_path):
- if not images:
+ FrameDescriptor = Tuple[av.VideoFrame, Any, Any]
+
+ def _peek_first_frame(
+ self, frame_iter: Iterator[FrameDescriptor]
+ ) -> Tuple[Optional[FrameDescriptor], Iterator[FrameDescriptor]]:
+ "Gets the first frame and returns the same full iterator"
+
+ if not hasattr(frame_iter, '__next__'):
+ frame_iter = iter(frame_iter)
+
+ first_frame = next(frame_iter, None)
+ return first_frame, itertools.chain((first_frame, ), frame_iter)
+
+ def save_as_chunk(
+ self, images: Iterator[FrameDescriptor], chunk_path: str
+ ) -> Sequence[Tuple[int, int]]:
+ first_frame, images = self._peek_first_frame(images)
+ if not first_frame:
raise Exception('no images to save')
- input_w = images[0][0].width
- input_h = images[0][0].height
+ input_w = first_frame[0].width
+ input_h = first_frame[0].height
with av.open(chunk_path, 'w', format=self.FORMAT) as output_container:
output_v_stream = self._add_video_stream(
@@ -788,11 +1037,15 @@ def save_as_chunk(self, images, chunk_path):
options=self._codec_opts,
)
- self._encode_images(images, output_container, output_v_stream)
+ with closing(output_v_stream):
+ self._encode_images(images, output_container, output_v_stream)
+
return [(input_w, input_h)]
@staticmethod
- def _encode_images(images, container, stream):
+ def _encode_images(
+ images, container: av.container.OutputContainer, stream: av.video.stream.VideoStream
+ ):
for frame, _, _ in images:
# let libav set the correct pts and time_base
frame.pts = None
@@ -818,11 +1071,12 @@ def __init__(self, quality):
}
def save_as_chunk(self, images, chunk_path):
- if not images:
+ first_frame, images = self._peek_first_frame(images)
+ if not first_frame:
raise Exception('no images to save')
- input_w = images[0][0].width
- input_h = images[0][0].height
+ input_w = first_frame[0].width
+ input_h = first_frame[0].height
downscale_factor = 1
while input_h / downscale_factor >= 1080:
@@ -840,7 +1094,9 @@ def save_as_chunk(self, images, chunk_path):
options=self._codec_opts,
)
- self._encode_images(images, output_container, output_v_stream)
+ with closing(output_v_stream):
+ self._encode_images(images, output_container, output_v_stream)
+
return [(input_w, input_h)]
def _is_archive(path):
diff --git a/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py
new file mode 100644
index 000000000000..8ef887d4c54b
--- /dev/null
+++ b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py
@@ -0,0 +1,118 @@
+# Generated by Django 4.2.13 on 2024-08-12 09:49
+
+import os
+from itertools import islice
+from typing import Iterable, TypeVar
+
+from django.db import migrations
+
+from cvat.apps.engine.log import get_migration_log_dir, get_migration_logger
+
+T = TypeVar("T")
+
+
+def take_by(iterable: Iterable[T], count: int) -> Iterable[T]:
+ """
+ Returns elements from the input iterable by batches of N items.
+ ('abcdefg', 3) -> ['a', 'b', 'c'], ['d', 'e', 'f'], ['g']
+ """
+
+ it = iter(iterable)
+ while True:
+ batch = list(islice(it, count))
+ if len(batch) == 0:
+ break
+
+ yield batch
+
+
+def get_migration_name() -> str:
+ return os.path.splitext(os.path.basename(__file__))[0]
+
+
+def get_updated_ids_filename(log_dir: str, migration_name: str) -> str:
+ return os.path.join(log_dir, migration_name + "-data_ids.log")
+
+
+MIGRATION_LOG_HEADER = (
+ 'The following Data ids have been switched from using "filesystem" chunk storage ' 'to "cache":'
+)
+
+
+def switch_tasks_with_static_chunks_to_dynamic_chunks(apps, schema_editor):
+ migration_name = get_migration_name()
+ migration_log_dir = get_migration_log_dir()
+ with get_migration_logger(migration_name) as common_logger:
+ Data = apps.get_model("engine", "Data")
+
+ data_with_static_cache_query = Data.objects.filter(storage_method="file_system")
+
+ data_with_static_cache_ids = list(
+ v[0]
+ for v in (
+ data_with_static_cache_query.order_by("id")
+ .values_list("id")
+ .iterator(chunk_size=100000)
+ )
+ )
+
+ data_with_static_cache_query.update(storage_method="cache")
+
+ updated_ids_filename = get_updated_ids_filename(migration_log_dir, migration_name)
+ with open(updated_ids_filename, "w") as data_ids_file:
+ print(MIGRATION_LOG_HEADER, file=data_ids_file)
+
+ for data_id in data_with_static_cache_ids:
+ print(data_id, file=data_ids_file)
+
+ common_logger.info(
+ "Information about migrated tasks is available in the migration log file: "
+ "{}. You will need to remove data manually for these tasks.".format(
+ updated_ids_filename
+ )
+ )
+
+
+def revert_switch_tasks_with_static_chunks_to_dynamic_chunks(apps, schema_editor):
+ migration_name = get_migration_name()
+ migration_log_dir = get_migration_log_dir()
+
+ updated_ids_filename = get_updated_ids_filename(migration_log_dir, migration_name)
+ if not os.path.isfile(updated_ids_filename):
+ raise FileNotFoundError(
+ "Can't revert the migration: can't file forward migration logfile at "
+ f"'{updated_ids_filename}'."
+ )
+
+ with open(updated_ids_filename, "r") as data_ids_file:
+ header = data_ids_file.readline().strip()
+ if header != MIGRATION_LOG_HEADER:
+ raise ValueError(
+ "Can't revert the migration: the migration log file has unexpected header"
+ )
+
+ forward_updated_ids = tuple(map(int, data_ids_file))
+
+ if not forward_updated_ids:
+ return
+
+ Data = apps.get_model("engine", "Data")
+
+ for id_batch in take_by(forward_updated_ids, 1000):
+ Data.objects.filter(storage_method="cache", id__in=id_batch).update(
+ storage_method="file_system"
+ )
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("engine", "0082_alter_labeledimage_job_and_more"),
+ ]
+
+ operations = [
+ migrations.RunPython(
+ switch_tasks_with_static_chunks_to_dynamic_chunks,
+ reverse_code=revert_switch_tasks_with_static_chunks_to_dynamic_chunks,
+ )
+ ]
diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py
index eda765e6bebb..c57eb0371d5e 100644
--- a/cvat/apps/engine/models.py
+++ b/cvat/apps/engine/models.py
@@ -252,6 +252,13 @@ def get_data_dirname(self):
def get_upload_dirname(self):
return os.path.join(self.get_data_dirname(), "raw")
+ def get_raw_data_dirname(self) -> str:
+ return {
+ StorageChoice.LOCAL: self.get_upload_dirname(),
+ StorageChoice.SHARE: settings.SHARE_ROOT,
+ StorageChoice.CLOUD_STORAGE: self.get_upload_dirname(),
+ }[self.storage]
+
def get_compressed_cache_dirname(self):
return os.path.join(self.get_data_dirname(), "compressed")
@@ -259,7 +266,7 @@ def get_original_cache_dirname(self):
return os.path.join(self.get_data_dirname(), "original")
@staticmethod
- def _get_chunk_name(chunk_number, chunk_type):
+ def _get_chunk_name(segment_id: int, chunk_number: int, chunk_type: DataChoice | str) -> str:
if chunk_type == DataChoice.VIDEO:
ext = 'mp4'
elif chunk_type == DataChoice.IMAGESET:
@@ -267,21 +274,21 @@ def _get_chunk_name(chunk_number, chunk_type):
else:
ext = 'list'
- return '{}.{}'.format(chunk_number, ext)
+ return 'segment_{}-{}.{}'.format(segment_id, chunk_number, ext)
- def _get_compressed_chunk_name(self, chunk_number):
- return self._get_chunk_name(chunk_number, self.compressed_chunk_type)
+ def _get_compressed_chunk_name(self, segment_id: int, chunk_number: int) -> str:
+ return self._get_chunk_name(segment_id, chunk_number, self.compressed_chunk_type)
- def _get_original_chunk_name(self, chunk_number):
- return self._get_chunk_name(chunk_number, self.original_chunk_type)
+ def _get_original_chunk_name(self, segment_id: int, chunk_number: int) -> str:
+ return self._get_chunk_name(segment_id, chunk_number, self.original_chunk_type)
- def get_original_chunk_path(self, chunk_number):
+ def get_original_segment_chunk_path(self, chunk_number: int, segment_id: int) -> str:
return os.path.join(self.get_original_cache_dirname(),
- self._get_original_chunk_name(chunk_number))
+ self._get_original_chunk_name(segment_id, chunk_number))
- def get_compressed_chunk_path(self, chunk_number):
+ def get_compressed_segment_chunk_path(self, chunk_number: int, segment_id: int) -> str:
return os.path.join(self.get_compressed_cache_dirname(),
- self._get_compressed_chunk_name(chunk_number))
+ self._get_compressed_chunk_name(segment_id, chunk_number))
def get_manifest_path(self):
return os.path.join(self.get_upload_dirname(), 'manifest.jsonl')
@@ -600,7 +607,7 @@ def __str__(self):
class Segment(models.Model):
# Common fields
- task = models.ForeignKey(Task, on_delete=models.CASCADE)
+ task = models.ForeignKey(Task, on_delete=models.CASCADE) # TODO: add related name
start_frame = models.IntegerField()
stop_frame = models.IntegerField()
type = models.CharField(choices=SegmentType.choices(), default=SegmentType.RANGE, max_length=32)
diff --git a/cvat/apps/engine/permissions.py b/cvat/apps/engine/permissions.py
index 70c147a4661f..dacc57f836d4 100644
--- a/cvat/apps/engine/permissions.py
+++ b/cvat/apps/engine/permissions.py
@@ -210,6 +210,7 @@ class Scopes(StrEnum):
UPDATE_ASSIGNEE = 'update:assignee'
UPDATE_DESC = 'update:desc'
UPDATE_ORG = 'update:organization'
+ UPDATE_ASSOCIATED_STORAGE = 'update:associated_storage'
VIEW = 'view'
IMPORT_DATASET = 'import:dataset'
EXPORT_ANNOTATIONS = 'export:annotations'
@@ -285,6 +286,9 @@ def get_scopes(request, view, obj):
scopes = []
if scope == Scopes.UPDATE:
+ # user should have permissions to view a project
+ scopes.append(Scopes.VIEW)
+
if any(k in request.data for k in ('owner_id', 'owner')):
owner_id = request.data.get('owner_id') or request.data.get('owner')
if owner_id != getattr(obj.owner, 'id', None):
@@ -299,6 +303,9 @@ def get_scopes(request, view, obj):
break
if 'organization' in request.data:
scopes.append(Scopes.UPDATE_ORG)
+
+ if {'source_storage', 'target_storage'} & request.data.keys():
+ scopes.append(Scopes.UPDATE_ASSOCIATED_STORAGE)
else:
scopes.append(scope)
@@ -376,6 +383,7 @@ class Scopes(StrEnum):
UPDATE_ASSIGNEE = 'update:assignee'
UPDATE_PROJECT = 'update:project'
UPDATE_OWNER = 'update:owner'
+ UPDATE_ASSOCIATED_STORAGE = 'update:associated_storage'
DELETE = 'delete'
VIEW_ANNOTATIONS = 'view:annotations'
UPDATE_ANNOTATIONS = 'update:annotations'
@@ -509,6 +517,8 @@ def get_scopes(request, view, obj) -> List[Scopes]:
scopes.append(scope)
elif scope == Scopes.UPDATE:
+ # user should have permissions to view a task
+ scopes.append(Scopes.VIEW)
if any(k in request.data for k in ('owner_id', 'owner')):
owner_id = request.data.get('owner_id') or request.data.get('owner')
if owner_id != getattr(obj.owner, 'id', None):
@@ -530,6 +540,9 @@ def get_scopes(request, view, obj) -> List[Scopes]:
if request.data.get('organization'):
scopes.append(Scopes.UPDATE_ORGANIZATION)
+ if {'source_storage', 'target_storage'} & request.data.keys():
+ scopes.append(Scopes.UPDATE_ASSOCIATED_STORAGE)
+
elif scope == Scopes.VIEW_ANNOTATIONS:
if 'format' in request.query_params:
scope = Scopes.EXPORT_ANNOTATIONS
@@ -614,11 +627,8 @@ class Scopes(StrEnum):
VIEW = 'view'
UPDATE = 'update'
UPDATE_ASSIGNEE = 'update:assignee'
- UPDATE_OWNER = 'update:owner'
- UPDATE_PROJECT = 'update:project'
UPDATE_STAGE = 'update:stage'
UPDATE_STATE = 'update:state'
- UPDATE_DESC = 'update:desc'
DELETE = 'delete'
VIEW_ANNOTATIONS = 'view:annotations'
UPDATE_ANNOTATIONS = 'update:annotations'
@@ -728,25 +738,19 @@ def get_scopes(request, view, obj):
scopes = []
if scope == Scopes.UPDATE:
- if any(k in request.data for k in ('owner_id', 'owner')):
- owner_id = request.data.get('owner_id') or request.data.get('owner')
- if owner_id != getattr(obj.owner, 'id', None):
- scopes.append(Scopes.UPDATE_OWNER)
+ # user should have permissions to view a job
+ scopes.append(Scopes.VIEW)
+
if any(k in request.data for k in ('assignee_id', 'assignee')):
assignee_id = request.data.get('assignee_id') or request.data.get('assignee')
if assignee_id != getattr(obj.assignee, 'id', None):
scopes.append(Scopes.UPDATE_ASSIGNEE)
- if any(k in request.data for k in ('project_id', 'project')):
- project_id = request.data.get('project_id') or request.data.get('project')
- if project_id != getattr(obj.project, 'id', None):
- scopes.append(Scopes.UPDATE_PROJECT)
+
if 'stage' in request.data:
scopes.append(Scopes.UPDATE_STAGE)
if 'state' in request.data:
scopes.append(Scopes.UPDATE_STATE)
- if any(k in request.data for k in ('name', 'labels', 'bug_tracker', 'subset')):
- scopes.append(Scopes.UPDATE_DESC)
elif scope == Scopes.VIEW_ANNOTATIONS:
if 'format' in request.query_params:
scope = Scopes.EXPORT_ANNOTATIONS
diff --git a/cvat/apps/engine/pyproject.toml b/cvat/apps/engine/pyproject.toml
new file mode 100644
index 000000000000..567b78362580
--- /dev/null
+++ b/cvat/apps/engine/pyproject.toml
@@ -0,0 +1,12 @@
+[tool.isort]
+profile = "black"
+forced_separate = ["tests"]
+line_length = 100
+skip_gitignore = true # align tool behavior with Black
+known_first_party = ["cvat"]
+
+# Can't just use a pyproject in the root dir, so duplicate
+# https://github.com/psf/black/issues/2863
+[tool.black]
+line-length = 100
+target-version = ['py38']
diff --git a/cvat/apps/engine/rules/projects.rego b/cvat/apps/engine/rules/projects.rego
index dadebdc894ad..8e40ddc43c8d 100644
--- a/cvat/apps/engine/rules/projects.rego
+++ b/cvat/apps/engine/rules/projects.rego
@@ -7,7 +7,7 @@ import data.organizations
# input: {
# "scope": <"create"|"list"|"update:desc"|"update:owner"|"update:assignee"|
-# "view"|"delete"|"export:dataset"|"export:annotations"|
+# "update:associated_storage"|"view"|"delete"|"export:dataset"|"export:annotations"|
# "import:dataset"> or null,
# "auth": {
# "user": {
@@ -127,14 +127,14 @@ allow if {
allow if {
- input.scope in {utils.DELETE, utils.UPDATE_ORG}
+ input.scope in {utils.DELETE, utils.UPDATE_ORG, utils.UPDATE_ASSOCIATED_STORAGE}
utils.is_sandbox
utils.has_perm(utils.WORKER)
utils.is_resource_owner
}
allow if {
- input.scope in {utils.DELETE, utils.UPDATE_ORG}
+ input.scope in {utils.DELETE, utils.UPDATE_ORG, utils.UPDATE_ASSOCIATED_STORAGE}
input.auth.organization.id == input.resource.organization.id
utils.has_perm(utils.WORKER)
organizations.is_member
@@ -142,7 +142,7 @@ allow if {
}
allow if {
- input.scope in {utils.DELETE, utils.UPDATE_ORG}
+ input.scope in {utils.DELETE, utils.UPDATE_ORG, utils.UPDATE_ASSOCIATED_STORAGE}
input.auth.organization.id == input.resource.organization.id
utils.has_perm(utils.USER)
organizations.is_staff
diff --git a/cvat/apps/engine/rules/tasks.rego b/cvat/apps/engine/rules/tasks.rego
index 9f1b7fa951a9..7f7d592bdb01 100644
--- a/cvat/apps/engine/rules/tasks.rego
+++ b/cvat/apps/engine/rules/tasks.rego
@@ -7,8 +7,8 @@ import data.organizations
# input: {
# "scope": <"create"|"create@project"|"view"|"list"|"update:desc"|
-# "update:owner"|"update:assignee"|"update:project"|"delete"|
-# "view:annotations"|"update:annotations"|"delete:annotations"|
+# "update:owner"|"update:assignee"|"update:project"|"update:associated_storage"|
+# "delete"|"view:annotations"|"update:annotations"|"delete:annotations"|
# "export:dataset"|"view:data"|"upload:data"|"export:annotations"> or null,
# "auth": {
# "user": {
@@ -250,10 +250,17 @@ allow if {
utils.has_perm(utils.WORKER)
}
+allow if {
+ input.scope == utils.UPDATE_ASSOCIATED_STORAGE
+ utils.is_sandbox
+ is_project_owner
+ utils.has_perm(utils.WORKER)
+}
+
allow if {
input.scope in {
utils.UPDATE_OWNER, utils.UPDATE_ASSIGNEE, utils.UPDATE_PROJECT,
- utils.DELETE, utils.UPDATE_ORG
+ utils.DELETE, utils.UPDATE_ORG, utils.UPDATE_ASSOCIATED_STORAGE
}
utils.is_sandbox
is_task_owner
@@ -263,7 +270,7 @@ allow if {
allow if {
input.scope in {
utils.UPDATE_OWNER, utils.UPDATE_ASSIGNEE, utils.UPDATE_PROJECT,
- utils.DELETE, utils.UPDATE_ORG
+ utils.DELETE, utils.UPDATE_ORG, utils.UPDATE_ASSOCIATED_STORAGE
}
input.auth.organization.id == input.resource.organization.id
utils.has_perm(utils.USER)
@@ -273,13 +280,20 @@ allow if {
allow if {
input.scope in {
utils.UPDATE_OWNER, utils.UPDATE_ASSIGNEE, utils.UPDATE_PROJECT,
- utils.DELETE, utils.UPDATE_ORG
+ utils.DELETE, utils.UPDATE_ORG, utils.UPDATE_ASSOCIATED_STORAGE
}
input.auth.organization.id == input.resource.organization.id
utils.has_perm(utils.WORKER)
organizations.has_perm(organizations.WORKER)
is_task_owner
}
+allow if {
+ input.scope == utils.UPDATE_ASSOCIATED_STORAGE
+ input.auth.organization.id == input.resource.organization.id
+ utils.has_perm(utils.WORKER)
+ organizations.has_perm(organizations.WORKER)
+ is_project_owner
+}
allow if {
input.scope in {
diff --git a/cvat/apps/engine/rules/tests/configs/projects.csv b/cvat/apps/engine/rules/tests/configs/projects.csv
index bfe3bbff08ad..5d42ba4a9234 100644
--- a/cvat/apps/engine/rules/tests/configs/projects.csv
+++ b/cvat/apps/engine/rules/tests/configs/projects.csv
@@ -43,4 +43,9 @@ export:backup,Project,Organization,None,,GET,/projects/{id}/backup,User,Maintain
update:organization,"Project, Organization",Sandbox,"None, Assignee",,PATCH,/projects/{id},Admin,N/A
update:organization,"Project, Organization",Sandbox,Owner,,PATCH,/projects/{id},Worker,N/A
update:organization,"Project, Organization",Organization,"None, Assignee",,PATCH,/projects/{id},User,Maintainer
-update:organization,"Project, Organization",Organization,Owner,,PATCH,/projects/{id},Worker,Worker
\ No newline at end of file
+update:organization,"Project, Organization",Organization,Owner,,PATCH,/projects/{id},Worker,Worker
+update:associated_storage,Project,Sandbox,None,,PATCH,/projects/{id},Admin,N/A
+update:associated_storage,Project,Sandbox,Owner,,PATCH,/projects/{id},Worker,N/A
+update:associated_storage,Project,Organization,None,,PATCH,/projects/{id},Admin,N/A
+update:associated_storage,Project,Organization,"None, Assignee",,PATCH,/projects/{id},User,Maintainer
+update:associated_storage,Project,Organization,Owner,,PATCH,/projects/{id},Worker,Worker
diff --git a/cvat/apps/engine/rules/tests/configs/tasks.csv b/cvat/apps/engine/rules/tests/configs/tasks.csv
index a748f5efb2df..e5c17c348104 100644
--- a/cvat/apps/engine/rules/tests/configs/tasks.csv
+++ b/cvat/apps/engine/rules/tests/configs/tasks.csv
@@ -29,6 +29,11 @@ update:project,"Task, Project",Sandbox,"None, Assignee",,PATCH,/tasks/{id},Admin
update:project,"Task, Project",Sandbox,"Owner, Project:owner, Project:assignee",,PATCH,/tasks/{id},Worker,N/A
update:project,"Task, Project",Organization,"None, Assignee",,PATCH,/tasks/{id},User,Maintainer
update:project,"Task, Project",Organization,"Owner, Project:owner, Project:assignee",,PATCH,/tasks/{id},Worker,Worker
+update:associated_storage,Task,Sandbox,None,,PATCH,/tasks/{id},Admin,N/A
+update:associated_storage,Task,Sandbox,"Owner, Project:owner",,PATCH,/tasks/{id},Worker,N/A
+update:associated_storage,Task,Organization,None,,PATCH,/tasks/{id},Admin,N/A
+update:associated_storage,Task,Organization,"None, Assignee, Project:assignee",,PATCH,/tasks/{id},User,Maintainer
+update:associated_storage,Task,Organization,"Owner, Project:owner",,PATCH,/tasks/{id},Worker,Worker
delete,Task,Sandbox,"None, Assignee",,DELETE,/tasks/{id},Admin,N/A
delete,Task,Sandbox,"Owner, Project:owner, Project:assignee",,DELETE,/tasks/{id},Worker,N/A
delete,Task,Organization,"None, Assignee",,DELETE,/tasks/{id},User,Maintainer
diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py
index 9d66b1716c17..ed937a993ffe 100644
--- a/cvat/apps/engine/serializers.py
+++ b/cvat/apps/engine/serializers.py
@@ -594,6 +594,7 @@ class JobReadSerializer(serializers.ModelSerializer):
dimension = serializers.CharField(max_length=2, source='segment.task.dimension', read_only=True)
data_chunk_size = serializers.ReadOnlyField(source='segment.task.data.chunk_size')
organization = serializers.ReadOnlyField(source='segment.task.organization.id', allow_null=True)
+ data_original_chunk_type = serializers.ReadOnlyField(source='segment.task.data.original_chunk_type')
data_compressed_chunk_type = serializers.ReadOnlyField(source='segment.task.data.compressed_chunk_type')
mode = serializers.ReadOnlyField(source='segment.task.mode')
bug_tracker = serializers.CharField(max_length=2000, source='get_bug_tracker',
@@ -607,7 +608,8 @@ class Meta:
model = models.Job
fields = ('url', 'id', 'task_id', 'project_id', 'assignee', 'guide_id',
'dimension', 'bug_tracker', 'status', 'stage', 'state', 'mode', 'frame_count',
- 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type',
+ 'start_frame', 'stop_frame',
+ 'data_chunk_size', 'data_compressed_chunk_type', 'data_original_chunk_type',
'created_date', 'updated_date', 'issues', 'labels', 'type', 'organization',
'target_storage', 'source_storage', 'assignee_updated_date')
read_only_fields = fields
diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py
index 0db84cebc32b..f24cd686a587 100644
--- a/cvat/apps/engine/task.py
+++ b/cvat/apps/engine/task.py
@@ -1,32 +1,37 @@
# Copyright (C) 2018-2022 Intel Corporation
-# Copyright (C) 2022-2023 CVAT.ai Corporation
+# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
+import concurrent.futures
import itertools
import fnmatch
import os
-from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union, Iterable
-from rest_framework.serializers import ValidationError
-import rq
import re
+import rq
import shutil
+from contextlib import closing
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Union
from urllib import parse as urlparse
from urllib import request as urlrequest
-import django_rq
-import concurrent.futures
-import queue
+import av
+import attrs
+import django_rq
from django.conf import settings
from django.db import transaction
from django.http import HttpRequest
-from datetime import datetime, timezone
-from pathlib import Path
+from rest_framework.serializers import ValidationError
from cvat.apps.engine import models
from cvat.apps.engine.log import ServerLogManager
-from cvat.apps.engine.media_extractors import (MEDIA_TYPES, ImageListReader, Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter,
- ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort)
+from cvat.apps.engine.media_extractors import (
+ MEDIA_TYPES, CachingMediaIterator, IMediaReader, ImageListReader,
+ Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, RandomAccessIterator,
+ ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort
+)
from cvat.apps.engine.models import RequestAction, RequestTarget
from cvat.apps.engine.utils import (
av_scan_paths,get_rq_job_meta, define_dependent_job, get_rq_lock_by_user, preload_images
@@ -71,6 +76,8 @@ def create(
class SegmentParams(NamedTuple):
start_frame: int
stop_frame: int
+ type: models.SegmentType = models.SegmentType.RANGE
+ frames: Optional[Sequence[int]] = []
class SegmentsParams(NamedTuple):
segments: Iterator[SegmentParams]
@@ -116,7 +123,7 @@ def _copy_data_from_share_point(
os.makedirs(target_dir)
shutil.copyfile(source_path, target_path)
-def _get_task_segment_data(
+def _generate_segment_params(
db_task: models.Task,
*,
data_size: Optional[int] = None,
@@ -127,10 +134,14 @@ def _segments():
# It is assumed here that files are already saved ordered in the task
# Here we just need to create segments by the job sizes
start_frame = 0
- for jf in job_file_mapping:
- segment_size = len(jf)
+ for job_files in job_file_mapping:
+ segment_size = len(job_files)
stop_frame = start_frame + segment_size - 1
- yield SegmentParams(start_frame, stop_frame)
+ yield SegmentParams(
+ start_frame=start_frame,
+ stop_frame=stop_frame,
+ type=models.SegmentType.RANGE,
+ )
start_frame = stop_frame + 1
@@ -153,31 +164,39 @@ def _segments():
)
segments = (
- SegmentParams(start_frame, min(start_frame + segment_size - 1, data_size - 1))
+ SegmentParams(
+ start_frame=start_frame,
+ stop_frame=min(start_frame + segment_size - 1, data_size - 1),
+ type=models.SegmentType.RANGE
+ )
for start_frame in range(0, data_size - overlap, segment_size - overlap)
)
return SegmentsParams(segments, segment_size, overlap)
-def _save_task_to_db(db_task: models.Task, *, job_file_mapping: Optional[JobFileMapping] = None):
- job = rq.get_current_job()
- job.meta['status'] = 'Task is being saved in database'
- job.save_meta()
+def _create_segments_and_jobs(
+ db_task: models.Task,
+ *,
+ job_file_mapping: Optional[JobFileMapping] = None,
+):
+ rq_job = rq.get_current_job()
+ rq_job.meta['status'] = 'Task is being saved in database'
+ rq_job.save_meta()
- segments, segment_size, overlap = _get_task_segment_data(
- db_task=db_task, job_file_mapping=job_file_mapping
+ segments, segment_size, overlap = _generate_segment_params(
+ db_task=db_task, job_file_mapping=job_file_mapping,
)
db_task.segment_size = segment_size
db_task.overlap = overlap
- for segment_idx, (start_frame, stop_frame) in enumerate(segments):
- slogger.glob.info("New segment for task #{}: idx = {}, start_frame = {}, \
- stop_frame = {}".format(db_task.id, segment_idx, start_frame, stop_frame))
+ for segment_idx, segment_params in enumerate(segments):
+ slogger.glob.info(
+ "New segment for task #{task_id}: idx = {segment_idx}, start_frame = {start_frame}, "
+ "stop_frame = {stop_frame}".format(
+ task_id=db_task.id, segment_idx=segment_idx, **segment_params._asdict()
+ ))
- db_segment = models.Segment()
- db_segment.task = db_task
- db_segment.start_frame = start_frame
- db_segment.stop_frame = stop_frame
+ db_segment = models.Segment(task=db_task, **segment_params._asdict())
db_segment.save()
db_job = models.Job(segment=db_segment)
@@ -322,48 +341,28 @@ def _validate_manifest(
*,
is_in_cloud: bool,
db_cloud_storage: Optional[Any],
- data_storage_method: str,
- data_sorting_method: str,
- isBackupRestore: bool,
) -> Optional[str]:
- if manifests:
- if len(manifests) != 1:
- raise ValidationError('Only one manifest file can be attached to data')
- manifest_file = manifests[0]
- full_manifest_path = os.path.join(root_dir, manifests[0])
-
- if is_in_cloud:
- cloud_storage_instance = db_storage_to_storage_instance(db_cloud_storage)
- # check that cloud storage manifest file exists and is up to date
- if not os.path.exists(full_manifest_path) or \
- datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) \
- < cloud_storage_instance.get_file_last_modified(manifest_file):
- cloud_storage_instance.download_file(manifest_file, full_manifest_path)
-
- if is_manifest(full_manifest_path):
- if not (
- data_sorting_method == models.SortingMethod.PREDEFINED or
- (settings.USE_CACHE and data_storage_method == models.StorageMethodChoice.CACHE) or
- isBackupRestore or is_in_cloud
- ):
- cache_disabled_message = ""
- if data_storage_method == models.StorageMethodChoice.CACHE and not settings.USE_CACHE:
- cache_disabled_message = (
- "This server doesn't allow to use cache for data. "
- "Please turn 'use cache' off and try to recreate the task"
- )
- slogger.glob.warning(cache_disabled_message)
-
- raise ValidationError(
- "A manifest file can only be used with the 'use cache' option "
- "or when 'sorting_method' is 'predefined'" + \
- (". " + cache_disabled_message if cache_disabled_message else "")
- )
- return manifest_file
+ if not manifests:
+ return None
+ if len(manifests) != 1:
+ raise ValidationError('Only one manifest file can be attached to data')
+ manifest_file = manifests[0]
+ full_manifest_path = os.path.join(root_dir, manifests[0])
+
+ if is_in_cloud:
+ cloud_storage_instance = db_storage_to_storage_instance(db_cloud_storage)
+ # check that cloud storage manifest file exists and is up to date
+ if not os.path.exists(full_manifest_path) or (
+ datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) \
+ < cloud_storage_instance.get_file_last_modified(manifest_file)
+ ):
+ cloud_storage_instance.download_file(manifest_file, full_manifest_path)
+
+ if not is_manifest(full_manifest_path):
raise ValidationError('Invalid manifest was uploaded')
- return None
+ return manifest_file
def _validate_scheme(url):
ALLOWED_SCHEMES = ['http', 'https']
@@ -522,18 +521,18 @@ def _create_thread(
slogger.glob.info("create task #{}".format(db_task.id))
- job_file_mapping = _validate_job_file_mapping(db_task, data)
-
- db_data = db_task.data
- upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT
- is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE
-
job = rq.get_current_job()
def _update_status(msg: str) -> None:
job.meta['status'] = msg
job.save_meta()
+ job_file_mapping = _validate_job_file_mapping(db_task, data)
+
+ db_data = db_task.data
+ upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT
+ is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE
+
if data['remote_files'] and not isDatasetImport:
data['remote_files'] = _download_data(data['remote_files'], upload_dir)
@@ -551,14 +550,17 @@ def _update_status(msg: str) -> None:
else:
assert False, f"Unknown file storage {db_data.storage}"
+ if (
+ db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM and
+ not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE
+ ):
+ db_data.storage_method = models.StorageMethodChoice.CACHE
+
manifest_file = _validate_manifest(
manifest_files,
manifest_root,
is_in_cloud=is_data_in_cloud,
db_cloud_storage=db_data.cloud_storage if is_data_in_cloud else None,
- data_storage_method=db_data.storage_method,
- data_sorting_method=data['sorting_method'],
- isBackupRestore=isBackupRestore,
)
manifest = None
@@ -668,14 +670,16 @@ def _update_status(msg: str) -> None:
is_media_sorted = False
if is_data_in_cloud:
- # first we need to filter files and keep only supported ones
- if any([v for k, v in media.items() if k != 'image']) and db_data.storage_method == models.StorageMethodChoice.CACHE:
- # FUTURE-FIXME: This is a temporary workaround for creating tasks
- # with unsupported cloud storage data (video, archive, pdf) when use_cache is enabled
- db_data.storage_method = models.StorageMethodChoice.FILE_SYSTEM
- _update_status("The 'use cache' option is ignored")
-
- if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE:
+ if (
+ # Download remote data if local storage is requested
+ # TODO: maybe move into cache building to fail faster on invalid task configurations
+ db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or
+
+ # Packed media must be downloaded for task creation
+ any(v for k, v in media.items() if k != 'image')
+ ):
+ _update_status("Downloading input media")
+
filtered_data = []
for files in (i for i in media.values() if i):
filtered_data.extend(files)
@@ -690,9 +694,11 @@ def _update_status(msg: str) -> None:
step = db_data.get_frame_step()
if start_frame or step != 1 or stop_frame != len(filtered_data) - 1:
media_to_download = filtered_data[start_frame : stop_frame + 1: step]
+
_download_data_from_cloud_storage(db_data.cloud_storage, media_to_download, upload_dir)
del media_to_download
del filtered_data
+
is_data_in_cloud = False
db_data.storage = models.StorageChoice.LOCAL
else:
@@ -757,7 +763,7 @@ def _update_status(msg: str) -> None:
)
# Extract input data
- extractor = None
+ extractor: Optional[IMediaReader] = None
manifest_index = _get_manifest_frame_indexer()
for media_type, media_files in media.items():
if not media_files:
@@ -917,38 +923,9 @@ def _update_status(msg: str) -> None:
db_data.compressed_chunk_type = models.DataChoice.VIDEO if task_mode == 'interpolation' and not data['use_zip_chunks'] else models.DataChoice.IMAGESET
db_data.original_chunk_type = models.DataChoice.VIDEO if task_mode == 'interpolation' else models.DataChoice.IMAGESET
- def update_progress(progress):
- progress_animation = '|/-\\'
- if not hasattr(update_progress, 'call_counter'):
- update_progress.call_counter = 0
-
- status_message = 'CVAT is preparing data chunks'
- if not progress:
- status_message = '{} {}'.format(status_message, progress_animation[update_progress.call_counter])
- job.meta['status'] = status_message
- job.meta['task_progress'] = progress or 0.
- job.save_meta()
- update_progress.call_counter = (update_progress.call_counter + 1) % len(progress_animation)
-
- compressed_chunk_writer_class = Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == models.DataChoice.VIDEO else ZipCompressedChunkWriter
- if db_data.original_chunk_type == models.DataChoice.VIDEO:
- original_chunk_writer_class = Mpeg4ChunkWriter
- # Let's use QP=17 (that is 67 for 0-100 range) for the original chunks, which should be visually lossless or nearly so.
- # A lower value will significantly increase the chunk size with a slight increase of quality.
- original_quality = 67
- else:
- original_chunk_writer_class = ZipChunkWriter
- original_quality = 100
-
- kwargs = {}
- if validate_dimension.dimension == models.DimensionType.DIM_3D:
- kwargs["dimension"] = validate_dimension.dimension
- compressed_chunk_writer = compressed_chunk_writer_class(db_data.image_quality, **kwargs)
- original_chunk_writer = original_chunk_writer_class(original_quality, **kwargs)
-
# calculate chunk size if it isn't specified
if db_data.chunk_size is None:
- if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter):
+ if db_data.compressed_chunk_type == models.DataChoice.IMAGESET:
first_image_idx = db_data.start_frame
if not is_data_in_cloud:
w, h = extractor.get_image_size(first_image_idx)
@@ -960,206 +937,317 @@ def update_progress(progress):
else:
db_data.chunk_size = 36
- video_path = ""
- video_size = (0, 0)
+ # TODO: try to pull up
+ # replace manifest file (e.g was uploaded 'subdir/manifest.jsonl' or 'some_manifest.jsonl')
+ if (manifest_file and not os.path.exists(db_data.get_manifest_path())):
+ shutil.copyfile(os.path.join(manifest_root, manifest_file),
+ db_data.get_manifest_path())
+ if manifest_root and manifest_root.startswith(db_data.get_upload_dirname()):
+ os.remove(os.path.join(manifest_root, manifest_file))
+ manifest_file = os.path.relpath(db_data.get_manifest_path(), upload_dir)
- db_images = []
+ # Create task frames from the metadata collected
+ video_path: str = ""
+ video_frame_size: tuple[int, int] = (0, 0)
- if settings.USE_CACHE and db_data.storage_method == models.StorageMethodChoice.CACHE:
- for media_type, media_files in media.items():
- if not media_files:
- continue
+ images: list[models.Image] = []
- # replace manifest file (e.g was uploaded 'subdir/manifest.jsonl' or 'some_manifest.jsonl')
- if manifest_file and not os.path.exists(db_data.get_manifest_path()):
- shutil.copyfile(os.path.join(manifest_root, manifest_file),
- db_data.get_manifest_path())
- if manifest_root and manifest_root.startswith(db_data.get_upload_dirname()):
- os.remove(os.path.join(manifest_root, manifest_file))
- manifest_file = os.path.relpath(db_data.get_manifest_path(), upload_dir)
+ for media_type, media_files in media.items():
+ if not media_files:
+ continue
- if task_mode == MEDIA_TYPES['video']['mode']:
+ if task_mode == MEDIA_TYPES['video']['mode']:
+ if manifest_file:
try:
- manifest_is_prepared = False
- if manifest_file:
- try:
- manifest = VideoManifestValidator(source_path=os.path.join(upload_dir, media_files[0]),
- manifest_path=db_data.get_manifest_path())
- manifest.init_index()
- manifest.validate_seek_key_frames()
- assert len(manifest) > 0, 'No key frames.'
-
- all_frames = manifest.video_length
- video_size = manifest.video_resolution
- manifest_is_prepared = True
- except Exception as ex:
- manifest.remove()
- if isinstance(ex, AssertionError):
- base_msg = str(ex)
- else:
- base_msg = 'Invalid manifest file was upload.'
- slogger.glob.warning(str(ex))
- _update_status('{} Start prepare a valid manifest file.'.format(base_msg))
-
- if not manifest_is_prepared:
- _update_status('Start prepare a manifest file')
- manifest = VideoManifestManager(db_data.get_manifest_path())
- manifest.link(
- media_file=media_files[0],
- upload_dir=upload_dir,
- chunk_size=db_data.chunk_size
- )
- manifest.create()
- _update_status('A manifest had been created')
+ _update_status('Validating the input manifest file')
- all_frames = len(manifest.reader)
- video_size = manifest.reader.resolution
- manifest_is_prepared = True
+ manifest = VideoManifestValidator(
+ source_path=os.path.join(upload_dir, media_files[0]),
+ manifest_path=db_data.get_manifest_path()
+ )
+ manifest.init_index()
+ manifest.validate_seek_key_frames()
+
+ if not len(manifest):
+ raise ValidationError("No key frames found in the manifest")
- db_data.size = len(range(db_data.start_frame, min(data['stop_frame'] + 1 \
- if data['stop_frame'] else all_frames, all_frames), db_data.get_frame_step()))
- video_path = os.path.join(upload_dir, media_files[0])
except Exception as ex:
- db_data.storage_method = models.StorageMethodChoice.FILE_SYSTEM
manifest.remove()
- del manifest
- base_msg = str(ex) if isinstance(ex, AssertionError) \
- else "Uploaded video does not support a quick way of task creating."
- _update_status("{} The task will be created using the old method".format(base_msg))
- else: # images, archive, pdf
- db_data.size = len(extractor)
- manifest = ImageManifestManager(db_data.get_manifest_path())
-
- if not manifest.exists:
+ manifest = None
+
+ if isinstance(ex, (ValidationError, AssertionError)):
+ base_msg = f"Invalid manifest file was uploaded: {ex}"
+ else:
+ base_msg = "Failed to parse the uploaded manifest file"
+ slogger.glob.warning(ex, exc_info=True)
+
+ _update_status(base_msg)
+ else:
+ manifest = None
+
+ if not manifest:
+ try:
+ _update_status('Preparing a manifest file')
+
+ # TODO: maybe generate manifest in a temp directory
+ manifest = VideoManifestManager(db_data.get_manifest_path())
manifest.link(
- sources=extractor.absolute_source_paths,
- meta={ k: {'related_images': related_images[k] } for k in related_images },
- data_dir=upload_dir,
- DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D),
+ media_file=media_files[0],
+ upload_dir=upload_dir,
+ chunk_size=db_data.chunk_size, # TODO: why it's needed here?
+ force=True
)
manifest.create()
- else:
- manifest.init_index()
- counter = itertools.count()
- for _, chunk_frames in itertools.groupby(extractor.frame_range, lambda x: next(counter) // db_data.chunk_size):
- chunk_paths = [(extractor.get_path(i), i) for i in chunk_frames]
- img_sizes = []
-
- for chunk_path, frame_id in chunk_paths:
- properties = manifest[manifest_index(frame_id)]
-
- # check mapping
- if not chunk_path.endswith(f"{properties['name']}{properties['extension']}"):
- raise Exception('Incorrect file mapping to manifest content')
-
- if db_task.dimension == models.DimensionType.DIM_2D and (
- properties.get('width') is not None and
- properties.get('height') is not None
- ):
- resolution = (properties['width'], properties['height'])
- elif is_data_in_cloud:
- raise Exception(
- "Can't find image '{}' width or height info in the manifest"
- .format(f"{properties['name']}{properties['extension']}")
- )
- else:
- resolution = extractor.get_image_size(frame_id)
- img_sizes.append(resolution)
-
- db_images.extend([
- models.Image(data=db_data,
- path=os.path.relpath(path, upload_dir),
- frame=frame, width=w, height=h)
- for (path, frame), (w, h) in zip(chunk_paths, img_sizes)
- ])
- if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE:
- counter = itertools.count()
- generator = itertools.groupby(extractor, lambda _: next(counter) // db_data.chunk_size)
- generator = ((idx, list(chunk_data)) for idx, chunk_data in generator)
-
- def save_chunks(
- executor: concurrent.futures.ThreadPoolExecutor,
- chunk_idx: int,
- chunk_data: Iterable[tuple[str, str, str]]) -> list[tuple[str, int, tuple[int, int]]]:
- nonlocal db_data, db_task, extractor, original_chunk_writer, compressed_chunk_writer
- if (db_task.dimension == models.DimensionType.DIM_2D and
- isinstance(extractor, (
- MEDIA_TYPES['image']['extractor'],
- MEDIA_TYPES['zip']['extractor'],
- MEDIA_TYPES['pdf']['extractor'],
- MEDIA_TYPES['archive']['extractor'],
- ))):
- chunk_data = preload_images(chunk_data)
-
- fs_original = executor.submit(
- original_chunk_writer.save_as_chunk,
- images=chunk_data,
- chunk_path=db_data.get_original_chunk_path(chunk_idx)
- )
- fs_compressed = executor.submit(
- compressed_chunk_writer.save_as_chunk,
- images=chunk_data,
- chunk_path=db_data.get_compressed_chunk_path(chunk_idx),
- )
- fs_original.result()
- image_sizes = fs_compressed.result()
- # (path, frame, size)
- return list((i[0][1], i[0][2], i[1]) for i in zip(chunk_data, image_sizes))
+ _update_status('A manifest has been created')
- def process_results(img_meta: list[tuple[str, int, tuple[int, int]]]):
- nonlocal db_images, db_data, video_path, video_size
+ except Exception as ex:
+ manifest.remove()
+ manifest = None
- if db_task.mode == 'annotation':
- db_images.extend(
- models.Image(
- data=db_data,
- path=os.path.relpath(frame_path, upload_dir),
- frame=frame_number,
- width=frame_size[0],
- height=frame_size[1])
- for frame_path, frame_number, frame_size in img_meta)
+ if isinstance(ex, AssertionError):
+ base_msg = f": {ex}"
+ else:
+ base_msg = ""
+ slogger.glob.warning(ex, exc_info=True)
+
+ _update_status(
+ f"Failed to create manifest for the uploaded video{base_msg}. "
+ "A manifest will not be used in this task"
+ )
+
+ if manifest:
+ video_frame_count = manifest.video_length
+ video_frame_size = manifest.video_resolution
else:
- video_size = img_meta[0][2]
- video_path = img_meta[0][0]
+ video_frame_count = extractor.get_frame_count()
+ video_frame_size = extractor.get_image_size(0)
+
+ db_data.size = len(range(
+ db_data.start_frame,
+ min(
+ data['stop_frame'] + 1 if data['stop_frame'] else video_frame_count,
+ video_frame_count,
+ ),
+ db_data.get_frame_step()
+ ))
+ video_path = os.path.join(upload_dir, media_files[0])
+ else: # images, archive, pdf
+ db_data.size = len(extractor)
- progress = extractor.get_progress(img_meta[-1][1])
- update_progress(progress)
+ manifest = ImageManifestManager(db_data.get_manifest_path())
+ if not manifest.exists:
+ manifest.link(
+ sources=extractor.absolute_source_paths,
+ meta={
+ k: {'related_images': related_images[k] }
+ for k in related_images
+ },
+ data_dir=upload_dir,
+ DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D),
+ )
+ manifest.create()
+ else:
+ manifest.init_index()
+
+ for frame_id in extractor.frame_range:
+ image_path = extractor.get_path(frame_id)
+ image_size = None
+
+ if manifest:
+ image_info = manifest[manifest_index(frame_id)]
+
+ # check mapping
+ if not image_path.endswith(f"{image_info['name']}{image_info['extension']}"):
+ raise ValidationError('Incorrect file mapping to manifest content')
+
+ if db_task.dimension == models.DimensionType.DIM_2D and (
+ image_info.get('width') is not None and
+ image_info.get('height') is not None
+ ):
+ image_size = (image_info['width'], image_info['height'])
+ elif is_data_in_cloud:
+ raise ValidationError(
+ "Can't find image '{}' width or height info in the manifest"
+ .format(f"{image_info['name']}{image_info['extension']}")
+ )
- futures = queue.Queue(maxsize=settings.CVAT_CONCURRENT_CHUNK_PROCESSING)
- with concurrent.futures.ThreadPoolExecutor(max_workers=2*settings.CVAT_CONCURRENT_CHUNK_PROCESSING) as executor:
- for chunk_idx, chunk_data in generator:
- db_data.size += len(chunk_data)
- if futures.full():
- process_results(futures.get().result())
- futures.put(executor.submit(save_chunks, executor, chunk_idx, chunk_data))
+ if not image_size:
+ image_size = extractor.get_image_size(frame_id)
- while not futures.empty():
- process_results(futures.get().result())
+ images.append(
+ models.Image(
+ data=db_data,
+ path=os.path.relpath(image_path, upload_dir),
+ frame=frame_id,
+ width=image_size[0],
+ height=image_size[1],
+ )
+ )
if db_task.mode == 'annotation':
- models.Image.objects.bulk_create(db_images)
- created_images = models.Image.objects.filter(data_id=db_data.id)
+ models.Image.objects.bulk_create(images)
+ images = models.Image.objects.filter(data_id=db_data.id)
db_related_files = [
models.RelatedFile(data=image.data, primary_image=image, path=os.path.join(upload_dir, related_file_path))
- for image in created_images
+ for image in images
for related_file_path in related_images.get(image.path, [])
]
models.RelatedFile.objects.bulk_create(db_related_files)
- db_images = []
else:
models.Video.objects.create(
data=db_data,
path=os.path.relpath(video_path, upload_dir),
- width=video_size[0], height=video_size[1])
+ width=video_frame_size[0], height=video_frame_size[1]
+ )
+ # validate stop_frame
if db_data.stop_frame == 0:
db_data.stop_frame = db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step()
else:
- # validate stop_frame
db_data.stop_frame = min(db_data.stop_frame, \
db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step())
slogger.glob.info("Found frames {} for Data #{}".format(db_data.size, db_data.id))
- _save_task_to_db(db_task, job_file_mapping=job_file_mapping)
+ _create_segments_and_jobs(db_task, job_file_mapping=job_file_mapping)
+
+ if (
+ settings.MEDIA_CACHE_ALLOW_STATIC_CACHE and
+ db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM
+ ):
+ _create_static_chunks(db_task, media_extractor=extractor)
+
+def _create_static_chunks(db_task: models.Task, *, media_extractor: IMediaReader):
+ @attrs.define
+ class _ChunkProgressUpdater:
+ _call_counter: int = attrs.field(default=0, init=False)
+ _rq_job: rq.job.Job = attrs.field(factory=rq.get_current_job)
+
+ def update_progress(self, progress: float):
+ progress_animation = '|/-\\'
+
+ status_message = 'CVAT is preparing data chunks'
+ if not progress:
+ status_message = '{} {}'.format(
+ status_message, progress_animation[self._call_counter]
+ )
+
+ self._rq_job.meta['status'] = status_message
+ self._rq_job.meta['task_progress'] = progress or 0.
+ self._rq_job.save_meta()
+
+ self._call_counter = (self._call_counter + 1) % len(progress_animation)
+
+ def save_chunks(
+ executor: concurrent.futures.ThreadPoolExecutor,
+ db_segment: models.Segment,
+ chunk_idx: int,
+ chunk_frame_ids: Sequence[int]
+ ):
+ chunk_data = [media_iterator[frame_idx] for frame_idx in chunk_frame_ids]
+
+ if (
+ db_task.dimension == models.DimensionType.DIM_2D and
+ isinstance(media_extractor, (
+ MEDIA_TYPES['image']['extractor'],
+ MEDIA_TYPES['zip']['extractor'],
+ MEDIA_TYPES['pdf']['extractor'],
+ MEDIA_TYPES['archive']['extractor'],
+ ))
+ ):
+ chunk_data = preload_images(chunk_data)
+
+ # TODO: extract into a class
+
+ fs_original = executor.submit(
+ original_chunk_writer.save_as_chunk,
+ images=chunk_data,
+ chunk_path=db_data.get_original_segment_chunk_path(
+ chunk_idx, segment_id=db_segment.id
+ ),
+ )
+ compressed_chunk_writer.save_as_chunk(
+ images=chunk_data,
+ chunk_path=db_data.get_compressed_segment_chunk_path(
+ chunk_idx, segment_id=db_segment.id
+ ),
+ )
+
+ fs_original.result()
+
+ db_data = db_task.data
+
+ if db_data.compressed_chunk_type == models.DataChoice.VIDEO:
+ compressed_chunk_writer_class = Mpeg4CompressedChunkWriter
+ else:
+ compressed_chunk_writer_class = ZipCompressedChunkWriter
+
+ if db_data.original_chunk_type == models.DataChoice.VIDEO:
+ original_chunk_writer_class = Mpeg4ChunkWriter
+
+ # Let's use QP=17 (that is 67 for 0-100 range) for the original chunks,
+ # which should be visually lossless or nearly so.
+ # A lower value will significantly increase the chunk size with a slight increase of quality.
+ original_quality = 67 # TODO: fix discrepancy in values in different parts of code
+ else:
+ original_chunk_writer_class = ZipChunkWriter
+ original_quality = 100
+
+ chunk_writer_kwargs = {}
+ if db_task.dimension == models.DimensionType.DIM_3D:
+ chunk_writer_kwargs["dimension"] = db_task.dimension
+ compressed_chunk_writer = compressed_chunk_writer_class(
+ db_data.image_quality, **chunk_writer_kwargs
+ )
+ original_chunk_writer = original_chunk_writer_class(original_quality, **chunk_writer_kwargs)
+
+ db_segments = db_task.segment_set.all()
+
+ if isinstance(media_extractor, MEDIA_TYPES['video']['extractor']):
+ def _get_frame_size(frame_tuple: Tuple[av.VideoFrame, Any, Any]) -> int:
+ # There is no need to be absolutely precise here,
+ # just need to provide the reasonable upper boundary.
+ # Return bytes needed for 1 frame
+ frame = frame_tuple[0]
+ return frame.width * frame.height * (frame.format.padded_bits_per_pixel // 8)
+
+ # Currently, we only optimize video creation for sequential
+ # chunks with potential overlap, so parallel processing is likely to
+ # help only for image datasets
+ media_iterator = CachingMediaIterator(
+ media_extractor,
+ max_cache_memory=2 ** 30, max_cache_entries=db_task.overlap,
+ object_size_callback=_get_frame_size
+ )
+ else:
+ media_iterator = RandomAccessIterator(media_extractor)
+
+ with closing(media_iterator):
+ progress_updater = _ChunkProgressUpdater()
+
+ # TODO: remove 2 * or the configuration option
+ # TODO: maybe make real multithreading support, currently the code is limited by 1
+ # video segment chunk, even if more threads are available
+ max_concurrency = 2 * settings.CVAT_CONCURRENT_CHUNK_PROCESSING if not isinstance(
+ media_extractor, MEDIA_TYPES['video']['extractor']
+ ) else 2
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
+ frame_step = db_data.get_frame_step()
+ for segment_idx, db_segment in enumerate(db_segments):
+ frame_counter = itertools.count()
+ for chunk_idx, chunk_frame_ids in (
+ (chunk_idx, list(chunk_frame_ids))
+ for chunk_idx, chunk_frame_ids in itertools.groupby(
+ (
+ # Convert absolute to relative ids (extractor output positions)
+ # Extractor will skip frames outside requested
+ (abs_frame_id - db_data.start_frame) // frame_step
+ for abs_frame_id in db_segment.frame_set
+ ),
+ lambda _: next(frame_counter) // db_data.chunk_size
+ )
+ ):
+ save_chunks(executor, db_segment, chunk_idx, chunk_frame_ids)
+
+ progress_updater.update_progress(segment_idx / len(db_segments))
diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py
index ae0200b6a2aa..e7ae8ae9ba7b 100644
--- a/cvat/apps/engine/tests/test_rest_api.py
+++ b/cvat/apps/engine/tests/test_rest_api.py
@@ -1422,7 +1422,13 @@ def _create_task(task_data, media_data):
if isinstance(media, io.BytesIO):
media.seek(0)
response = cls.client.post("/api/tasks/{}/data".format(tid), data=media_data)
- assert response.status_code == status.HTTP_202_ACCEPTED
+ assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
+ rq_id = response.json()["rq_id"]
+
+ response = cls.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
+
response = cls.client.get("/api/tasks/{}".format(tid))
data_id = response.data["data"]
cls.tasks.append({
@@ -1766,6 +1772,12 @@ def _create_task(task_data, media_data):
media.seek(0)
response = self.client.post("/api/tasks/{}/data".format(tid), data=media_data)
assert response.status_code == status.HTTP_202_ACCEPTED
+ rq_id = response.json()["rq_id"]
+
+ response = self.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
+
response = self.client.get("/api/tasks/{}".format(tid))
data_id = response.data["data"]
self.tasks.append({
@@ -2882,6 +2894,12 @@ def _create_task(task_data, media_data):
media.seek(0)
response = self.client.post("/api/tasks/{}/data".format(tid), data=media_data)
assert response.status_code == status.HTTP_202_ACCEPTED
+ rq_id = response.json()["rq_id"]
+
+ response = self.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
+
response = self.client.get("/api/tasks/{}".format(tid))
data_id = response.data["data"]
self.tasks.append({
@@ -3433,7 +3451,7 @@ def _test_api_v2_tasks_id_data_spec(self, user, spec, data,
expected_compressed_type,
expected_original_type,
expected_image_sizes,
- expected_storage_method=StorageMethodChoice.FILE_SYSTEM,
+ expected_storage_method=None,
expected_uploaded_data_location=StorageChoice.LOCAL,
dimension=DimensionType.DIM_2D,
expected_task_creation_status_state='Finished',
@@ -3448,6 +3466,12 @@ def _test_api_v2_tasks_id_data_spec(self, user, spec, data,
if get_status_callback is None:
get_status_callback = self._get_task_creation_status
+ if expected_storage_method is None:
+ if settings.MEDIA_CACHE_ALLOW_STATIC_CACHE:
+ expected_storage_method = StorageMethodChoice.FILE_SYSTEM
+ else:
+ expected_storage_method = StorageMethodChoice.CACHE
+
# create task
response = self._create_task(user, spec)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -4007,7 +4031,7 @@ def _test_api_v2_tasks_id_data_create_can_use_chunked_local_video(self, user):
image_sizes = self._share_image_sizes['test_rotated_90_video.mp4']
self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET,
- self.ChunkType.VIDEO, image_sizes, StorageMethodChoice.FILE_SYSTEM)
+ self.ChunkType.VIDEO, image_sizes, StorageMethodChoice.CACHE)
def _test_api_v2_tasks_id_data_create_can_use_chunked_cached_local_video(self, user):
task_spec = {
@@ -4104,7 +4128,6 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_and_manifest(self, u
task_data = {
"image_quality": 70,
- "use_cache": True
}
manifest_name = "images_manifest_sorted.jsonl"
@@ -4115,79 +4138,34 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_and_manifest(self, u
for i, fn in enumerate(images + [manifest_name])
})
- for copy_data in [True, False]:
- with self.subTest(current_function_name(), copy=copy_data):
- task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' copy={copy_data}'
- task_data['copy_data'] = copy_data
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data,
- self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE,
- StorageChoice.LOCAL if copy_data else StorageChoice.SHARE)
-
- with self.subTest(current_function_name() + ' file order mismatch'):
- task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' mismatching file order'
- task_data_copy = task_data.copy()
- task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl"
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy,
- self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE,
- expected_task_creation_status_state='Failed',
- expected_task_creation_status_reason='Incorrect file mapping to manifest content')
-
- for copy_data in [True, False]:
- with self.subTest(current_function_name(), copy=copy_data):
- task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' copy={copy_data}'
- task_data['copy_data'] = copy_data
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data,
- self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE,
- StorageChoice.LOCAL if copy_data else StorageChoice.SHARE)
-
- with self.subTest(current_function_name() + ' file order mismatch'):
- task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' mismatching file order'
- task_data_copy = task_data.copy()
- task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl"
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy,
- self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE,
- expected_task_creation_status_state='Failed',
- expected_task_creation_status_reason='Incorrect file mapping to manifest content')
-
- for copy_data in [True, False]:
- with self.subTest(current_function_name(), copy=copy_data):
+ for use_cache in [True, False]:
+ task_data['use_cache'] = use_cache
+
+ for copy_data in [True, False]:
+ with self.subTest(current_function_name(), copy=copy_data, use_cache=use_cache):
+ task_spec = task_spec_common.copy()
+ task_spec['name'] = task_spec['name'] + f' copy={copy_data}'
+ task_data_copy = task_data.copy()
+ task_data_copy['copy_data'] = copy_data
+ self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy,
+ self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
+ image_sizes,
+ expected_uploaded_data_location=(
+ StorageChoice.LOCAL if copy_data else StorageChoice.SHARE
+ )
+ )
+
+ with self.subTest(current_function_name() + ' file order mismatch', use_cache=use_cache):
task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' copy={copy_data}'
- task_data['copy_data'] = copy_data
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data,
+ task_spec['name'] = task_spec['name'] + f' mismatching file order'
+ task_data_copy = task_data.copy()
+ task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl"
+ self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy,
self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE,
- StorageChoice.LOCAL if copy_data else StorageChoice.SHARE)
-
- with self.subTest(current_function_name() + ' file order mismatch'):
- task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' mismatching file order'
- task_data_copy = task_data.copy()
- task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl"
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy,
- self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE,
- expected_task_creation_status_state='Failed',
- expected_task_creation_status_reason='Incorrect file mapping to manifest content')
-
- with self.subTest(current_function_name() + ' without use cache'):
- task_spec = task_spec_common.copy()
- task_spec['name'] = task_spec['name'] + f' manifest without cache'
- task_data_copy = task_data.copy()
- task_data_copy['use_cache'] = False
- self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy,
- self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE,
- expected_task_creation_status_state='Failed',
- expected_task_creation_status_reason="A manifest file can only be used with the 'use cache' option")
+ image_sizes,
+ expected_uploaded_data_location=StorageChoice.SHARE,
+ expected_task_creation_status_state='Failed',
+ expected_task_creation_status_reason='Incorrect file mapping to manifest content')
def _test_api_v2_tasks_id_data_create_can_use_server_images_with_predefined_sorting(self, user):
task_spec = {
@@ -4219,7 +4197,7 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_with_predefined_sort
task_data = task_data_common.copy()
task_data["use_cache"] = caching_enabled
- if caching_enabled:
+ if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE:
storage_method = StorageMethodChoice.CACHE
else:
storage_method = StorageMethodChoice.FILE_SYSTEM
@@ -4278,7 +4256,7 @@ def _test_api_v2_tasks_id_data_create_can_use_local_images_with_predefined_sorti
sorting_method=SortingMethod.PREDEFINED)
task_data_common["use_cache"] = caching_enabled
- if caching_enabled:
+ if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE:
storage_method = StorageMethodChoice.CACHE
else:
storage_method = StorageMethodChoice.FILE_SYSTEM
@@ -4339,7 +4317,7 @@ def _test_api_v2_tasks_id_data_create_can_use_server_archive_with_predefined_sor
task_data = task_data_common.copy()
task_data["use_cache"] = caching_enabled
- if caching_enabled:
+ if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE:
storage_method = StorageMethodChoice.CACHE
else:
storage_method = StorageMethodChoice.FILE_SYSTEM
@@ -4412,7 +4390,7 @@ def _test_api_v2_tasks_id_data_create_can_use_local_archive_with_predefined_sort
sorting_method=SortingMethod.PREDEFINED)
task_data["use_cache"] = caching_enabled
- if caching_enabled:
+ if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE:
storage_method = StorageMethodChoice.CACHE
else:
storage_method = StorageMethodChoice.FILE_SYSTEM
@@ -4590,7 +4568,7 @@ def _send_data_and_fail(*args, **kwargs):
self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data,
self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL,
+ image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL,
send_data_callback=_send_data)
with self.subTest(current_function_name() + ' mismatching file sets - extra files'):
@@ -4604,7 +4582,7 @@ def _send_data_and_fail(*args, **kwargs):
with self.assertRaisesMessage(Exception, "(extra)"):
self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data,
self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL,
+ image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL,
send_data_callback=_send_data_and_fail)
with self.subTest(current_function_name() + ' mismatching file sets - missing files'):
@@ -4618,7 +4596,7 @@ def _send_data_and_fail(*args, **kwargs):
with self.assertRaisesMessage(Exception, "(missing)"):
self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data,
self.ChunkType.IMAGESET, self.ChunkType.IMAGESET,
- image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL,
+ image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL,
send_data_callback=_send_data_and_fail)
def _test_api_v2_tasks_id_data_create_can_use_server_rar(self, user):
diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py
index a67a79109f33..9f000be5d218 100644
--- a/cvat/apps/engine/tests/test_rest_api_3D.py
+++ b/cvat/apps/engine/tests/test_rest_api_3D.py
@@ -86,9 +86,13 @@ def _create_task(self, data, image_data):
assert response.status_code == status.HTTP_201_CREATED, response.status_code
tid = response.data["id"]
- response = self.client.post("/api/tasks/%s/data" % tid,
- data=image_data)
+ response = self.client.post("/api/tasks/%s/data" % tid, data=image_data)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
+ rq_id = response.json()["rq_id"]
+
+ response = self.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
response = self.client.get("/api/tasks/%s" % tid)
@@ -527,7 +531,7 @@ def test_api_v2_dump_and_upload_annotation(self):
for user, edata in list(self.expected_dump_upload.items()):
with self.subTest(format=f"{format_name}_{edata['name']}_dump"):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
url = self._generate_url_dump_tasks_annotations(task_id)
file_name = osp.join(test_dir, f"{format_name}_{edata['name']}.zip")
@@ -718,7 +722,7 @@ def test_api_v2_export_dataset(self):
for user, edata in list(self.expected_dump_upload.items()):
with self.subTest(format=f"{format_name}_{edata['name']}_export"):
- self._clear_rq_jobs() # clean up from previous tests and iterations
+ self._clear_temp_data() # clean up from previous tests and iterations
url = self._generate_url_dump_dataset(task_id)
file_name = osp.join(test_dir, f"{format_name}_{edata['name']}.zip")
@@ -740,6 +744,8 @@ def test_api_v2_export_dataset(self):
content = io.BytesIO(b"".join(response.streaming_content))
with open(file_name, "wb") as f:
f.write(content.getvalue())
- self.assertEqual(osp.exists(file_name), edata['file_exists'])
- self._check_dump_content(content, task_ann_prev.data, format_name,related_files=False)
+ self.assertEqual(osp.exists(file_name), edata['file_exists'])
+ self._check_dump_content(
+ content, task_ann_prev.data, format_name, related_files=False
+ )
diff --git a/cvat/apps/engine/tests/utils.py b/cvat/apps/engine/tests/utils.py
index b884b3e9b4c4..3d2a533d1e97 100644
--- a/cvat/apps/engine/tests/utils.py
+++ b/cvat/apps/engine/tests/utils.py
@@ -13,7 +13,7 @@
from django.core.cache import caches
from django.http.response import HttpResponse
from PIL import Image
-from rest_framework.test import APIClient, APITestCase
+from rest_framework.test import APITestCase
import av
import django_rq
import numpy as np
@@ -92,14 +92,7 @@ def clear_rq_jobs():
class ApiTestBase(APITestCase):
- def _clear_rq_jobs(self):
- clear_rq_jobs()
-
- def setUp(self):
- super().setUp()
- self.client = APIClient()
-
- def tearDown(self):
+ def _clear_temp_data(self):
# Clear server frame/chunk cache.
# The parent class clears DB changes, and it can lead to under-cleaned task data,
# which can affect other tests.
@@ -112,7 +105,14 @@ def tearDown(self):
# Clear any remaining RQ jobs produced by the tests executed
self._clear_rq_jobs()
- return super().tearDown()
+ def _clear_rq_jobs(self):
+ clear_rq_jobs()
+
+ def setUp(self):
+ self._clear_temp_data()
+
+ super().setUp()
+ self.client = self.client_class()
def generate_image_file(filename, size=(100, 100)):
diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py
index 05a50857b28f..3cb7e34c5c40 100644
--- a/cvat/apps/engine/views.py
+++ b/cvat/apps/engine/views.py
@@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: MIT
+from abc import ABCMeta, abstractmethod
import os
import os.path as osp
import re
@@ -12,7 +13,7 @@
from contextlib import suppress
from PIL import Image
from types import SimpleNamespace
-from typing import Optional, Any, Dict, List, cast, Callable, Mapping, Iterable
+from typing import Optional, Any, Dict, List, Union, cast, Callable, Mapping, Iterable
import traceback
import textwrap
from collections import namedtuple
@@ -58,12 +59,14 @@
from cvat.apps.events.handlers import handle_dataset_import
from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer
-from cvat.apps.engine.frame_provider import FrameProvider
+from cvat.apps.engine.frame_provider import (
+ IFrameProvider, TaskFrameProvider, JobFrameProvider, FrameQuality
+)
from cvat.apps.engine.filters import NonModelSimpleFilter, NonModelOrderingFilter, NonModelJsonLogicFilter
from cvat.apps.engine.media_extractors import get_mime
from cvat.apps.engine.permissions import AnnotationGuidePermission, get_iam_context
from cvat.apps.engine.models import (
- ClientFile, Job, JobType, Label, SegmentType, Task, Project, Issue, Data,
+ ClientFile, Job, JobType, Label, Task, Project, Issue, Data,
Comment, StorageMethodChoice, StorageChoice,
CloudProviderChoice, Location, CloudStorage as CloudStorageModel,
Asset, AnnotationGuide, RequestStatus, RequestAction, RequestTarget, RequestSubresource
@@ -631,19 +634,17 @@ def append_backup_chunk(self, request, file_id):
def preview(self, request, pk):
self._object = self.get_object() # call check_object_permissions as well
- first_task = self._object.tasks.select_related('data__video').order_by('-id').first()
+ first_task: Optional[models.Task] = self._object.tasks.order_by('-id').first()
if not first_task:
return HttpResponseNotFound('Project image preview not found')
- data_getter = DataChunkGetter(
+ data_getter = _TaskDataGetter(
+ db_task=first_task,
data_type='preview',
data_quality='compressed',
- data_num=first_task.data.start_frame,
- task_dim=first_task.dimension
)
- return data_getter(request, first_task.data.start_frame,
- first_task.data.stop_frame, first_task.data)
+ return data_getter()
@staticmethod
def _get_rq_response(queue, job_id):
@@ -663,80 +664,50 @@ def _get_rq_response(queue, job_id):
return response
-class DataChunkGetter:
- def __init__(self, data_type, data_num, data_quality, task_dim):
+class _DataGetter(metaclass=ABCMeta):
+ def __init__(
+ self, data_type: str, data_num: Optional[Union[str, int]], data_quality: str
+ ) -> None:
possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image')
possible_quality_values = ('compressed', 'original')
if not data_type or data_type not in possible_data_type_values:
raise ValidationError('Data type not specified or has wrong value')
elif data_type == 'chunk' or data_type == 'frame' or data_type == 'preview':
- if data_num is None:
+ if data_num is None and data_type != 'preview':
raise ValidationError('Number is not specified')
elif data_quality not in possible_quality_values:
raise ValidationError('Wrong quality value')
self.type = data_type
self.number = int(data_num) if data_num is not None else None
- self.quality = FrameProvider.Quality.COMPRESSED \
- if data_quality == 'compressed' else FrameProvider.Quality.ORIGINAL
-
- self.dimension = task_dim
-
- def _check_frame_range(self, frame: int):
- frame_range = range(self._start, self._stop + 1, self._db_data.get_frame_step())
- if frame not in frame_range:
- raise ValidationError(
- f'The frame number should be in the [{self._start}, {self._stop}] range'
- )
-
- def __call__(self, request, start: int, stop: int, db_data: Optional[Data]):
- if not db_data:
- raise NotFound(detail='Cannot find requested data')
+ self.quality = FrameQuality.COMPRESSED \
+ if data_quality == 'compressed' else FrameQuality.ORIGINAL
- self._start = start
- self._stop = stop
- self._db_data = db_data
+ @abstractmethod
+ def _get_frame_provider(self) -> IFrameProvider: ...
- frame_provider = FrameProvider(db_data, self.dimension)
+ def __call__(self):
+ frame_provider = self._get_frame_provider()
try:
if self.type == 'chunk':
- start_chunk = frame_provider.get_chunk_number(start)
- stop_chunk = frame_provider.get_chunk_number(stop)
- # pylint: disable=superfluous-parens
- if not (start_chunk <= self.number <= stop_chunk):
- raise ValidationError('The chunk number should be in the ' +
- f'[{start_chunk}, {stop_chunk}] range')
-
- # TODO: av.FFmpegError processing
- if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE:
- buff, mime_type = frame_provider.get_chunk(self.number, self.quality)
- return HttpResponse(buff.getvalue(), content_type=mime_type)
-
- # Follow symbol links if the chunk is a link on a real image otherwise
- # mimetype detection inside sendfile will work incorrectly.
- path = os.path.realpath(frame_provider.get_chunk(self.number, self.quality))
- return sendfile(request, path)
+ data = frame_provider.get_chunk(self.number, quality=self.quality)
+ return HttpResponse(data.data.getvalue(), content_type=data.mime)
elif self.type == 'frame' or self.type == 'preview':
- self._check_frame_range(self.number)
-
if self.type == 'preview':
- cache = MediaCache(self.dimension)
- buf, mime = cache.get_local_preview_with_mime(self.number, db_data)
+ data = frame_provider.get_preview()
else:
- buf, mime = frame_provider.get_frame(self.number, self.quality)
+ data = frame_provider.get_frame(self.number, quality=self.quality)
- return HttpResponse(buf.getvalue(), content_type=mime)
+ return HttpResponse(data.data.getvalue(), content_type=data.mime)
elif self.type == 'context_image':
- self._check_frame_range(self.number)
-
- cache = MediaCache(self.dimension)
- buff, mime = cache.get_frame_context_images(db_data, self.number)
- if not buff:
+ data = frame_provider.get_frame_context_images_chunk(self.number)
+ if not data:
return HttpResponseNotFound()
- return HttpResponse(buff, content_type=mime)
+
+ return HttpResponse(data.data, content_type=data.mime)
else:
return Response(data='unknown data type {}.'.format(self.type),
status=status.HTTP_400_BAD_REQUEST)
@@ -745,44 +716,78 @@ def __call__(self, request, start: int, stop: int, db_data: Optional[Data]):
'\n'.join([str(d) for d in ex.detail])
return Response(data=msg, status=ex.status_code)
+class _TaskDataGetter(_DataGetter):
+ def __init__(
+ self,
+ db_task: models.Task,
+ *,
+ data_type: str,
+ data_quality: str,
+ data_num: Optional[Union[str, int]] = None,
+ ) -> None:
+ super().__init__(data_type=data_type, data_num=data_num, data_quality=data_quality)
+ self._db_task = db_task
+
+ def _get_frame_provider(self) -> TaskFrameProvider:
+ return TaskFrameProvider(self._db_task)
+
+
+class _JobDataGetter(_DataGetter):
+ def __init__(
+ self,
+ db_job: models.Job,
+ *,
+ data_type: str,
+ data_quality: str,
+ data_num: Optional[Union[str, int]] = None,
+ data_index: Optional[Union[str, int]] = None,
+ ) -> None:
+ possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image')
+ possible_quality_values = ('compressed', 'original')
+
+ if not data_type or data_type not in possible_data_type_values:
+ raise ValidationError('Data type not specified or has wrong value')
+ elif data_type == 'chunk' or data_type == 'frame' or data_type == 'preview':
+ if data_type == 'chunk':
+ if data_num is None and data_index is None:
+ raise ValidationError('Number or Index is not specified')
+ if data_num is not None and data_index is not None:
+ raise ValidationError('Number and Index cannot be used together')
+ elif data_num is None and data_type != 'preview':
+ raise ValidationError('Number is not specified')
+ elif data_quality not in possible_quality_values:
+ raise ValidationError('Wrong quality value')
+
+ self.type = data_type
-class JobDataGetter(DataChunkGetter):
- def __init__(self, job: Job, data_type, data_num, data_quality):
- super().__init__(data_type, data_num, data_quality, task_dim=job.segment.task.dimension)
- self.job = job
+ self.index = int(data_index) if data_index is not None else None
+ self.number = int(data_num) if data_num is not None else None
- def _check_frame_range(self, frame: int):
- frame_range = self.job.segment.frame_set
- if frame not in frame_range:
- raise ValidationError("The frame number doesn't belong to the job")
+ self.quality = FrameQuality.COMPRESSED \
+ if data_quality == 'compressed' else FrameQuality.ORIGINAL
- def __call__(self, request, start, stop, db_data):
- if self.type == 'chunk' and self.job.segment.type == SegmentType.SPECIFIC_FRAMES:
- frame_provider = FrameProvider(db_data, self.dimension)
+ self._db_job = db_job
- start_chunk = frame_provider.get_chunk_number(start)
- stop_chunk = frame_provider.get_chunk_number(stop)
- # pylint: disable=superfluous-parens
- if not (start_chunk <= self.number <= stop_chunk):
- raise ValidationError('The chunk number should be in the ' +
- f'[{start_chunk}, {stop_chunk}] range')
+ def _get_frame_provider(self) -> JobFrameProvider:
+ return JobFrameProvider(self._db_job)
- cache = MediaCache()
+ def __call__(self):
+ if self.type == 'chunk':
+ # Reproduce the task chunk indexing
+ frame_provider = self._get_frame_provider()
- if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE:
- buf, mime = cache.get_selective_job_chunk_data_with_mime(
- chunk_number=self.number, quality=self.quality, job=self.job
+ if self.index is not None:
+ data = frame_provider.get_chunk(
+ self.index, quality=self.quality, is_task_chunk=False
)
else:
- buf, mime = cache.prepare_selective_job_chunk(
- chunk_number=self.number, quality=self.quality, db_job=self.job
+ data = frame_provider.get_chunk(
+ self.number, quality=self.quality, is_task_chunk=True
)
- return HttpResponse(buf.getvalue(), content_type=mime)
-
+ return HttpResponse(data.data.getvalue(), content_type=data.mime)
else:
- return super().__call__(request, start, stop, db_data)
-
+ return super().__call__()
@extend_schema(tags=['tasks'])
@extend_schema_view(
@@ -1306,11 +1311,10 @@ def data(self, request, pk):
data_num = request.query_params.get('number', None)
data_quality = request.query_params.get('quality', 'compressed')
- data_getter = DataChunkGetter(data_type, data_num, data_quality,
- self._object.dimension)
-
- return data_getter(request, self._object.data.start_frame,
- self._object.data.stop_frame, self._object.data)
+ data_getter = _TaskDataGetter(
+ self._object, data_type=data_type, data_num=data_num, data_quality=data_quality
+ )
+ return data_getter()
@tus_chunk_action(detail=True, suffix_base="data")
def append_data_chunk(self, request, pk, file_id):
@@ -1651,15 +1655,12 @@ def preview(self, request, pk):
if not self._object.data:
return HttpResponseNotFound('Task image preview not found')
- data_getter = DataChunkGetter(
+ data_getter = _TaskDataGetter(
+ db_task=self._object,
data_type='preview',
data_quality='compressed',
- data_num=self._object.data.start_frame,
- task_dim=self._object.dimension
)
-
- return data_getter(request, self._object.data.start_frame,
- self._object.data.stop_frame, self._object.data)
+ return data_getter()
@extend_schema(tags=['jobs'])
@@ -2026,8 +2027,14 @@ def get_export_callback(self, save_images: bool) -> Callable:
OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False,
type=OpenApiTypes.STR, enum=['compressed', 'original'],
description="Specifies the quality level of the requested data"),
- OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT,
- description="A unique number value identifying chunk or frame"),
+ OpenApiParameter('number',
+ location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT,
+ description="A unique number value identifying chunk or frame. "
+ "The numbers are the same as for the task. "
+ "Deprecated for chunks in favor of 'index'"),
+ OpenApiParameter('index',
+ location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT,
+ description="A unique number value identifying chunk, starts from 0 for each job"),
],
responses={
'200': OpenApiResponse(OpenApiTypes.BINARY, description='Data of a specific type'),
@@ -2039,12 +2046,15 @@ def data(self, request, pk):
db_job = self.get_object() # call check_object_permissions as well
data_type = request.query_params.get('type', None)
data_num = request.query_params.get('number', None)
+ data_index = request.query_params.get('index', None)
data_quality = request.query_params.get('quality', 'compressed')
- data_getter = JobDataGetter(db_job, data_type, data_num, data_quality)
-
- return data_getter(request, db_job.segment.start_frame,
- db_job.segment.stop_frame, db_job.segment.task.data)
+ data_getter = _JobDataGetter(
+ db_job,
+ data_type=data_type, data_quality=data_quality,
+ data_index=data_index, data_num=data_num
+ )
+ return data_getter()
@extend_schema(methods=['GET'], summary='Get metainformation for media files in a job',
@@ -2137,15 +2147,12 @@ def metadata(self, request, pk):
def preview(self, request, pk):
self._object = self.get_object() # call check_object_permissions as well
- data_getter = DataChunkGetter(
+ data_getter = _JobDataGetter(
+ db_job=self._object,
data_type='preview',
data_quality='compressed',
- data_num=self._object.segment.start_frame,
- task_dim=self._object.segment.task.dimension
)
-
- return data_getter(request, self._object.segment.start_frame,
- self._object.segment.stop_frame, self._object.segment.task.data)
+ return data_getter()
@extend_schema(tags=['issues'])
@@ -2716,13 +2723,13 @@ def preview(self, request, pk):
# The idea is try to define real manifest preview only for the storages that have related manifests
# because otherwise it can lead to extra calls to a bucket, that are usually not free.
if not db_storage.has_at_least_one_manifest:
- result = cache.get_cloud_preview_with_mime(db_storage)
+ result = cache.get_cloud_preview(db_storage)
if not result:
return HttpResponseNotFound('Cloud storage preview not found')
- return HttpResponse(result[0], result[1])
+ return HttpResponse(result[0].getvalue(), result[1])
- preview, mime = cache.get_or_set_cloud_preview_with_mime(db_storage)
- return HttpResponse(preview, mime)
+ preview, mime = cache.get_or_set_cloud_preview(db_storage)
+ return HttpResponse(preview.getvalue(), mime)
except CloudStorageModel.DoesNotExist:
message = f"Storage {pk} does not exist"
slogger.glob.error(message)
@@ -3391,7 +3398,7 @@ def retrieve(self, request: HttpRequest, pk: str):
job = self._get_rq_job_by_id(pk)
if not job:
- return HttpResponseNotFound(f"There is no request with specified id: {pk}")
+ return HttpResponseNotFound("There is no request with specified id")
self.check_object_permissions(request, job)
@@ -3428,7 +3435,7 @@ def cancel(self, request: HttpRequest, pk: str):
rq_job = self._get_rq_job_by_id(pk)
if not rq_job:
- return HttpResponseNotFound(f"There is no request with specified id: {pk!r}")
+ return HttpResponseNotFound("There is no request with specified id")
self.check_object_permissions(request, rq_job)
diff --git a/cvat/apps/iam/rules/utils.rego b/cvat/apps/iam/rules/utils.rego
index c0f719c63957..7371f2e35edb 100644
--- a/cvat/apps/iam/rules/utils.rego
+++ b/cvat/apps/iam/rules/utils.rego
@@ -36,6 +36,7 @@ RESEND := "resend"
UPDATE_DESC := "update:desc"
UPDATE_ASSIGNEE := "update:assignee"
UPDATE_OWNER := "update:owner"
+UPDATE_ASSOCIATED_STORAGE := "update:associated_storage"
EXPORT_ANNOTATIONS := "export:annotations"
EXPORT_DATASET := "export:dataset"
CREATE_IN_PROJECT := "create@project"
diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py
index c86b4eaa61af..e49b93e24f12 100644
--- a/cvat/apps/lambda_manager/tests/test_lambda.py
+++ b/cvat/apps/lambda_manager/tests/test_lambda.py
@@ -1,11 +1,10 @@
# Copyright (C) 2021-2022 Intel Corporation
-# Copyright (C) 2023 CVAT.ai Corporation
+# Copyright (C) 2023-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from collections import OrderedDict
from itertools import groupby
-from io import BytesIO
from typing import Dict, Optional
from unittest import mock, skip
import json
@@ -14,11 +13,11 @@
import requests
from django.contrib.auth.models import Group, User
from django.http import HttpResponseNotFound, HttpResponseServerError
-from PIL import Image
from rest_framework import status
-from rest_framework.test import APIClient, APITestCase
-from cvat.apps.engine.tests.utils import filter_dict, get_paginated_collection
+from cvat.apps.engine.tests.utils import (
+ ApiTestBase, filter_dict, ForceLogin, generate_image_file, get_paginated_collection
+)
LAMBDA_ROOT_PATH = '/api/lambda'
LAMBDA_FUNCTIONS_PATH = f'{LAMBDA_ROOT_PATH}/functions'
@@ -49,34 +48,11 @@
with open(path) as f:
functions = json.load(f)
-
-def generate_image_file(filename, size=(100, 100)):
- f = BytesIO()
- image = Image.new('RGB', size=size)
- image.save(f, 'jpeg')
- f.name = filename
- f.seek(0)
- return f
-
-
-class ForceLogin:
- def __init__(self, user, client):
- self.user = user
- self.client = client
-
- def __enter__(self):
- if self.user:
- self.client.force_login(self.user, backend='django.contrib.auth.backends.ModelBackend')
-
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- if self.user:
- self.client.logout()
-
-class _LambdaTestCaseBase(APITestCase):
+class _LambdaTestCaseBase(ApiTestBase):
def setUp(self):
- self.client = APIClient(raise_request_exception=False)
+ super().setUp()
+
+ self.client = self.client_class(raise_request_exception=False)
http_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', side_effect = self._get_data_from_lambda_manager_http)
self.addCleanup(http_patcher.stop)
@@ -181,6 +157,11 @@ def _create_task(self, task_spec, data, *, owner=None, org_id=None):
data=data,
QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
+ rq_id = response.json()["rq_id"]
+
+ response = self.client.get(f"/api/requests/{rq_id}")
+ assert response.status_code == status.HTTP_200_OK, response.status_code
+ assert response.json()["status"] == "finished", response.json().get("status")
response = self.client.get("/api/tasks/%s" % tid,
QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py
index 286b8b4cc985..143537985fd7 100644
--- a/cvat/apps/lambda_manager/views.py
+++ b/cvat/apps/lambda_manager/views.py
@@ -1,5 +1,5 @@
# Copyright (C) 2022 Intel Corporation
-# Copyright (C) 2022-2023 CVAT.ai Corporation
+# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
@@ -32,9 +32,9 @@
from rest_framework.request import Request
import cvat.apps.dataset_manager as dm
-from cvat.apps.engine.frame_provider import FrameProvider
+from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider
from cvat.apps.engine.models import (
- Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget,
+ Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget
)
from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField
from cvat.apps.engine.serializers import LabeledDataSerializer
@@ -489,19 +489,19 @@ def transform_attributes(input_attributes, attr_mapping, db_attributes):
def _get_image(self, db_task, frame, quality):
if quality is None or quality == "original":
- quality = FrameProvider.Quality.ORIGINAL
+ quality = FrameQuality.ORIGINAL
elif quality == "compressed":
- quality = FrameProvider.Quality.COMPRESSED
+ quality = FrameQuality.COMPRESSED
else:
raise ValidationError(
'`{}` lambda function was run '.format(self.id) +
'with wrong arguments (quality={})'.format(quality),
code=status.HTTP_400_BAD_REQUEST)
- frame_provider = FrameProvider(db_task.data)
+ frame_provider = TaskFrameProvider(db_task)
image = frame_provider.get_frame(frame, quality=quality)
- return base64.b64encode(image[0].getvalue()).decode('utf-8')
+ return base64.b64encode(image.data.getvalue()).decode('utf-8')
class LambdaQueue:
RESULT_TTL = timedelta(minutes=30)
diff --git a/cvat/apps/organizations/permissions.py b/cvat/apps/organizations/permissions.py
index 3ee36fba1481..3bf1eb655fe4 100644
--- a/cvat/apps/organizations/permissions.py
+++ b/cvat/apps/organizations/permissions.py
@@ -3,8 +3,6 @@
#
# SPDX-License-Identifier: MIT
-from typing import cast
-
from django.conf import settings
from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum
@@ -175,7 +173,10 @@ def get_scopes(request, view, obj):
}[view.action]
if scope == Scopes.UPDATE:
- if request.data.get('role') != cast(Membership, obj).role:
+ # user should have permissions to view a membership
+ scopes.append(Scopes.VIEW)
+
+ if 'role' in request.data.keys():
scopes.append(Scopes.UPDATE_ROLE)
else:
scopes.append(scope)
diff --git a/cvat/apps/quality_control/views.py b/cvat/apps/quality_control/views.py
index 83cc9dc49e64..00f327539f96 100644
--- a/cvat/apps/quality_control/views.py
+++ b/cvat/apps/quality_control/views.py
@@ -329,7 +329,7 @@ def create(self, request, *args, **kwargs):
def data(self, request, pk):
report = self.get_object() # check permissions
json_report = qc.prepare_report_for_downloading(report, host=get_server_url(request))
- return HttpResponse(json_report.encode())
+ return HttpResponse(json_report.encode(), content_type="application/json")
@extend_schema(tags=["quality"])
diff --git a/cvat/requirements/base.in b/cvat/requirements/base.in
index 3440163d0038..7e9103815df5 100644
--- a/cvat/requirements/base.in
+++ b/cvat/requirements/base.in
@@ -1,7 +1,13 @@
-r ../../utils/dataset_manifest/requirements.in
attrs==21.4.0
+
+# This is the last version of av that supports ffmpeg we depend on.
+# Changing ffmpeg is undesirable, as there might be video decoding differences
+# between versions.
+# TODO: try to move to the newer version
av==9.2.0
+
azure-storage-blob==12.13.0
boto3==1.17.61
clickhouse-connect==0.6.8
diff --git a/cvat/schema.yml b/cvat/schema.yml
index badefe355b8d..779b08fe376f 100644
--- a/cvat/schema.yml
+++ b/cvat/schema.yml
@@ -1,7 +1,7 @@
openapi: 3.0.3
info:
title: CVAT REST API
- version: 2.19.0
+ version: 2.20.0
description: REST API for Computer Vision Annotation Tool (CVAT)
termsOfService: https://www.google.com/policies/terms/
contact:
@@ -2322,11 +2322,18 @@ paths:
type: integer
description: A unique integer value identifying this job.
required: true
+ - in: query
+ name: index
+ schema:
+ type: integer
+ description: A unique number value identifying chunk, starts from 0 for each
+ job
- in: query
name: number
schema:
type: integer
- description: A unique number value identifying chunk or frame
+ description: A unique number value identifying chunk or frame. The numbers
+ are the same as for the task. Deprecated for chunks in favor of 'index'
- in: query
name: quality
schema:
@@ -8074,6 +8081,10 @@ components:
allOf:
- $ref: '#/components/schemas/ChunkType'
readOnly: true
+ data_original_chunk_type:
+ allOf:
+ - $ref: '#/components/schemas/ChunkType'
+ readOnly: true
created_date:
type: string
format: date-time
diff --git a/dev/format_python_code.sh b/dev/format_python_code.sh
index 5b455a296f4d..7eff923abb8a 100755
--- a/dev/format_python_code.sh
+++ b/dev/format_python_code.sh
@@ -25,6 +25,9 @@ for paths in \
"cvat/apps/analytics_report" \
"cvat/apps/engine/lazy_list.py" \
"cvat/apps/engine/background.py" \
+ "cvat/apps/engine/frame_provider.py" \
+ "cvat/apps/engine/cache.py" \
+ "cvat/apps/engine/default_settings.py" \
; do
${BLACK} -- ${paths}
${ISORT} -- ${paths}
diff --git a/docker-compose.yml b/docker-compose.yml
index 051bd0bfd8cf..569e163e9fe5 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -10,6 +10,7 @@ x-backend-env: &backend-env
CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk
CVAT_REDIS_ONDISK_PORT: 6666
CVAT_LOG_IMPORT_ERRORS: 'true'
+ CVAT_ALLOW_STATIC_CACHE: '${CVAT_ALLOW_STATIC_CACHE:-no}'
DJANGO_LOG_SERVER_HOST: vector
DJANGO_LOG_SERVER_PORT: 80
no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-}
diff --git a/helm-chart/test.values.yaml b/helm-chart/test.values.yaml
index 5a5fa8fe6bab..73edaa815d70 100644
--- a/helm-chart/test.values.yaml
+++ b/helm-chart/test.values.yaml
@@ -27,6 +27,12 @@ cvat:
frontend:
imagePullPolicy: Never
+redis:
+ master:
+ # The "flushall" command, which we use in tests, is disabled in helm by default
+ # https://artifacthub.io/packages/helm/bitnami/redis#redis-master-configuration-parameters
+ disableCommands: []
+
keydb:
resources:
requests:
diff --git a/tests/cypress/e2e/features/ground_truth_jobs.js b/tests/cypress/e2e/features/ground_truth_jobs.js
index 9eba445b76a2..0753d59839cc 100644
--- a/tests/cypress/e2e/features/ground_truth_jobs.js
+++ b/tests/cypress/e2e/features/ground_truth_jobs.js
@@ -88,6 +88,15 @@ context('Ground truth jobs', () => {
.should('be.visible');
}
+ function openManagementTab() {
+ cy.clickInTaskMenu('Quality control', true);
+ cy.get('.cvat-task-control-tabs')
+ .within(() => {
+ cy.contains('Management').click();
+ });
+ cy.get('.cvat-quality-control-management-tab').should('exist').and('be.visible');
+ }
+
before(() => {
cy.visit('auth/login');
cy.login();
@@ -187,6 +196,121 @@ context('Ground truth jobs', () => {
});
});
+ describe('Testing ground truth management basics', () => {
+ const serverFiles = ['images/image_1.jpg', 'images/image_2.jpg', 'images/image_3.jpg'];
+
+ before(() => {
+ cy.headlessCreateTask({
+ labels: [{ name: labelName, attributes: [], type: 'any' }],
+ name: taskName,
+ project_id: null,
+ source_storage: { location: 'local' },
+ target_storage: { location: 'local' },
+ }, {
+ server_files: serverFiles,
+ image_quality: 70,
+ use_zip_chunks: true,
+ use_cache: true,
+ sorting_method: 'lexicographical',
+ }).then((taskResponse) => {
+ taskID = taskResponse.taskID;
+ [jobID] = taskResponse.jobIDs;
+ }).then(() => (
+ cy.headlessCreateJob({
+ task_id: taskID,
+ frame_count: 3,
+ type: 'ground_truth',
+ frame_selection_method: 'random_uniform',
+ })
+ )).then((jobResponse) => {
+ groundTruthJobID = jobResponse.jobID;
+ }).then(() => {
+ cy.visit(`/tasks/${taskID}/quality-control#management`);
+ cy.get('.cvat-quality-control-management-tab').should('exist').and('be.visible');
+ cy.get('.cvat-annotations-quality-allocation-table-summary').should('exist').and('be.visible');
+ });
+ });
+
+ after(() => {
+ cy.headlessDeleteTask(taskID);
+ });
+
+ it('Check management page contents.', () => {
+ cy.get('.cvat-annotations-quality-allocation-table-summary').should('exist');
+ cy.contains('.cvat-allocation-summary-excluded', '0').should('exist');
+ cy.contains('.cvat-allocation-summary-total', '3').should('exist');
+ cy.contains('.cvat-allocation-summary-active', '3').should('exist');
+
+ cy.get('.cvat-frame-allocation-table').should('exist');
+ cy.get('.cvat-allocation-frame-row').should('have.length', 3);
+ cy.get('.cvat-allocation-frame-row').each(($el, index) => {
+ cy.wrap($el).within(() => {
+ cy.contains(`#${index}`).should('exist');
+ cy.contains(`images/image_${index + 1}.jpg`).should('exist');
+ });
+ });
+ });
+
+ it('Check link to frame.', () => {
+ cy.get('.cvat-allocation-frame-row').last().within(() => {
+ cy.get('.cvat-open-frame-button').first().click();
+ });
+ cy.get('.cvat-spinner').should('not.exist');
+ cy.url().should('contain', `/tasks/${taskID}/jobs/${groundTruthJobID}`);
+ cy.checkFrameNum(2);
+
+ cy.interactMenu('Open the task');
+ openManagementTab();
+ });
+
+ it('Disable single frame, enable it back.', () => {
+ cy.get('.cvat-allocation-frame-row').last().within(() => {
+ cy.get('.cvat-allocation-frame-delete').click();
+ });
+ cy.get('.cvat-spinner').should('not.exist');
+
+ cy.get('.cvat-allocation-frame-row-excluded').should('exist');
+ cy.contains('.cvat-allocation-summary-excluded', '1').should('exist');
+ cy.contains('.cvat-allocation-summary-active', '2').should('exist');
+
+ cy.get('.cvat-allocation-frame-row-excluded').within(() => {
+ cy.get('.cvat-allocation-frame-restore').click();
+ });
+ cy.get('.cvat-spinner').should('not.exist');
+ cy.get('.cvat-allocation-frame-row-excluded').should('not.exist');
+ cy.contains('.cvat-allocation-summary-excluded', '0').should('exist');
+ cy.contains('.cvat-allocation-summary-active', '3').should('exist');
+ });
+
+ it('Select several frames, use group operations.', () => {
+ function selectFrames() {
+ cy.get('.cvat-allocation-frame-row').each(($el, index) => {
+ if (index !== 0) {
+ cy.wrap($el).within(() => {
+ cy.get('.ant-table-selection-column input[type="checkbox"]').should('not.be.checked').check();
+ });
+ }
+ });
+ }
+
+ selectFrames();
+ cy.get('.cvat-allocation-selection-frame-delete').click();
+ cy.get('.cvat-spinner').should('not.exist');
+
+ cy.get('.cvat-allocation-frame-row-excluded').should('have.length', 2);
+ cy.contains('.cvat-allocation-summary-excluded', '2').should('exist');
+ cy.contains('.cvat-allocation-summary-active', '1').should('exist');
+
+ selectFrames();
+ cy.get('.cvat-allocation-selection-frame-restore').click();
+ cy.get('.cvat-spinner').should('not.exist');
+
+ cy.get('.cvat-allocation-frame-row-excluded').should('not.exist');
+ cy.contains('.cvat-allocation-summary-excluded', '0').should('exist');
+ cy.contains('.cvat-allocation-summary-active', '3').should('exist');
+ });
+ });
+
describe('Regression tests', () => {
const imagesCount = 20;
const imageFileName = 'ground_truth_2';
@@ -205,8 +329,12 @@ context('Ground truth jobs', () => {
cy.imageGenerator(imagesFolder, imageFileName, width, height, color, posX, posY, labelName, imagesCount);
cy.createZipArchive(directoryToArchive, archivePath);
cy.createAnnotationTask(
- taskName, labelName, attrName,
- textDefaultValue, archiveName, false,
+ taskName,
+ labelName,
+ attrName,
+ textDefaultValue,
+ archiveName,
+ false,
{ multiJobs: true, segmentSize: 1 },
);
cy.openTask(taskName);
diff --git a/tests/cypress/e2e/features/requests_page.js b/tests/cypress/e2e/features/requests_page.js
index fb34cf9fba38..3cdf187a9825 100644
--- a/tests/cypress/e2e/features/requests_page.js
+++ b/tests/cypress/e2e/features/requests_page.js
@@ -323,7 +323,8 @@ context('Requests page', () => {
cy.getJobIDFromIdx(0).then((jobID) => {
const closeExportNotification = () => {
cy.contains('Export is finished').should('be.visible');
- cy.closeNotification('.ant-notification-notice-info');
+ cy.contains('Export is finished').parents('.ant-notification-notice')
+ .find('span[aria-label="close"]').click();
};
const exportParams = {
diff --git a/tests/cypress/support/commands.js b/tests/cypress/support/commands.js
index 83628abfb66b..76f7ffe2640c 100644
--- a/tests/cypress/support/commands.js
+++ b/tests/cypress/support/commands.js
@@ -171,11 +171,11 @@ Cypress.Commands.add(
attrName = 'Some attr name',
textDefaultValue = 'Some default value for type Text',
image = 'image.png',
- multiAttrParams,
- advancedConfigurationParams,
+ multiAttrParams = null,
+ advancedConfigurationParams = null,
forProject = false,
attachToProject = false,
- projectName,
+ projectName = '',
expectedResult = 'success',
projectSubsetFieldValue = 'Test',
) => {
@@ -365,6 +365,19 @@ Cypress.Commands.add('headlessLogout', () => {
cy.clearAllLocalStorage();
});
+Cypress.Commands.add('headlessCreateJob', (jobSpec) => {
+ cy.window().then(async ($win) => {
+ const data = {
+ ...jobSpec,
+ };
+
+ const job = new $win.cvat.classes.Job(data);
+
+ const result = await job.save(data);
+ return cy.wrap({ jobID: result.id });
+ });
+});
+
Cypress.Commands.add('openTask', (taskName, projectSubsetFieldValue) => {
cy.contains('strong', new RegExp(`^${taskName}$`))
.parents('.cvat-tasks-list-item')
diff --git a/tests/python/rest_api/test_jobs.py b/tests/python/rest_api/test_jobs.py
index 4fbea276e0a7..290b3689ba60 100644
--- a/tests/python/rest_api/test_jobs.py
+++ b/tests/python/rest_api/test_jobs.py
@@ -11,13 +11,14 @@
from copy import deepcopy
from http import HTTPStatus
from io import BytesIO
-from itertools import product
+from itertools import groupby, product
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pytest
from cvat_sdk import models
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
+from cvat_sdk.api_client.exceptions import ForbiddenException
from cvat_sdk.core.helpers import get_paginated_collection
from deepdiff import DeepDiff
from PIL import Image
@@ -361,7 +362,7 @@ def _test_destroy_job_fails(self, user, job_id, *, expected_status: int, **kwarg
assert response.status == expected_status
return response
- @pytest.mark.usefixtures("restore_cvat_data")
+ @pytest.mark.usefixtures("restore_cvat_data_per_function")
@pytest.mark.parametrize("job_type, allow", (("ground_truth", True), ("annotation", False)))
def test_destroy_job(self, admin_user, jobs, job_type, allow):
job = next(j for j in jobs if j["type"] == job_type)
@@ -603,12 +604,8 @@ def test_get_gt_job_in_org_task(
self._test_get_job_403(user["username"], job["id"])
-@pytest.mark.usefixtures(
- # if the db is restored per test, there are conflicts with the server data cache
- # if we don't clean the db, the gt jobs created will be reused, and their
- # ids won't conflict
- "restore_db_per_class"
-)
+@pytest.mark.usefixtures("restore_db_per_class")
+@pytest.mark.usefixtures("restore_redis_ondisk_per_class")
class TestGetGtJobData:
def _delete_gt_job(self, user, gt_job_id):
with make_api_client(user) as api_client:
@@ -636,12 +633,11 @@ def test_can_get_gt_job_meta(self, admin_user, tasks, jobs, task_mode, request):
:job_frame_count
]
gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids)
+ request.addfinalizer(lambda: self._delete_gt_job(user, gt_job.id))
with make_api_client(user) as api_client:
(gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(gt_job.id)
- request.addfinalizer(lambda: self._delete_gt_job(user, gt_job.id))
-
# These values are relative to the resulting task frames, unlike meta values
assert 0 == gt_job.start_frame
assert task_meta.size - 1 == gt_job.stop_frame
@@ -691,12 +687,11 @@ def test_can_get_gt_job_meta_with_complex_frame_setup(self, admin_user, request)
task_frame_ids = range(start_frame, stop_frame, frame_step)
job_frame_ids = list(task_frame_ids[::3])
gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids)
+ request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id))
with make_api_client(admin_user) as api_client:
(gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(gt_job.id)
- request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id))
-
# These values are relative to the resulting task frames, unlike meta values
assert 0 == gt_job.start_frame
assert len(task_frame_ids) - 1 == gt_job.stop_frame
@@ -717,7 +712,10 @@ def test_can_get_gt_job_meta_with_complex_frame_setup(self, admin_user, request)
@pytest.mark.parametrize("task_mode", ["annotation", "interpolation"])
@pytest.mark.parametrize("quality", ["compressed", "original"])
- def test_can_get_gt_job_chunk(self, admin_user, tasks, jobs, task_mode, quality, request):
+ @pytest.mark.parametrize("indexing", ["absolute", "relative"])
+ def test_can_get_gt_job_chunk(
+ self, admin_user, tasks, jobs, task_mode, quality, request, indexing
+ ):
user = admin_user
job_frame_count = 4
task = next(
@@ -734,41 +732,49 @@ def test_can_get_gt_job_chunk(self, admin_user, tasks, jobs, task_mode, quality,
(task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id)
frame_step = int(task_meta.frame_filter.split("=")[-1]) if task_meta.frame_filter else 1
- job_frame_ids = list(range(task_meta.start_frame, task_meta.stop_frame, frame_step))[
- :job_frame_count
- ]
+ task_frame_ids = range(task_meta.start_frame, task_meta.stop_frame + 1, frame_step)
+ rng = np.random.Generator(np.random.MT19937(42))
+ job_frame_ids = sorted(rng.choice(task_frame_ids, job_frame_count, replace=False).tolist())
+
gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids)
+ request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id))
- with make_api_client(admin_user) as api_client:
- (chunk_file, response) = api_client.jobs_api.retrieve_data(
- gt_job.id, number=0, quality=quality, type="chunk"
- )
- assert response.status == HTTPStatus.OK
+ if indexing == "absolute":
+ chunk_iter = groupby(task_frame_ids, key=lambda f: f // task_meta.chunk_size)
+ else:
+ chunk_iter = groupby(job_frame_ids, key=lambda f: f // task_meta.chunk_size)
- request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id))
+ for chunk_id, chunk_frames in chunk_iter:
+ chunk_frames = list(chunk_frames)
- frame_range = range(
- task_meta.start_frame, min(task_meta.stop_frame + 1, task_meta.chunk_size), frame_step
- )
- included_frames = job_frame_ids
+ if indexing == "absolute":
+ kwargs = {"number": chunk_id}
+ else:
+ kwargs = {"index": chunk_id}
- # The frame count is the same as in the whole range
- # with placeholders in the frames outside the job.
- # This is required by the UI implementation
- with zipfile.ZipFile(chunk_file) as chunk:
- assert set(chunk.namelist()) == set("{:06d}.jpeg".format(i) for i in frame_range)
+ with make_api_client(admin_user) as api_client:
+ (chunk_file, response) = api_client.jobs_api.retrieve_data(
+ gt_job.id, **kwargs, quality=quality, type="chunk"
+ )
+ assert response.status == HTTPStatus.OK
+
+ # The frame count is the same as in the whole range
+ # with placeholders in the frames outside the job.
+ # This is required by the UI implementation
+ with zipfile.ZipFile(chunk_file) as chunk:
+ assert set(chunk.namelist()) == set(
+ f"{i:06d}.jpeg" for i in range(len(chunk_frames))
+ )
- for file_info in chunk.filelist:
- with chunk.open(file_info) as image_file:
- image = Image.open(image_file)
- image_data = np.array(image)
+ for file_info in chunk.filelist:
+ with chunk.open(file_info) as image_file:
+ image = Image.open(image_file)
- if int(os.path.splitext(file_info.filename)[0]) not in included_frames:
- assert image.size == (1, 1)
- assert np.all(image_data == 0), image_data
- else:
- assert image.size > (1, 1)
- assert np.any(image_data != 0)
+ chunk_frame_id = int(os.path.splitext(file_info.filename)[0])
+ if chunk_frames[chunk_frame_id] not in job_frame_ids:
+ assert image.size == (1, 1)
+ else:
+ assert image.size > (1, 1)
def _create_gt_job(self, user, task_id, frames):
with make_api_client(user) as api_client:
@@ -813,6 +819,7 @@ def test_can_get_gt_job_frame(self, admin_user, tasks, jobs, task_mode, quality,
:job_frame_count
]
gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids)
+ request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id))
frame_range = range(
task_meta.start_frame, min(task_meta.stop_frame + 1, task_meta.chunk_size), frame_step
@@ -830,15 +837,13 @@ def test_can_get_gt_job_frame(self, admin_user, tasks, jobs, task_mode, quality,
_check_status=False,
)
assert response.status == HTTPStatus.BAD_REQUEST
- assert b"The frame number doesn't belong to the job" in response.data
+ assert b"Incorrect requested frame number" in response.data
(_, response) = api_client.jobs_api.retrieve_data(
gt_job.id, number=included_frames[0], quality=quality, type="frame"
)
assert response.status == HTTPStatus.OK
- request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id))
-
@pytest.mark.usefixtures("restore_db_per_class")
class TestListJobs:
@@ -1275,6 +1280,15 @@ def test_can_update_assignee_updated_date_on_assignee_updates(
else:
assert updated_job.assignee is None
+ def test_malefactor_cannot_obtain_job_details_via_empty_partial_update_request(
+ self, regular_lonely_user, jobs
+ ):
+ job = next(iter(jobs))
+
+ with make_api_client(regular_lonely_user) as api_client:
+ with pytest.raises(ForbiddenException):
+ api_client.jobs_api.partial_update(job["id"])
+
def _check_coco_job_annotations(content, values_to_be_checked):
exported_annotations = json.loads(content)
diff --git a/tests/python/rest_api/test_memberships.py b/tests/python/rest_api/test_memberships.py
index 98f145cbd846..abaca68ae2df 100644
--- a/tests/python/rest_api/test_memberships.py
+++ b/tests/python/rest_api/test_memberships.py
@@ -8,6 +8,7 @@
import pytest
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
+from cvat_sdk.api_client.exceptions import ForbiddenException
from deepdiff import DeepDiff
from shared.utils.config import get_method, make_api_client, patch_method
@@ -137,6 +138,15 @@ def test_user_cannot_change_self_role(self, who: str, find_users):
user["username"], user["membership_id"], self.ROLES[abs(self.ROLES.index(who) - 1)]
)
+ def test_malefactor_cannot_obtain_membership_details_via_empty_partial_update_request(
+ self, regular_lonely_user, memberships
+ ):
+ membership = next(iter(memberships))
+
+ with make_api_client(regular_lonely_user) as api_client:
+ with pytest.raises(ForbiddenException):
+ api_client.memberships_api.partial_update(membership["id"])
+
@pytest.mark.usefixtures("restore_db_per_function")
class TestDeleteMemberships:
diff --git a/tests/python/rest_api/test_projects.py b/tests/python/rest_api/test_projects.py
index 3105388b153d..61b893cc2e78 100644
--- a/tests/python/rest_api/test_projects.py
+++ b/tests/python/rest_api/test_projects.py
@@ -1398,3 +1398,12 @@ def test_can_update_assignee_updated_date_on_assignee_updates(
assert updated_project.assignee.id == new_assignee_id
else:
assert updated_project.assignee is None
+
+ def test_malefactor_cannot_obtain_project_details_via_empty_partial_update_request(
+ self, regular_lonely_user, projects
+ ):
+ project = next(iter(projects))
+
+ with make_api_client(regular_lonely_user) as api_client:
+ with pytest.raises(ForbiddenException):
+ api_client.projects_api.partial_update(project["id"])
diff --git a/tests/python/rest_api/test_queues.py b/tests/python/rest_api/test_queues.py
index f801e661e426..4ce314b865b2 100644
--- a/tests/python/rest_api/test_queues.py
+++ b/tests/python/rest_api/test_queues.py
@@ -18,7 +18,7 @@
@pytest.mark.usefixtures("restore_db_per_function")
-@pytest.mark.usefixtures("restore_cvat_data")
+@pytest.mark.usefixtures("restore_cvat_data_per_function")
@pytest.mark.usefixtures("restore_redis_inmem_per_function")
class TestRQQueueWorking:
_USER_1 = "admin1"
diff --git a/tests/python/rest_api/test_resource_import_export.py b/tests/python/rest_api/test_resource_import_export.py
index 833661fcfab8..39f4be22a011 100644
--- a/tests/python/rest_api/test_resource_import_export.py
+++ b/tests/python/rest_api/test_resource_import_export.py
@@ -177,7 +177,7 @@ def test_user_cannot_export_to_cloud_storage_with_specific_location_without_acce
@pytest.mark.usefixtures("restore_db_per_function")
-@pytest.mark.usefixtures("restore_cvat_data")
+@pytest.mark.usefixtures("restore_cvat_data_per_function")
class TestImportResourceFromS3(_S3ResourceTest):
@pytest.mark.usefixtures("restore_redis_inmem_per_function")
@pytest.mark.parametrize("cloud_storage_id", [3])
diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py
index e849244361fc..fd5db6388b31 100644
--- a/tests/python/rest_api/test_tasks.py
+++ b/tests/python/rest_api/test_tasks.py
@@ -6,10 +6,14 @@
import io
import itertools
import json
+import math
import os
import os.path as osp
import zipfile
+from abc import ABCMeta, abstractmethod
+from contextlib import closing
from copy import deepcopy
+from enum import Enum
from functools import partial
from http import HTTPStatus
from itertools import chain, product
@@ -18,18 +22,22 @@
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from time import sleep, time
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union
+import attrs
+import numpy as np
import pytest
from cvat_sdk import Client, Config, exceptions
from cvat_sdk.api_client import models
from cvat_sdk.api_client.api_client import ApiClient, ApiException, Endpoint
+from cvat_sdk.api_client.exceptions import ForbiddenException
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import NullProgressReporter
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from cvat_sdk.core.uploading import Uploader
from deepdiff import DeepDiff
from PIL import Image
+from pytest_cases import fixture_ref, parametrize
import shared.utils.s3 as s3
from shared.fixtures.init import docker_exec_cvat, kube_exec_cvat
@@ -48,6 +56,7 @@
generate_image_files,
generate_manifest,
generate_video_file,
+ read_video_file,
)
from .utils import (
@@ -903,7 +912,7 @@ def test_uses_subset_name(
@pytest.mark.usefixtures("restore_db_per_function")
-@pytest.mark.usefixtures("restore_cvat_data")
+@pytest.mark.usefixtures("restore_cvat_data_per_function")
@pytest.mark.usefixtures("restore_redis_ondisk_per_function")
class TestPostTaskData:
_USERNAME = "admin1"
@@ -2028,6 +2037,525 @@ def test_create_task_with_cloud_storage_directories_and_default_bucket_prefix(
assert task.size == expected_task_size
+class _SourceDataType(str, Enum):
+ images = "images"
+ video = "video"
+
+
+class _TaskSpec(models.ITaskWriteRequest, models.IDataRequest, metaclass=ABCMeta):
+ size: int
+ frame_step: int
+ source_data_type: _SourceDataType
+
+ @abstractmethod
+ def read_frame(self, i: int) -> Image.Image: ...
+
+
+@attrs.define
+class _TaskSpecBase(_TaskSpec):
+ _params: Union[Dict, models.TaskWriteRequest]
+ _data_params: Union[Dict, models.DataRequest]
+ size: int = attrs.field(kw_only=True)
+
+ @property
+ def frame_step(self) -> int:
+ v = getattr(self, "frame_filter", "step=1")
+ return int(v.split("=")[-1])
+
+ def __getattr__(self, k: str) -> Any:
+ notfound = object()
+
+ for params in [self._params, self._data_params]:
+ if isinstance(params, dict):
+ v = params.get(k, notfound)
+ else:
+ v = getattr(params, k, notfound)
+
+ if v is not notfound:
+ return v
+
+ raise AttributeError(k)
+
+
+@attrs.define
+class _ImagesTaskSpec(_TaskSpecBase):
+ source_data_type: ClassVar[_SourceDataType] = _SourceDataType.images
+
+ _get_frame: Callable[[int], bytes] = attrs.field(kw_only=True)
+
+ def read_frame(self, i: int) -> Image.Image:
+ return Image.open(io.BytesIO(self._get_frame(i)))
+
+
+@attrs.define
+class _VideoTaskSpec(_TaskSpecBase):
+ source_data_type: ClassVar[_SourceDataType] = _SourceDataType.video
+
+ _get_video_file: Callable[[], io.IOBase] = attrs.field(kw_only=True)
+
+ def read_frame(self, i: int) -> Image.Image:
+ with closing(read_video_file(self._get_video_file())) as reader:
+ for _ in range(i + 1):
+ frame = next(reader)
+
+ return frame
+
+
+@pytest.mark.usefixtures("restore_db_per_class")
+@pytest.mark.usefixtures("restore_redis_ondisk_per_class")
+@pytest.mark.usefixtures("restore_cvat_data_per_function")
+class TestTaskData:
+ _USERNAME = "admin1"
+
+ def _uploaded_images_task_fxt_base(
+ self,
+ request: pytest.FixtureRequest,
+ *,
+ frame_count: int = 10,
+ segment_size: Optional[int] = None,
+ ) -> Generator[Tuple[_TaskSpec, int], None, None]:
+ task_params = {
+ "name": request.node.name,
+ "labels": [{"name": "a"}],
+ }
+ if segment_size:
+ task_params["segment_size"] = segment_size
+
+ image_files = generate_image_files(frame_count)
+ images_data = [f.getvalue() for f in image_files]
+ data_params = {
+ "image_quality": 70,
+ "client_files": image_files,
+ }
+
+ def get_frame(i: int) -> bytes:
+ return images_data[i]
+
+ task_id, _ = create_task(self._USERNAME, spec=task_params, data=data_params)
+ yield _ImagesTaskSpec(
+ models.TaskWriteRequest._from_openapi_data(**task_params),
+ models.DataRequest._from_openapi_data(**data_params),
+ get_frame=get_frame,
+ size=len(images_data),
+ ), task_id
+
+ @pytest.fixture(scope="class")
+ def fxt_uploaded_images_task(
+ self, request: pytest.FixtureRequest
+ ) -> Generator[Tuple[_TaskSpec, int], None, None]:
+ yield from self._uploaded_images_task_fxt_base(request=request)
+
+ @pytest.fixture(scope="class")
+ def fxt_uploaded_images_task_with_segments(
+ self, request: pytest.FixtureRequest
+ ) -> Generator[Tuple[_TaskSpec, int], None, None]:
+ yield from self._uploaded_images_task_fxt_base(request=request, segment_size=4)
+
+ def _uploaded_video_task_fxt_base(
+ self,
+ request: pytest.FixtureRequest,
+ *,
+ frame_count: int = 10,
+ segment_size: Optional[int] = None,
+ ) -> Generator[Tuple[_TaskSpec, int], None, None]:
+ task_params = {
+ "name": request.node.name,
+ "labels": [{"name": "a"}],
+ }
+ if segment_size:
+ task_params["segment_size"] = segment_size
+
+ video_file = generate_video_file(frame_count)
+ video_data = video_file.getvalue()
+ data_params = {
+ "image_quality": 70,
+ "client_files": [video_file],
+ }
+
+ def get_video_file() -> io.BytesIO:
+ return io.BytesIO(video_data)
+
+ task_id, _ = create_task(self._USERNAME, spec=task_params, data=data_params)
+ yield _VideoTaskSpec(
+ models.TaskWriteRequest._from_openapi_data(**task_params),
+ models.DataRequest._from_openapi_data(**data_params),
+ get_video_file=get_video_file,
+ size=frame_count,
+ ), task_id
+
+ @pytest.fixture(scope="class")
+ def fxt_uploaded_video_task(
+ self,
+ request: pytest.FixtureRequest,
+ ) -> Generator[Tuple[_TaskSpec, int], None, None]:
+ yield from self._uploaded_video_task_fxt_base(request=request)
+
+ @pytest.fixture(scope="class")
+ def fxt_uploaded_video_task_with_segments(
+ self, request: pytest.FixtureRequest
+ ) -> Generator[Tuple[_TaskSpec, int], None, None]:
+ yield from self._uploaded_video_task_fxt_base(request=request, segment_size=4)
+
+ def _compute_segment_params(self, task_spec: _TaskSpec) -> List[Tuple[int, int]]:
+ segment_params = []
+ segment_size = getattr(task_spec, "segment_size", 0) or task_spec.size
+ start_frame = getattr(task_spec, "start_frame", 0)
+ end_frame = (getattr(task_spec, "stop_frame", None) or (task_spec.size - 1)) + 1
+ overlap = min(
+ (
+ getattr(task_spec, "overlap", None) or 0
+ if task_spec.source_data_type == _SourceDataType.images
+ else 5
+ ),
+ segment_size // 2,
+ )
+ segment_start = start_frame
+ while segment_start < end_frame:
+ if start_frame < segment_start:
+ segment_start -= overlap * task_spec.frame_step
+
+ segment_end = segment_start + task_spec.frame_step * segment_size
+
+ segment_params.append((segment_start, min(segment_end, end_frame) - 1))
+ segment_start = segment_end
+
+ return segment_params
+
+ @staticmethod
+ def _compare_images(
+ expected: Image.Image, actual: Image.Image, *, must_be_identical: bool = True
+ ):
+ expected_pixels = np.array(expected)
+ chunk_frame_pixels = np.array(actual)
+ assert expected_pixels.shape == chunk_frame_pixels.shape
+
+ if not must_be_identical:
+ # video chunks can have slightly changed colors, due to codec specifics
+ # compressed images can also be distorted
+ assert np.allclose(chunk_frame_pixels, expected_pixels, atol=2)
+ else:
+ assert np.array_equal(chunk_frame_pixels, expected_pixels)
+
+ _default_task_cases = [
+ fixture_ref("fxt_uploaded_images_task"),
+ fixture_ref("fxt_uploaded_images_task_with_segments"),
+ fixture_ref("fxt_uploaded_video_task"),
+ fixture_ref("fxt_uploaded_video_task_with_segments"),
+ ]
+
+ @parametrize("task_spec, task_id", _default_task_cases)
+ def test_can_get_task_meta(self, task_spec: _TaskSpec, task_id: int):
+ with make_api_client(self._USERNAME) as api_client:
+ (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id)
+
+ assert task_meta.size == task_spec.size
+ assert task_meta.start_frame == getattr(task_spec, "start_frame", 0)
+ assert task_meta.stop_frame == getattr(task_spec, "stop_frame", None) or task_spec.size
+ assert task_meta.frame_filter == getattr(task_spec, "frame_filter", "")
+
+ task_frame_set = set(
+ range(task_meta.start_frame, task_meta.stop_frame + 1, task_spec.frame_step)
+ )
+ assert len(task_frame_set) == task_meta.size
+
+ if getattr(task_spec, "chunk_size", None):
+ assert task_meta.chunk_size == task_spec.chunk_size
+
+ if task_spec.source_data_type == _SourceDataType.video:
+ assert len(task_meta.frames) == 1
+ else:
+ assert len(task_meta.frames) == task_meta.size
+
+ @parametrize("task_spec, task_id", _default_task_cases)
+ def test_can_get_task_frames(self, task_spec: _TaskSpec, task_id: int):
+ with make_api_client(self._USERNAME) as api_client:
+ (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id)
+
+ for quality, abs_frame_id in product(
+ ["original", "compressed"],
+ range(task_meta.start_frame, task_meta.stop_frame + 1, task_spec.frame_step),
+ ):
+ rel_frame_id = (
+ abs_frame_id - getattr(task_spec, "start_frame", 0) // task_spec.frame_step
+ )
+ (_, response) = api_client.tasks_api.retrieve_data(
+ task_id,
+ type="frame",
+ quality=quality,
+ number=rel_frame_id,
+ _parse_response=False,
+ )
+
+ if task_spec.source_data_type == _SourceDataType.video:
+ frame_size = (task_meta.frames[0].width, task_meta.frames[0].height)
+ else:
+ frame_size = (
+ task_meta.frames[rel_frame_id].width,
+ task_meta.frames[rel_frame_id].height,
+ )
+
+ frame = Image.open(io.BytesIO(response.data))
+ assert frame_size == frame.size
+
+ self._compare_images(
+ task_spec.read_frame(abs_frame_id),
+ frame,
+ must_be_identical=(
+ task_spec.source_data_type == _SourceDataType.images
+ and quality == "original"
+ ),
+ )
+
+ @parametrize("task_spec, task_id", _default_task_cases)
+ def test_can_get_task_chunks(self, task_spec: _TaskSpec, task_id: int):
+ with make_api_client(self._USERNAME) as api_client:
+ (task, _) = api_client.tasks_api.retrieve(task_id)
+ (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id)
+
+ if task_spec.source_data_type == _SourceDataType.images:
+ assert task.data_original_chunk_type == "imageset"
+ assert task.data_compressed_chunk_type == "imageset"
+ elif task_spec.source_data_type == _SourceDataType.video:
+ assert task.data_original_chunk_type == "video"
+
+ if getattr(task_spec, "use_zip_chunks", False):
+ assert task.data_compressed_chunk_type == "imageset"
+ else:
+ assert task.data_compressed_chunk_type == "video"
+ else:
+ assert False
+
+ chunk_count = math.ceil(task_meta.size / task_meta.chunk_size)
+ for quality, chunk_id in product(["original", "compressed"], range(chunk_count)):
+ expected_chunk_frame_ids = range(
+ chunk_id * task_meta.chunk_size,
+ min((chunk_id + 1) * task_meta.chunk_size, task_meta.size),
+ )
+
+ (_, response) = api_client.tasks_api.retrieve_data(
+ task_id, type="chunk", quality=quality, number=chunk_id, _parse_response=False
+ )
+
+ chunk_file = io.BytesIO(response.data)
+ if zipfile.is_zipfile(chunk_file):
+ with zipfile.ZipFile(chunk_file, "r") as chunk_archive:
+ chunk_images = {
+ int(os.path.splitext(name)[0]): np.array(
+ Image.open(io.BytesIO(chunk_archive.read(name)))
+ )
+ for name in chunk_archive.namelist()
+ }
+ chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0]))
+ else:
+ chunk_images = dict(enumerate(read_video_file(chunk_file)))
+
+ assert sorted(chunk_images.keys()) == list(range(len(expected_chunk_frame_ids)))
+
+ for chunk_frame, abs_frame_id in zip(chunk_images, expected_chunk_frame_ids):
+ self._compare_images(
+ task_spec.read_frame(abs_frame_id),
+ chunk_images[chunk_frame],
+ must_be_identical=(
+ task_spec.source_data_type == _SourceDataType.images
+ and quality == "original"
+ ),
+ )
+
+ @parametrize("task_spec, task_id", _default_task_cases)
+ def test_can_get_job_meta(self, task_spec: _TaskSpec, task_id: int):
+ segment_params = self._compute_segment_params(task_spec)
+ with make_api_client(self._USERNAME) as api_client:
+ jobs = sorted(
+ get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id),
+ key=lambda j: j.start_frame,
+ )
+ assert len(jobs) == len(segment_params)
+
+ for (segment_start, segment_end), job in zip(segment_params, jobs):
+ (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id)
+
+ assert (job_meta.start_frame, job_meta.stop_frame) == (segment_start, segment_end)
+ assert job_meta.frame_filter == getattr(task_spec, "frame_filter", "")
+
+ segment_size = segment_end - segment_start + 1
+ assert job_meta.size == segment_size
+
+ task_frame_set = set(
+ range(job_meta.start_frame, job_meta.stop_frame + 1, task_spec.frame_step)
+ )
+ assert len(task_frame_set) == job_meta.size
+
+ if getattr(task_spec, "chunk_size", None):
+ assert job_meta.chunk_size == task_spec.chunk_size
+
+ if task_spec.source_data_type == _SourceDataType.video:
+ assert len(job_meta.frames) == 1
+ else:
+ assert len(job_meta.frames) == job_meta.size
+
+ @parametrize("task_spec, task_id", _default_task_cases)
+ def test_can_get_job_frames(self, task_spec: _TaskSpec, task_id: int):
+ with make_api_client(self._USERNAME) as api_client:
+ jobs = sorted(
+ get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id),
+ key=lambda j: j.start_frame,
+ )
+ for job in jobs:
+ (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id)
+
+ for quality, (frame_pos, abs_frame_id) in product(
+ ["original", "compressed"],
+ enumerate(range(job_meta.start_frame, job_meta.stop_frame)),
+ ):
+ rel_frame_id = (
+ abs_frame_id - getattr(task_spec, "start_frame", 0) // task_spec.frame_step
+ )
+ (_, response) = api_client.jobs_api.retrieve_data(
+ job.id,
+ type="frame",
+ quality=quality,
+ number=rel_frame_id,
+ _parse_response=False,
+ )
+
+ if task_spec.source_data_type == _SourceDataType.video:
+ frame_size = (job_meta.frames[0].width, job_meta.frames[0].height)
+ else:
+ frame_size = (
+ job_meta.frames[frame_pos].width,
+ job_meta.frames[frame_pos].height,
+ )
+
+ frame = Image.open(io.BytesIO(response.data))
+ assert frame_size == frame.size
+
+ self._compare_images(
+ task_spec.read_frame(abs_frame_id),
+ frame,
+ must_be_identical=(
+ task_spec.source_data_type == _SourceDataType.images
+ and quality == "original"
+ ),
+ )
+
+ @parametrize("task_spec, task_id", _default_task_cases)
+ @parametrize("indexing", ["absolute", "relative"])
+ def test_can_get_job_chunks(self, task_spec: _TaskSpec, task_id: int, indexing: str):
+ with make_api_client(self._USERNAME) as api_client:
+ jobs = sorted(
+ get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id),
+ key=lambda j: j.start_frame,
+ )
+
+ (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id)
+
+ for job in jobs:
+ (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id)
+
+ if task_spec.source_data_type == _SourceDataType.images:
+ assert job.data_original_chunk_type == "imageset"
+ assert job.data_compressed_chunk_type == "imageset"
+ elif task_spec.source_data_type == _SourceDataType.video:
+ assert job.data_original_chunk_type == "video"
+
+ if getattr(task_spec, "use_zip_chunks", False):
+ assert job.data_compressed_chunk_type == "imageset"
+ else:
+ assert job.data_compressed_chunk_type == "video"
+ else:
+ assert False
+
+ if indexing == "absolute":
+ chunk_count = math.ceil(task_meta.size / job_meta.chunk_size)
+
+ def get_task_chunk_abs_frame_ids(chunk_id: int) -> Sequence[int]:
+ return range(
+ task_meta.start_frame
+ + chunk_id * task_meta.chunk_size * task_spec.frame_step,
+ task_meta.start_frame
+ + min((chunk_id + 1) * task_meta.chunk_size, task_meta.size)
+ * task_spec.frame_step,
+ )
+
+ def get_job_frame_ids() -> Sequence[int]:
+ return range(
+ job_meta.start_frame, job_meta.stop_frame + 1, task_spec.frame_step
+ )
+
+ def get_expected_chunk_abs_frame_ids(chunk_id: int):
+ return sorted(
+ set(get_task_chunk_abs_frame_ids(chunk_id)) & set(get_job_frame_ids())
+ )
+
+ job_chunk_ids = (
+ task_chunk_id
+ for task_chunk_id in range(chunk_count)
+ if get_expected_chunk_abs_frame_ids(task_chunk_id)
+ )
+ else:
+ chunk_count = math.ceil(job_meta.size / job_meta.chunk_size)
+ job_chunk_ids = range(chunk_count)
+
+ def get_expected_chunk_abs_frame_ids(chunk_id: int):
+ return sorted(
+ frame
+ for frame in range(
+ job_meta.start_frame
+ + chunk_id * job_meta.chunk_size * task_spec.frame_step,
+ job_meta.start_frame
+ + min((chunk_id + 1) * job_meta.chunk_size, job_meta.size)
+ * task_spec.frame_step,
+ )
+ if not job_meta.included_frames or frame in job_meta.included_frames
+ )
+
+ for quality, chunk_id in product(["original", "compressed"], job_chunk_ids):
+ expected_chunk_abs_frame_ids = get_expected_chunk_abs_frame_ids(chunk_id)
+
+ kwargs = {}
+ if indexing == "absolute":
+ kwargs["number"] = chunk_id
+ elif indexing == "relative":
+ kwargs["index"] = chunk_id
+ else:
+ assert False
+
+ (_, response) = api_client.jobs_api.retrieve_data(
+ job.id,
+ type="chunk",
+ quality=quality,
+ **kwargs,
+ _parse_response=False,
+ )
+
+ chunk_file = io.BytesIO(response.data)
+ if zipfile.is_zipfile(chunk_file):
+ with zipfile.ZipFile(chunk_file, "r") as chunk_archive:
+ chunk_images = {
+ int(os.path.splitext(name)[0]): np.array(
+ Image.open(io.BytesIO(chunk_archive.read(name)))
+ )
+ for name in chunk_archive.namelist()
+ }
+ chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0]))
+ else:
+ chunk_images = dict(enumerate(read_video_file(chunk_file)))
+
+ assert sorted(chunk_images.keys()) == list(range(job_meta.size))
+
+ for chunk_frame, abs_frame_id in zip(
+ chunk_images, expected_chunk_abs_frame_ids
+ ):
+ self._compare_images(
+ task_spec.read_frame(abs_frame_id),
+ chunk_images[chunk_frame],
+ must_be_identical=(
+ task_spec.source_data_type == _SourceDataType.images
+ and quality == "original"
+ ),
+ )
+
+
@pytest.mark.usefixtures("restore_db_per_function")
class TestPatchTaskLabel:
def _get_task_labels(self, pid, user, **kwargs) -> List[models.Label]:
@@ -2229,7 +2757,7 @@ def test_admin_can_add_skeleton(self, tasks, admin_user):
@pytest.mark.usefixtures("restore_db_per_function")
-@pytest.mark.usefixtures("restore_cvat_data")
+@pytest.mark.usefixtures("restore_cvat_data_per_function")
@pytest.mark.usefixtures("restore_redis_ondisk_per_function")
class TestWorkWithTask:
_USERNAME = "admin1"
@@ -2286,7 +2814,13 @@ def _make_client(self) -> Client:
return Client(BASE_URL, config=Config(status_check_period=0.01))
@pytest.fixture(autouse=True)
- def setup(self, restore_db_per_function, restore_cvat_data, tmp_path: Path, admin_user: str):
+ def setup(
+ self,
+ restore_db_per_function,
+ restore_cvat_data_per_function,
+ tmp_path: Path,
+ admin_user: str,
+ ):
self.tmp_dir = tmp_path
self.client = self._make_client()
@@ -2778,6 +3312,15 @@ def test_user_cannot_update_task_with_cloud_storage_without_access(
)
assert response.status == HTTPStatus.FORBIDDEN
+ def test_malefactor_cannot_obtain_task_details_via_empty_partial_update_request(
+ self, regular_lonely_user, tasks
+ ):
+ task = next(iter(tasks))
+
+ with make_api_client(regular_lonely_user) as api_client:
+ with pytest.raises(ForbiddenException):
+ api_client.tasks_api.partial_update(task["id"])
+
@pytest.mark.parametrize("has_old_assignee", [False, True])
@pytest.mark.parametrize("new_assignee", [None, "same", "different"])
def test_can_update_assignee_updated_date_on_assignee_updates(
diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py
index 142c4354c4d1..e7ac8418b69a 100644
--- a/tests/python/sdk/test_auto_annotation.py
+++ b/tests/python/sdk/test_auto_annotation.py
@@ -29,6 +29,7 @@ def _common_setup(
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
+ restore_redis_ondisk_per_function,
):
logger = fxt_logger[0]
client = fxt_login[0]
diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py
index d5fbc0957eb7..542ad9a1e80c 100644
--- a/tests/python/sdk/test_datasets.py
+++ b/tests/python/sdk/test_datasets.py
@@ -23,6 +23,7 @@ def _common_setup(
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
+ restore_redis_ondisk_per_function,
):
logger = fxt_logger[0]
client = fxt_login[0]
diff --git a/tests/python/sdk/test_jobs.py b/tests/python/sdk/test_jobs.py
index ef46fcb8cf0e..3202e2957ff0 100644
--- a/tests/python/sdk/test_jobs.py
+++ b/tests/python/sdk/test_jobs.py
@@ -29,6 +29,7 @@ def setup(
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
fxt_stdout: io.StringIO,
+ restore_redis_ondisk_per_function,
):
self.tmp_path = tmp_path
logger, self.logger_stream = fxt_logger
diff --git a/tests/python/sdk/test_projects.py b/tests/python/sdk/test_projects.py
index 43d6257c03c6..b03df660d87a 100644
--- a/tests/python/sdk/test_projects.py
+++ b/tests/python/sdk/test_projects.py
@@ -32,6 +32,7 @@ def setup(
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
fxt_stdout: io.StringIO,
+ restore_redis_ondisk_per_function,
):
self.tmp_path = tmp_path
logger, self.logger_stream = fxt_logger
diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py
index 722cb37ab003..2bcbd122abff 100644
--- a/tests/python/sdk/test_pytorch.py
+++ b/tests/python/sdk/test_pytorch.py
@@ -36,6 +36,7 @@ def _common_setup(
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
+ restore_redis_ondisk_per_function,
):
logger = fxt_logger[0]
client = fxt_login[0]
diff --git a/tests/python/sdk/test_tasks.py b/tests/python/sdk/test_tasks.py
index 0dc5c0694e9c..54e0823d3311 100644
--- a/tests/python/sdk/test_tasks.py
+++ b/tests/python/sdk/test_tasks.py
@@ -33,6 +33,7 @@ def setup(
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
fxt_stdout: io.StringIO,
+ restore_redis_ondisk_per_function,
):
self.tmp_path = tmp_path
logger, self.logger_stream = fxt_logger
diff --git a/tests/python/shared/assets/jobs.json b/tests/python/shared/assets/jobs.json
index d4add795c783..415fb67d44cd 100644
--- a/tests/python/shared/assets/jobs.json
+++ b/tests/python/shared/assets/jobs.json
@@ -10,6 +10,7 @@
"created_date": "2024-07-15T15:34:53.594000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 1,
"guide_id": null,
@@ -51,6 +52,7 @@
"created_date": "2024-07-15T15:33:10.549000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 1,
"guide_id": null,
@@ -92,6 +94,7 @@
"created_date": "2024-03-21T20:50:05.838000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 3,
"guide_id": null,
@@ -125,6 +128,7 @@
"created_date": "2024-03-21T20:50:05.815000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 1,
"guide_id": null,
@@ -158,6 +162,7 @@
"created_date": "2024-03-21T20:50:05.811000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -191,6 +196,7 @@
"created_date": "2024-03-21T20:50:05.805000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -224,6 +230,7 @@
"created_date": "2023-05-26T16:11:23.946000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 3,
"guide_id": null,
@@ -257,6 +264,7 @@
"created_date": "2023-05-26T16:11:23.880000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 11,
"guide_id": null,
@@ -290,6 +298,7 @@
"created_date": "2023-03-27T19:08:07.649000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 4,
"guide_id": null,
@@ -331,6 +340,7 @@
"created_date": "2023-03-27T19:08:07.649000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 6,
"guide_id": null,
@@ -372,6 +382,7 @@
"created_date": "2023-03-10T11:57:31.614000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 2,
"guide_id": null,
@@ -413,6 +424,7 @@
"created_date": "2023-03-10T11:56:33.757000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 2,
"guide_id": null,
@@ -454,6 +466,7 @@
"created_date": "2023-03-01T15:36:26.668000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 2,
"guide_id": null,
@@ -495,6 +508,7 @@
"created_date": "2023-02-10T14:05:25.947000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -528,6 +542,7 @@
"created_date": "2022-12-01T12:53:10.425000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "video",
"dimension": "2d",
"frame_count": 25,
"guide_id": null,
@@ -569,6 +584,7 @@
"created_date": "2022-09-22T14:22:25.820000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 8,
"guide_id": null,
@@ -610,6 +626,7 @@
"created_date": "2022-06-08T08:33:06.505000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -649,6 +666,7 @@
"created_date": "2022-03-05T10:32:19.149000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 11,
"guide_id": null,
@@ -690,6 +708,7 @@
"created_date": "2022-03-05T09:33:10.420000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -723,6 +742,7 @@
"created_date": "2022-03-05T09:33:10.420000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -756,6 +776,7 @@
"created_date": "2022-03-05T09:33:10.420000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -795,6 +816,7 @@
"created_date": "2022-03-05T09:33:10.420000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 5,
"guide_id": null,
@@ -834,6 +856,7 @@
"created_date": "2022-03-05T08:30:48.612000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 14,
"guide_id": null,
@@ -867,6 +890,7 @@
"created_date": "2022-02-21T10:31:52.429000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 11,
"guide_id": null,
@@ -900,6 +924,7 @@
"created_date": "2022-02-16T06:26:54.631000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "3d",
"frame_count": 1,
"guide_id": null,
@@ -939,6 +964,7 @@
"created_date": "2022-02-16T06:25:48.168000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "video",
"dimension": "2d",
"frame_count": 25,
"guide_id": null,
@@ -978,6 +1004,7 @@
"created_date": "2021-12-14T18:50:29.458000Z",
"data_chunk_size": 72,
"data_compressed_chunk_type": "imageset",
+ "data_original_chunk_type": "imageset",
"dimension": "2d",
"frame_count": 23,
"guide_id": null,
diff --git a/tests/python/shared/fixtures/init.py b/tests/python/shared/fixtures/init.py
index 8e9d334f7a47..4a17454617d0 100644
--- a/tests/python/shared/fixtures/init.py
+++ b/tests/python/shared/fixtures/init.py
@@ -96,12 +96,20 @@ def pytest_addoption(parser):
def _run(command, capture_output=True):
_command = command.split() if isinstance(command, str) else command
try:
+ logger.debug(f"Executing a command: {_command}")
+
stdout, stderr = "", ""
if capture_output:
proc = run(_command, check=True, stdout=PIPE, stderr=PIPE) # nosec
stdout, stderr = proc.stdout.decode(), proc.stderr.decode()
else:
proc = run(_command) # nosec
+
+ if stdout:
+ logger.debug(f"Output (stdout): {stdout}")
+ if stderr:
+ logger.debug(f"Output (stderr): {stderr}")
+
return stdout, stderr
except CalledProcessError as exc:
message = f"Command failed: {' '.join(map(shlex.quote, _command))}."
@@ -232,20 +240,20 @@ def kube_restore_clickhouse_db():
def docker_restore_redis_inmem():
- docker_exec_redis_inmem(["redis-cli", "flushall"])
+ docker_exec_redis_inmem(["redis-cli", "-e", "flushall"])
def kube_restore_redis_inmem():
- kube_exec_redis_inmem(["redis-cli", "flushall"])
+ kube_exec_redis_inmem(["sh", "-c", 'redis-cli -e -a "${REDIS_PASSWORD}" flushall'])
def docker_restore_redis_ondisk():
- docker_exec_redis_ondisk(["redis-cli", "-p", "6666", "flushall"])
+ docker_exec_redis_ondisk(["redis-cli", "-e", "-p", "6666", "flushall"])
def kube_restore_redis_ondisk():
kube_exec_redis_ondisk(
- ["redis-cli", "-p", "6666", "-a", "${CVAT_REDIS_ONDISK_PASSWORD}", "flushall"]
+ ["sh", "-c", 'redis-cli -e -p 6666 -a "${CVAT_REDIS_ONDISK_PASSWORD}" flushall']
)
@@ -551,7 +559,7 @@ def restore_db_per_class(request):
@pytest.fixture(scope="function")
-def restore_cvat_data(request):
+def restore_cvat_data_per_function(request):
platform = request.config.getoption("--platform")
if platform == "local":
docker_restore_data_volumes()
@@ -592,6 +600,15 @@ def restore_redis_inmem_per_function(request):
kube_restore_redis_inmem()
+@pytest.fixture(scope="class")
+def restore_redis_inmem_per_class(request):
+ platform = request.config.getoption("--platform")
+ if platform == "local":
+ docker_restore_redis_inmem()
+ else:
+ kube_restore_redis_inmem()
+
+
@pytest.fixture(scope="function")
def restore_redis_ondisk_per_function(request):
platform = request.config.getoption("--platform")
@@ -599,3 +616,12 @@ def restore_redis_ondisk_per_function(request):
docker_restore_redis_ondisk()
else:
kube_restore_redis_ondisk()
+
+
+@pytest.fixture(scope="class")
+def restore_redis_ondisk_per_class(request):
+ platform = request.config.getoption("--platform")
+ if platform == "local":
+ docker_restore_redis_ondisk()
+ else:
+ kube_restore_redis_ondisk()
diff --git a/tests/python/shared/utils/helpers.py b/tests/python/shared/utils/helpers.py
index f336cb3f9111..ac5948182d78 100644
--- a/tests/python/shared/utils/helpers.py
+++ b/tests/python/shared/utils/helpers.py
@@ -1,10 +1,11 @@
-# Copyright (C) 2022 CVAT.ai Corporation
+# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import subprocess
+from contextlib import closing
from io import BytesIO
-from typing import List, Optional
+from typing import Generator, List, Optional
import av
import av.video.reformatter
@@ -13,7 +14,7 @@
from shared.fixtures.init import get_server_image_tag
-def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)):
+def generate_image_file(filename="image.png", size=(100, 50), color=(0, 0, 0)):
f = BytesIO()
f.name = filename
image = Image.new("RGB", size=size, color=color)
@@ -40,7 +41,7 @@ def generate_image_files(
return images
-def generate_video_file(num_frames: int, size=(50, 50)) -> BytesIO:
+def generate_video_file(num_frames: int, size=(100, 50)) -> BytesIO:
f = BytesIO()
f.name = "video.avi"
@@ -60,6 +61,19 @@ def generate_video_file(num_frames: int, size=(50, 50)) -> BytesIO:
return f
+def read_video_file(file: BytesIO) -> Generator[Image.Image, None, None]:
+ file.seek(0)
+
+ with av.open(file) as container:
+ video_stream = container.streams.video[0]
+
+ with closing(video_stream.codec_context): # pyav has a memory leak in stream.close()
+ with closing(container.demux(video_stream)) as demux_iter:
+ for packet in demux_iter:
+ for frame in packet.decode():
+ yield frame.to_image()
+
+
def generate_manifest(path: str) -> None:
command = [
"docker",