Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for collection overlays #5289

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions app/packages/looker/src/worker/disk-overlay-decoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,29 @@ export const decodeOverlayOnDisk = async (
sources: { [path: string]: string },
cls: string,
maskPathDecodingPromises: Promise<void>[] = [],
maskTargetsBuffers: ArrayBuffer[] = []
maskTargetsBuffers: ArrayBuffer[] = [],
overlayCollectionProcessingParams:
| { idx: number; cls: string }
| undefined = undefined
) => {
// handle all list types here
if (cls === DETECTIONS) {
if (cls === DETECTIONS && label.detections) {
const promises: Promise<void>[] = [];
for (const detection of label.detections) {

for (let i = 0; i < label.detections.length; i++) {
const detection = label.detections[i];
promises.push(
decodeOverlayOnDisk(
field,
detection,
coloring,
customizeColorSetting,
colorscale,
{},
sources,
DETECTION,
maskPathDecodingPromises,
maskTargetsBuffers
maskTargetsBuffers,
{ idx: i, cls: DETECTIONS }
)
);
}
Expand Down Expand Up @@ -74,16 +80,29 @@ export const decodeOverlayOnDisk = async (
return;
}

// if we have an explicit source defined from sample.urls, use that
// otherwise, use the path field from the label
let source = sources[`${field}.${overlayPathField}`];

if (typeof overlayCollectionProcessingParams !== "undefined") {
// example: for detections, we need to access the source from the parent label
// like: if field is "prediction_masks", we're trying to get "predictiion_masks.detections[INDEX].mask"
source =
sources[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any concern that the key we're constructing here doesn't exist on sources?

Copy link
Contributor Author

@sashankaryal sashankaryal Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which case it'd return undefined. In fact, it won't exist in OSS

a = {}

assert typeof a['foo.bar'] === 'undefined';

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my eyes told me the ]. deref was outside the key so thought we might be derefing an undefined. my mistake 🙏

`${field}.${overlayCollectionProcessingParams.cls.toLocaleLowerCase()}[${
overlayCollectionProcessingParams.idx
}].${overlayPathField}`
];
}

// convert absolute file path to a URL that we can "fetch" from
const overlayImageUrl = getSampleSrc(
sources[`${field}.${overlayPathField}`] || label[overlayPathField]
);
const overlayImageUrl = getSampleSrc(source || label[overlayPathField]);
const urlTokens = overlayImageUrl.split("?");

let baseUrl = overlayImageUrl;

// remove query params if not local URL
if (!urlTokens.at(1)?.startsWith("filepath=")) {
if (!urlTokens.at(1)?.startsWith("filepath=") && !source) {
baseUrl = overlayImageUrl.split("?")[0];
}

Expand Down
56 changes: 41 additions & 15 deletions fiftyone/server/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import typing as t

from functools import reduce
from pydash import get

import asyncio
import aiofiles
Expand All @@ -31,6 +32,8 @@
logger = logging.getLogger(__name__)

_ADDITIONAL_MEDIA_FIELDS = {
fol.Detection: "mask_path",
fol.Detections: "mask_path",
fol.Heatmap: "map_path",
fol.Segmentation: "mask_path",
OrthographicProjectionMetadata: "filepath",
Expand Down Expand Up @@ -68,7 +71,11 @@ async def get_metadata(
filepath = sample["filepath"]
metadata = sample.get("metadata", None)

opm_field, additional_fields = _get_additional_media_fields(collection)
(
opm_field,
detections_fields,
additional_fields,
) = _get_additional_media_fields(collection)

filepath_result, filepath_source, urls = _create_media_urls(
collection,
Expand All @@ -77,6 +84,7 @@ async def get_metadata(
url_cache,
additional_fields=additional_fields,
opm_field=opm_field,
detections_fields=detections_fields,
)
if filepath_result is not None:
filepath = filepath_result
Expand Down Expand Up @@ -389,13 +397,31 @@ def _create_media_urls(
cache: t.Dict,
additional_fields: t.Optional[t.List[str]] = None,
opm_field: t.Optional[str] = None,
detections_fields: t.Optional[t.List[str]] = None,
) -> t.Dict[str, str]:
filepath_source = None
media_fields = collection.app_config.media_fields.copy()

if additional_fields is not None:
media_fields.extend(additional_fields)

if detections_fields is not None:
for field in detections_fields:
detections = get(sample, field)

if not detections:
continue

detections_list = get(detections, "detections")

if not detections_list or len(detections_list) == 0:
continue

len_detections = len(detections_list)

for i in range(len_detections):
media_fields.append(f"{field}.detections[{i}].mask_path")

if (
sample_media_type == fom.POINT_CLOUD
or sample_media_type == fom.THREE_D
Expand All @@ -413,7 +439,10 @@ def _create_media_urls(
media_urls = []

for field in media_fields:
path = _deep_get(sample, field)
path = get(sample, field)

if not path:
continue

if path not in cache:
cache[path] = path
Expand All @@ -435,6 +464,8 @@ def _get_additional_media_fields(
) -> t.List[str]:
additional = []
opm_field = None
detections_fields = None

for cls, subfield_name in _ADDITIONAL_MEDIA_FIELDS.items():
for field_name, field in collection.get_field_schema(
flat=True
Expand All @@ -447,18 +478,13 @@ def _get_additional_media_fields(
if cls == OrthographicProjectionMetadata:
opm_field = field_name

additional.append(f"{field_name}.{subfield_name}")

return opm_field, additional
if cls == fol.Detections:
if detections_fields is None:
detections_fields = [field_name]
else:
detections_fields.append(field_name)

else:
additional.append(f"{field_name}.{subfield_name}")

def _deep_get(sample, keys, default=None):
"""
Get a value from a nested dictionary by specifying keys delimited by '.',
similar to lodash's ``_.get()``.
"""
return reduce(
lambda d, key: d.get(key, default) if isinstance(d, dict) else default,
keys.split("."),
sample,
)
return opm_field, detections_fields, additional
Loading