diff --git a/app/packages/core/src/components/Filters/BooleanFieldFilter.tsx b/app/packages/core/src/components/Filters/BooleanFieldFilter.tsx index 3407a01193..12769a7491 100644 --- a/app/packages/core/src/components/Filters/BooleanFieldFilter.tsx +++ b/app/packages/core/src/components/Filters/BooleanFieldFilter.tsx @@ -1,12 +1,10 @@ -import React from "react"; - import { + boolExcludeAtom, + boolIsMatchingAtom, booleanCountResults, booleanSelectedValuesAtom, - boolIsMatchingAtom, - boolOnlyMatchAtom, - boolExcludeAtom, } from "@fiftyone/state"; +import React from "react"; import CategoricalFilter from "./categoricalFilter/CategoricalFilter"; const BooleanFieldFilter = ({ @@ -27,7 +25,6 @@ const BooleanFieldFilter = ({ selectedValuesAtom={booleanSelectedValuesAtom({ path, modal })} isMatchingAtom={boolIsMatchingAtom({ path, modal })} - onlyMatchAtom={boolOnlyMatchAtom({ path, modal })} excludeAtom={boolExcludeAtom({ path, modal })} countsAtom={booleanCountResults({ path, diff --git a/app/packages/core/src/components/Filters/LabelFieldFilter.tsx b/app/packages/core/src/components/Filters/LabelFieldFilter.tsx index ea672b231f..14cfc9b2e9 100644 --- a/app/packages/core/src/components/Filters/LabelFieldFilter.tsx +++ b/app/packages/core/src/components/Filters/LabelFieldFilter.tsx @@ -2,12 +2,11 @@ import React from "react"; import { isMatchingAtom, - onlyMatchAtom, stringExcludeAtom, stringSelectedValuesAtom, } from "@fiftyone/state"; -import CategoricalFilter from "./categoricalFilter/CategoricalFilter"; import { labelTagsCount } from "../Sidebar/Entries/EntryCounts"; +import CategoricalFilter from "./categoricalFilter/CategoricalFilter"; const LabelTagFieldFilter = ({ path, @@ -27,7 +26,6 @@ const LabelTagFieldFilter = ({ selectedValuesAtom={stringSelectedValuesAtom({ modal, path })} excludeAtom={stringExcludeAtom({ modal, path })} - onlyMatchAtom={onlyMatchAtom({ modal, path })} isMatchingAtom={isMatchingAtom({ modal, path })} countsAtom={labelTagsCount({ modal, extended: false })} path={path} diff --git a/app/packages/core/src/components/Filters/NumericFieldFilter.tsx b/app/packages/core/src/components/Filters/NumericFieldFilter.tsx index 890fd81764..e33b386204 100644 --- a/app/packages/core/src/components/Filters/NumericFieldFilter.tsx +++ b/app/packages/core/src/components/Filters/NumericFieldFilter.tsx @@ -10,14 +10,14 @@ import styled from "styled-components"; import * as fos from "@fiftyone/state"; -import RangeSlider from "../Common/RangeSlider"; -import Checkbox from "../Common/Checkbox"; -import { Button } from "../utils"; import { DATE_FIELD, DATE_TIME_FIELD, FLOAT_FIELD } from "@fiftyone/utilities"; import { formatDateTime } from "../../utils/generic"; -import withSuspense from "./withSuspense"; +import Checkbox from "../Common/Checkbox"; +import RangeSlider from "../Common/RangeSlider"; import FieldLabelAndInfo from "../FieldLabelAndInfo"; +import { Button } from "../utils"; import FilterOption from "./categoricalFilter/filterOption/FilterOption"; +import withSuspense from "./withSuspense"; const NamedRangeSliderContainer = styled.div` margin: 3px; @@ -135,11 +135,6 @@ const NumericFieldFilter = ({ modal, defaultRange, }); - const onlyMatchAtom = fos.numericOnlyMatchAtom({ - path, - modal, - defaultRange, - }); const values = useRecoilValue( fos.rangeAtom({ modal, @@ -149,7 +144,6 @@ const NumericFieldFilter = ({ }) ); const setExcluded = excludeAtom ? useSetRecoilState(excludeAtom) : null; - const setOnlyMatch = onlyMatchAtom ? useSetRecoilState(onlyMatchAtom) : null; const setIsMatching = isMatchingAtom ? useSetRecoilState(isMatchingAtom) : null; @@ -207,7 +201,6 @@ const NumericFieldFilter = ({ const initializeSettings = () => { setFilter([null, null]); setExcluded && setExcluded(false); - setOnlyMatch && setOnlyMatch(true); setIsMatching && setIsMatching(!nestedField); }; @@ -304,7 +297,6 @@ const NumericFieldFilter = ({ nestedField={nestedField} shouldNotShowExclude={false} // only boolean fields don't use exclude excludeAtom={excludeAtom} - onlyMatchAtom={onlyMatchAtom} isMatchingAtom={isMatchingAtom} valueName={field?.name ?? ""} path={path} diff --git a/app/packages/core/src/components/Filters/StringFieldFilter.tsx b/app/packages/core/src/components/Filters/StringFieldFilter.tsx index aa8f3df7c1..a2697f5598 100644 --- a/app/packages/core/src/components/Filters/StringFieldFilter.tsx +++ b/app/packages/core/src/components/Filters/StringFieldFilter.tsx @@ -1,12 +1,10 @@ -import React from "react"; import * as fos from "@fiftyone/state"; - import { isMatchingAtom, - onlyMatchAtom, stringExcludeAtom, stringSelectedValuesAtom, } from "@fiftyone/state"; +import React from "react"; import CategoricalFilter from "./categoricalFilter/CategoricalFilter"; const StringFieldFilter = ({ @@ -27,7 +25,6 @@ const StringFieldFilter = ({ selectedValuesAtom={stringSelectedValuesAtom({ modal, path })} excludeAtom={stringExcludeAtom({ modal, path })} - onlyMatchAtom={onlyMatchAtom({ modal, path })} isMatchingAtom={isMatchingAtom({ modal, path })} countsAtom={fos.stringCountResults({ modal, diff --git a/app/packages/core/src/components/Filters/categoricalFilter/CategoricalFilter.tsx b/app/packages/core/src/components/Filters/categoricalFilter/CategoricalFilter.tsx index efe73f4366..bcb5ceac19 100644 --- a/app/packages/core/src/components/Filters/categoricalFilter/CategoricalFilter.tsx +++ b/app/packages/core/src/components/Filters/categoricalFilter/CategoricalFilter.tsx @@ -175,7 +175,6 @@ interface Props { selectedValuesAtom: RecoilState; excludeAtom: RecoilState; // toggles select or exclude isMatchingAtom: RecoilState; // toggles match or filter - onlyMatchAtom: RecoilState; // toggles onlyMatch mode (omit empty samples) countsAtom: RecoilValue<{ count: number; results: [T["value"], number][]; @@ -190,7 +189,6 @@ const CategoricalFilter = ({ countsAtom, selectedValuesAtom, excludeAtom, - onlyMatchAtom, isMatchingAtom, path, modal, @@ -203,6 +201,7 @@ const CategoricalFilter = ({ : path.startsWith("_label_tags") ? "label tag" : name; + const selectedCounts = useRef(new Map()); const onSelect = useOnSelect(selectedValuesAtom, selectedCounts); const useSearch = getUseSearch({ modal, path }); @@ -213,7 +212,7 @@ const CategoricalFilter = ({ // id fields should always use filter mode const neverShowExpansion = field?.ftype?.includes("ObjectIdField"); - + if (countsLoadable.state === "hasError") throw countsLoadable.contents; if (countsLoadable.state !== "hasValue") return null; const { count, results } = countsLoadable.contents; @@ -267,7 +266,6 @@ const CategoricalFilter = ({ selectedValuesAtom={selectedValuesAtom} excludeAtom={excludeAtom} isMatchingAtom={isMatchingAtom} - onlyMatchAtom={onlyMatchAtom} modal={modal} totalCount={count} selectedCounts={selectedCounts} diff --git a/app/packages/core/src/components/Filters/categoricalFilter/Wrapper.tsx b/app/packages/core/src/components/Filters/categoricalFilter/Wrapper.tsx index 2853f8b614..381371ca3a 100644 --- a/app/packages/core/src/components/Filters/categoricalFilter/Wrapper.tsx +++ b/app/packages/core/src/components/Filters/categoricalFilter/Wrapper.tsx @@ -8,18 +8,17 @@ import { import * as fos from "@fiftyone/state"; -import FilterOption from "./filterOption/FilterOption"; import Checkbox from "../../Common/Checkbox"; import { Button } from "../../utils"; import { CHECKBOX_LIMIT, nullSort } from "../utils"; -import { isKeypointLabel, V } from "./CategoricalFilter"; +import { V, isKeypointLabel } from "./CategoricalFilter"; +import FilterOption from "./filterOption/FilterOption"; interface WrapperProps { results: [V["value"], number][]; selectedValuesAtom: RecoilState; excludeAtom: RecoilState; isMatchingAtom: RecoilState; - onlyMatchAtom: RecoilState; color: string; totalCount: number; modal: boolean; @@ -34,7 +33,6 @@ const Wrapper = ({ selectedValuesAtom, excludeAtom, isMatchingAtom, - onlyMatchAtom, modal, path, selectedCounts, @@ -44,7 +42,6 @@ const Wrapper = ({ const [selected, setSelected] = useRecoilState(selectedValuesAtom); const selectedSet = new Set(selected); const setExcluded = excludeAtom ? useSetRecoilState(excludeAtom) : null; - const setOnlyMatch = onlyMatchAtom ? useSetRecoilState(onlyMatchAtom) : null; const setIsMatching = isMatchingAtom ? useSetRecoilState(isMatchingAtom) : null; @@ -87,7 +84,6 @@ const Wrapper = ({ const initializeSettings = () => { setExcluded && setExcluded(false); - setOnlyMatch && setOnlyMatch(true); setIsMatching && setIsMatching(!nestedField); }; @@ -144,7 +140,6 @@ const Wrapper = ({ nestedField={nestedField} shouldNotShowExclude={shouldNotShowExclude} excludeAtom={excludeAtom} - onlyMatchAtom={onlyMatchAtom} isMatchingAtom={isMatchingAtom} valueName={name} color={color} diff --git a/app/packages/core/src/components/Filters/categoricalFilter/filterOption/FilterOption.tsx b/app/packages/core/src/components/Filters/categoricalFilter/filterOption/FilterOption.tsx index 8cb7239ff6..29d5afa561 100644 --- a/app/packages/core/src/components/Filters/categoricalFilter/filterOption/FilterOption.tsx +++ b/app/packages/core/src/components/Filters/categoricalFilter/filterOption/FilterOption.tsx @@ -1,27 +1,24 @@ -import React, { PropsWithChildren, useEffect } from "react"; -import styled from "styled-components"; -import { RecoilState, useRecoilState, useSetRecoilState } from "recoil"; import FilterAltIcon from "@mui/icons-material/FilterAlt"; import FilterAltOffIcon from "@mui/icons-material/FilterAltOff"; -import ImageIcon from "@mui/icons-material/Image"; import HideImageIcon from "@mui/icons-material/HideImage"; +import ImageIcon from "@mui/icons-material/Image"; import { IconButton } from "@mui/material"; -import { useSpring } from "framer-motion"; import Color from "color"; +import React, { useEffect } from "react"; +import { RecoilState, useRecoilState, useSetRecoilState } from "recoil"; +import styled from "styled-components"; -import { useOutsideClick } from "@fiftyone/state"; import { useTheme } from "@fiftyone/components/src/components/ThemeProvider"; import Tooltip from "@fiftyone/components/src/components/Tooltip"; +import { useOutsideClick } from "@fiftyone/state"; -import { PopoutDiv } from "../../../utils"; -import Item from "./FilterItem"; import { Popout } from "@fiftyone/components"; +import Item from "./FilterItem"; interface Props { nestedField: string | undefined; // nested ListFields only ("detections") shouldNotShowExclude: boolean; // for BooleanFields excludeAtom: RecoilState; - onlyMatchAtom: RecoilState; isMatchingAtom: RecoilState; valueName: string; color: string; @@ -132,7 +129,6 @@ const FilterOption: React.FC = ({ nestedField, shouldNotShowExclude, excludeAtom, - onlyMatchAtom, isMatchingAtom, }) => { const isLabelTag = path?.startsWith("_label_tags"); @@ -140,7 +136,6 @@ const FilterOption: React.FC = ({ const [open, setOpen] = React.useState(false); const [excluded, setExcluded] = useRecoilState(excludeAtom); - const setOnlyMatch = onlyMatchAtom ? useSetRecoilState(onlyMatchAtom) : null; const setIsMatching = isMatchingAtom ? useSetRecoilState(isMatchingAtom) : null; @@ -213,25 +208,21 @@ const FilterOption: React.FC = ({ const onSelectFilter = () => { setExcluded && setExcluded(false); setIsMatching && setIsMatching(false); - setOnlyMatch && setOnlyMatch(true); }; const onSelectNegativeFilter = () => { setExcluded && setExcluded(true); setIsMatching && setIsMatching(false); - setOnlyMatch && setOnlyMatch(false); }; const onSelectMatch = () => { setExcluded && setExcluded(false); setIsMatching && setIsMatching(true); - setOnlyMatch && setOnlyMatch(true); }; const onSelectNegativeMatch = () => { setExcluded && setExcluded(true); setIsMatching && setIsMatching(true); - setOnlyMatch && setOnlyMatch(true); }; const children = ( diff --git a/app/packages/state/src/recoil/aggregations.ts b/app/packages/state/src/recoil/aggregations.ts index 837f45ec7a..f9c7df8c7a 100644 --- a/app/packages/state/src/recoil/aggregations.ts +++ b/app/packages/state/src/recoil/aggregations.ts @@ -203,7 +203,7 @@ export const stringCountResults = selectorFamily({ const isSkeletonPoints = VALID_KEYPOINTS.includes( get(schemaAtoms.field(parent)).embeddedDocType - ) && keys[2] === "points"; + ) && keys.slice(-1)[0] === "points"; if (isSkeletonPoints) { const skeleton = get(selectors.skeleton(parent)); diff --git a/fiftyone/server/samples.py b/fiftyone/server/samples.py index 5918e116a7..6e2d0d0c83 100644 --- a/fiftyone/server/samples.py +++ b/fiftyone/server/samples.py @@ -78,7 +78,6 @@ async def paginate_samples( stages=stages, filters=filters, pagination_data=pagination_data, - count_label_tags=True, extended_stages=extended_stages, sample_filter=sample_filter, reload=reload, diff --git a/fiftyone/server/view.py b/fiftyone/server/view.py index e86c987657..778bf92831 100644 --- a/fiftyone/server/view.py +++ b/fiftyone/server/view.py @@ -77,7 +77,6 @@ def get_view( stages=None, filters=None, pagination_data=False, - count_label_tags=False, extended_stages=None, sample_filter=None, reload=True, @@ -90,8 +89,6 @@ def get_view( stages (None): an optional list of serialized :class:`fiftyone.core.stages.ViewStage` instances filters (None): an optional ``dict`` of App defined filters - count_label_tags (False): whether to includes hidden ``_label_tags`` - counts on sample documents pagination_data (False): whether process samples as pagination data - excludes all :class:`fiftyone.core.fields.DictField` values - filters label fields @@ -134,11 +131,10 @@ def get_view( # omit all dict field values for performance, not needed by grid view = _project_pagination_paths(view) - if filters or extended_stages or count_label_tags: + if filters or extended_stages or pagination_data: view = get_extended_view( view, filters, - count_label_tags=count_label_tags, pagination_data=pagination_data, extended_stages=extended_stages, ) @@ -149,7 +145,6 @@ def get_view( def get_extended_view( view, filters=None, - count_label_tags=False, extended_stages=None, pagination_data=False, ): @@ -158,8 +153,6 @@ def get_extended_view( Args: view: a :class:`fiftyone.core.collections.SampleCollection` filters: an optional ``dict`` of App defined filters - count_label_tags (False): whether to set the hidden ``_label_tags`` - field with counts of tags with respect to all label fields extended_stages (None): extended view stages pagination_data (False): filters label data @@ -187,17 +180,12 @@ def get_extended_view( if label_tags: view = _match_label_tags(view, label_tags) - stages = _make_filter_stages( - view, - filters, - count_label_tags=count_label_tags, - pagination_data=pagination_data, - ) + stages = _make_filter_stages(view, filters) for stage in stages: view = view.add_stage(stage) - if count_label_tags: + if pagination_data: view = _add_labels_tags_counts(view) return view @@ -272,62 +260,62 @@ def _project_pagination_paths(view: foc.SampleCollection): def _make_filter_stages( view, filters, - count_label_tags=False, - pagination_data=False, ): stages = [] queries = [] - for path, field, prefix, args in _iter_paths(view, filters): + for path, label_path, field, args in _iter_paths(view, filters): is_matching = args.get("isMatching", True) - only_matches = args.get("onlyMatch", True) - + path_field = view.get_field(path) is_label_field = _is_label(field) - if is_label_field and issubclass( - field.document_type, (fol.Keypoint, fol.Keypoints) + if ( + is_label_field + and issubclass(field.document_type, (fol.Keypoint, fol.Keypoints)) + and isinstance(path_field, (fof.KeypointsField, fof.ListField)) ): continue - if not is_label_field or (is_matching or only_matches): - queries.append(_make_query(path, view.get_field(path), args)) + if args.get("exclude") and not is_matching: + continue + + queries.append(_make_query(path, path_field, args)) if queries: stages.append(fosg.Match({"$and": queries})) - for path, field, prefix, args in _iter_paths(view, filters): - if not _is_label(field): - continue - + for path, label_path, label_field, args in _iter_paths( + view, filters, labels=True + ): is_matching = args.get("isMatching", True) - only_matches = args.get("onlyMatch", True) - parent = field field = view.get_field(path) - if issubclass(parent.document_type, (fol.Keypoint, fol.Keypoints)): + if issubclass( + label_field.document_type, (fol.Keypoint, fol.Keypoints) + ) and isinstance(field, fof.ListField): expr = _make_keypoint_list_filter(args, view, path, field) if expr is not None: stages.append( fosg.FilterKeypoints( - prefix + parent.name, - only_matches=only_matches, + label_path, + only_matches=True, **expr, ) ) - elif not is_matching and (pagination_data or not count_label_tags): + elif not is_matching: key = field.db_field if field.db_field else field.name expr = _make_scalar_expression(F(key), args, field, is_label=True) if expr is not None: stages.append( fosg.FilterLabels( - prefix + parent.name, + label_path, expr, - only_matches=only_matches, + only_matches=not args.get("exclude", False), ) ) return stages -def _iter_paths(view, filters): +def _iter_paths(view, filters, labels=False): for path in sorted(filters): if path == "tags" or path.startswith("_"): continue @@ -345,6 +333,9 @@ def _iter_paths(view, filters): parent_path = ".".join(parent_path.split(".")[:-1]) parent_field = view.get_field(parent_path) + if labels and not _is_label(parent_field): + continue + yield path, parent_path, parent_field, filters[path] @@ -375,11 +366,12 @@ def _is_label(field): def _make_query(path, field, args): keys = path.split(".") path = ".".join(keys[:-1] + [field.db_field or field.name]) - if isinstance(field, fof.ListField): + if isinstance(field, fof.ListField) and field.field: field = field.field if isinstance( - field, (fof.DateField, fof.DateTimeField, fof.FloatField, fof.IntField) + field, + (fof.DateField, fof.DateTimeField, fof.FloatField, fof.IntField), ): mn, mx = args["range"] if isinstance(field, (fof.DateField, fof.DateTimeField)): @@ -400,7 +392,7 @@ def _make_query(path, field, args): ] } - values = args["values"] + values = args.get("values", None) if isinstance(field, fof.ObjectIdField): values = list(map(lambda v: ObjectId(v), args["values"])) @@ -415,10 +407,10 @@ def _make_query(path, field, args): } if not true and false: - return {path: {"$neq" if args["exclude"] else "$eq": False}} + return {path: {"$ne" if args["exclude"] else "$eq": False}} if true and not false: - return {path: {"$neq" if args["exclude"] else "$eq": True}} + return {path: {"$ne" if args["exclude"] else "$eq": True}} if not true and not false: return { @@ -471,6 +463,10 @@ def _make_scalar_expression(f, args, field, list_field=False, is_label=False): expr = f.is_in(values) exclude = args["exclude"] + if exclude: + # pylint: disable=invalid-unary-operand-type + expr = ~expr + if none and not is_label and not list_field: if exclude: expr &= f.exists() @@ -678,18 +674,19 @@ def _match_label_tags(view: foc.SampleCollection, label_tags): matching = label_tags["isMatching"] expr = lambda exclude, values: {"$nin" if exclude else "$in": values} - view = view.mongo( - [ - { - "$match": { - "$or": [ - {f"{path}.tags": expr(exclude, values)} - for path in label_paths - ] + if not exclude or matching: + view = view.mongo( + [ + { + "$match": { + "$or": [ + {f"{path}.tags": expr(exclude, values)} + for path in label_paths + ] + } } - } - ] - ) + ] + ) if not matching and exclude: view = view.exclude_labels( diff --git a/tests/unittests/server_tests.py b/tests/unittests/server_tests.py index 5d2169b851..88f1ea164c 100644 --- a/tests/unittests/server_tests.py +++ b/tests/unittests/server_tests.py @@ -5,11 +5,15 @@ | `voxel51.com `_ | """ +import math import unittest +import fiftyone as fo import fiftyone.core.dataset as fod import fiftyone.core.fields as fof import fiftyone.core.labels as fol +import fiftyone.core.odm as foo +import fiftyone.core.sample as fos import fiftyone.server.view as fosv from decorators import drop_datasets @@ -17,645 +21,761 @@ class ServerViewTests(unittest.TestCase): @drop_datasets - def test_extended_view_image_label_filters_samples(self): + def test_extended_image_sample(self): + dataset = fod.Dataset("test") + sample = fos.Sample( + filepath="image.png", + predictions=fol.Detections( + detections=[ + fol.Detection( + label="carrot", confidence=0.25, tags=["one", "two"] + ), + fol.Detection( + label="not_carrot", confidence=0.75, tags=["two"] + ), + ] + ), + bool=True, + int=1, + str="str", + list_bool=[True], + list_int=[1, 2], + list_str=["one", "two"], + ) + dataset.add_sample(sample) + filters = { + "id": { + "values": [dataset.first().id], + "exclude": False, + } + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "id": { + "values": [dataset.first().id], + "exclude": True, + } + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + filters = { "predictions.detections.label": { "values": ["carrot"], "exclude": False, - "onlyMatch": True, "isMatching": False, - "_CLS": "str", }, "predictions.detections.confidence": { "range": [0.5, 1], - "_CLS": "numeric", "exclude": False, - "onlyMatch": True, "isMatching": False, }, } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) - dataset = fod.Dataset("test") - dataset.add_sample_field( - "predictions", fof.EmbeddedDocumentField, fol.Detections - ) - - returned = fosv.get_view( - "test", - filters=filters, - count_label_tags=True, - )._pipeline() - - expected = [ - { - "$match": { - "$and": [ - { - "$and": [ - { - "predictions.detections.confidence": { - "$gte": 0.5 - } - }, - { - "predictions.detections.confidence": { - "$lte": 1 - } - }, - ], - }, - {"predictions.detections.label": {"$in": ["carrot"]}}, - ], - }, - }, - {"$addFields": {"_label_tags": []}}, - { - "$addFields": { - "_label_tags": { - "$cond": { - "if": {"$gt": ["$predictions", None]}, - "then": { - "$concatArrays": [ - "$_label_tags", - { - "$reduce": { - "input": "$predictions.detections", - "initialValue": [], - "in": { - "$concatArrays": [ - "$$value", - "$$this.tags", - ], - }, - }, - }, - ], - }, - "else": "$_label_tags", - }, - }, - }, - }, - { - "$addFields": { - "_label_tags": { - "$function": { - "body": "function(items) {let counts = {};items && items.forEach((i) => {counts[i] = 1 + (counts[i] || 0);});return counts;}", - "args": ["$_label_tags"], - "lang": "js", - }, - }, - }, - }, - ] - - self.assertEqual(expected, returned) - - @drop_datasets - def test_extended_view_image_label_filters_aggregations(self): filters = { "predictions.detections.label": { "values": ["carrot"], "exclude": False, - "onlyMatch": True, "isMatching": False, - "_CLS": "str", }, "predictions.detections.confidence": { - "range": [0.5, 1], - "_CLS": "numeric", + "range": [0.0, 0.5], "exclude": False, - "onlyMatch": True, "isMatching": False, }, } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().predictions.detections), 1) - dataset = fod.Dataset("test") - dataset.add_sample_field( - "predictions", fof.EmbeddedDocumentField, fol.Detections + filters = { + "list_str": { + "values": ["one"], + "exclude": False, + }, + "list_int": { + "range": [0, 2], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().list_str), 2) + self.assertEqual(len(view.first().list_int), 2) + + filters = { + "list_str": { + "values": ["empty"], + "exclude": False, + }, + "list_int": { + "range": [0, 2], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "list_str": { + "values": ["one"], + "exclude": False, + }, + "list_int": { + "range": [3, 4], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "list_bool": { + "true": False, + "false": True, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "list_bool": { + "true": False, + "false": True, + "exclude": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "list_bool": { + "true": True, + "false": False, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + view = fosv.get_view("test", pagination_data=True) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"one": 1, "two": 2}) - returned = fosv.get_view( - "test", filters=filters, count_label_tags=False - )._pipeline() - - expected = [ - { - "$match": { - "$and": [ - { - "$and": [ - { - "predictions.detections.confidence": { - "$gte": 0.5 - } - }, - { - "predictions.detections.confidence": { - "$lte": 1 - } - }, - ], - }, - {"predictions.detections.label": {"$in": ["carrot"]}}, - ], - }, - }, - { - "$addFields": { - "predictions.detections": { - "$filter": { - "input": "$predictions.detections", - "cond": { - "$or": [ - { - "$and": [ - { - "$gte": [ - "$$this.confidence", - 0.5, - ] - }, - {"$lte": ["$$this.confidence", 1]}, - ], - }, - {"$in": ["$$this.confidence", []]}, - ], - }, - }, - }, - }, - }, - { - "$addFields": { - "predictions.detections": { - "$filter": { - "input": "$predictions.detections", - "cond": {"$in": ["$$this.label", ["carrot"]]}, - }, - }, - }, - }, - ] - - self.assertEqual(expected, returned) + filters = { + "_label_tags": { + "values": ["two"], + "exclude": False, + "isMatching": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().predictions.detections), 2) - @drop_datasets - def test_extended_view_video_label_filters_samples(self): filters = { - "frames.detections.detections.index": { - "range": [27, 54], - "_CLS": "numeric", + "_label_tags": { + "values": ["one"], "exclude": False, - "onlyMatch": True, "isMatching": False, }, - "frames.detections.detections.label": { - "values": ["vehicle"], - "exclude": False, - "onlyMatch": True, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().predictions.detections), 1) + + view = fosv.get_view("test", pagination_data=True, filters=filters) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"one": 1, "two": 1}) + + filters = { + "_label_tags": { + "values": ["two"], + "exclude": True, "isMatching": False, - "_CLS": "str", }, } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().predictions.detections), 0) - dataset = fod.Dataset("test") - dataset.media_type = "video" - dataset.add_frame_field( - "detections", fof.EmbeddedDocumentField, fol.Detections + filters = { + "_label_tags": { + "values": ["one"], + "exclude": True, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual( + view.first().predictions.detections[0].label, "not_carrot" ) - returned = fosv.get_view( - "test", - filters=filters, - count_label_tags=True, - )._pipeline()[1:] - - expected = [ - { - "$match": { - "$and": [ - { - "$and": [ - { - "frames.detections.detections.index": { - "$gte": 27 - } - }, - { - "frames.detections.detections.index": { - "$lte": 54 - } - }, - ], - }, - { - "frames.detections.detections.label": { - "$in": ["vehicle"] - } - }, - ], - }, - }, - {"$addFields": {"_label_tags": []}}, - { - "$addFields": { - "_label_tags": { - "$concatArrays": [ - "$_label_tags", - { - "$reduce": { - "input": "$frames", - "initialValue": [], - "in": { - "$concatArrays": [ - "$$value", - { - "$cond": { - "if": { - "$gt": [ - "$$this.detections.detections", - None, - ], - }, - "then": { - "$reduce": { - "input": "$$this.detections.detections", - "initialValue": [], - "in": { - "$concatArrays": [ - "$$value", - "$$this.tags", - ], - }, - }, - }, - "else": [], - }, - }, - ], - }, - }, - }, - ], - }, - }, - }, - { - "$addFields": { - "_label_tags": { - "$function": { - "body": "function(items) {let counts = {};items && items.forEach((i) => {counts[i] = 1 + (counts[i] || 0);});return counts;}", - "args": ["$_label_tags"], - "lang": "js", - }, - }, - }, - }, - ] - - self.assertEqual(expected, returned) + view = fosv.get_view("test", pagination_data=True, filters=filters) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"two": 1}) + + filters = { + "_label_tags": { + "values": ["one"], + "exclude": True, + "isMatching": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) @drop_datasets - def test_extended_view_video_label_filters_aggregations(self): + def test_extended_frame_sample(self): + dataset = fod.Dataset("test") + sample = fos.Sample( + filepath="video.mp4", + ) + sample.frames[1] = fo.Frame( + predictions=fol.Detections( + detections=[ + fol.Detection( + label="carrot", confidence=0.25, tags=["one", "two"] + ), + fol.Detection( + label="not_carrot", confidence=0.75, tags=["two"] + ), + ] + ) + ) + dataset.add_sample(sample) + filters = { - "frames.detections.detections.index": { - "range": [27, 54], - "_CLS": "numeric", + "frames.predictions.detections.label": { + "values": ["carrot"], "exclude": False, - "onlyMatch": True, "isMatching": False, }, - "frames.detections.detections.label": { - "values": ["vehicle"], + "frames.predictions.detections.confidence": { + "range": [0.5, 1], "exclude": False, - "onlyMatch": True, "isMatching": False, - "_CLS": "str", }, } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) - dataset = fod.Dataset("test") - dataset.media_type = "video" - dataset.add_frame_field( - "detections", fof.EmbeddedDocumentField, fol.Detections + filters = { + "frames.predictions.detections.label": { + "values": ["carrot"], + "exclude": False, + "isMatching": False, + }, + "frames.predictions.detections.confidence": { + "range": [0.0, 0.5], + "exclude": False, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().frames[1].predictions.detections), 1) + + view = fosv.get_view("test", pagination_data=True) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"one": 1, "two": 2}) - returned = fosv.get_view( - "test", filters=filters, count_label_tags=False - )._pipeline()[1:] - - expected = [ - { - "$match": { - "$and": [ - { - "$and": [ - { - "frames.detections.detections.index": { - "$gte": 27 - } - }, - { - "frames.detections.detections.index": { - "$lte": 54 - } - }, - ], - }, - { - "frames.detections.detections.label": { - "$in": ["vehicle"] - } - }, - ], - }, - }, - { - "$addFields": { - "frames": { - "$map": { - "input": "$frames", - "as": "frame", - "in": { - "$mergeObjects": [ - "$$frame", - { - "detections": { - "$mergeObjects": [ - "$$frame.detections", - { - "detections": { - "$filter": { - "input": "$$frame.detections.detections", - "cond": { - "$or": [ - { - "$and": [ - { - "$gte": [ - "$$this.index", - 27, - ], - }, - { - "$lte": [ - "$$this.index", - 54, - ], - }, - ], - }, - { - "$in": [ - "$$this.index", - [], - ], - }, - ], - }, - }, - }, - }, - ], - }, - }, - ], - }, - }, - }, - }, - }, - { - "$addFields": { - "frames": { - "$map": { - "input": "$frames", - "as": "frame", - "in": { - "$mergeObjects": [ - "$$frame", - { - "detections": { - "$mergeObjects": [ - "$$frame.detections", - { - "detections": { - "$filter": { - "input": "$$frame.detections.detections", - "cond": { - "$in": [ - "$$this.label", - [ - "vehicle" - ], - ], - }, - }, - }, - }, - ], - }, - }, - ], - }, - }, - }, - }, - }, - ] - - self.assertEqual(expected, returned) + filters = { + "_label_tags": { + "values": ["two"], + "exclude": False, + "isMatching": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().frames[1].predictions.detections), 2) - @drop_datasets - def test_extended_view_video_match_label_tags_aggregations(self): filters = { "_label_tags": { "values": ["one"], "exclude": False, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().frames[1].predictions.detections), 1) + + view = fosv.get_view("test", pagination_data=True, filters=filters) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"one": 1, "two": 1}) + + filters = { + "_label_tags": { + "values": ["two"], + "exclude": True, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().frames[1].predictions.detections), 0) + + filters = { + "_label_tags": { + "values": ["one"], + "exclude": True, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual( + view.first().frames[1].predictions.detections[0].label, + "not_carrot", + ) + + view = fosv.get_view("test", pagination_data=True, filters=filters) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"two": 1}) + + filters = { + "_label_tags": { + "values": ["one"], + "exclude": True, "isMatching": True, - } + }, } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + @drop_datasets + def test_extended_dynamic_image_sample(self): dataset = fod.Dataset("test") - dataset.media_type = "video" - dataset.add_frame_field( - "detections", fof.EmbeddedDocumentField, fol.Detections + sample = fos.Sample( + filepath="image.png", + dynamic=fo.DynamicEmbeddedDocument( + predictions=fol.Detections( + detections=[ + fol.Detection( + label="carrot", + confidence=0.25, + tags=["one", "two"], + ), + fol.Detection( + label="not_carrot", confidence=0.75, tags=["two"] + ), + ] + ), + bool=True, + int=1, + str="str", + list_bool=[True], + list_int=[1, 2], + list_str=["one", "two"], + ), + dynamic_list=[ + fo.DynamicEmbeddedDocument( + bool=True, + int=1, + str="str", + list_bool=[True], + list_int=[1, 2], + list_str=["one", "two"], + ) + ], ) + dataset.add_sample(sample) + dataset.add_dynamic_sample_fields() - returned = fosv.get_view( - "test", filters=filters, count_label_tags=True - )._pipeline()[1:] - - expected = [ - { - "$match": { - "$or": [ - {"frames.detections.detections.tags": {"$in": ["one"]}} - ], - }, - }, - {"$addFields": {"_label_tags": []}}, - { - "$addFields": { - "_label_tags": { - "$concatArrays": [ - "$_label_tags", - { - "$reduce": { - "input": "$frames", - "initialValue": [], - "in": { - "$concatArrays": [ - "$$value", - { - "$cond": { - "if": { - "$gt": [ - "$$this.detections.detections", - None, - ], - }, - "then": { - "$reduce": { - "input": "$$this.detections.detections", - "initialValue": [], - "in": { - "$concatArrays": [ - "$$value", - "$$this.tags", - ], - }, - }, - }, - "else": [], - }, - }, - ], - }, - }, - }, - ], - }, - }, - }, - { - "$addFields": { - "_label_tags": { - "$function": { - "body": "function(items) {let counts = {};items && items.forEach((i) => {counts[i] = 1 + (counts[i] || 0);});return counts;}", - "args": ["$_label_tags"], - "lang": "js", - }, - }, - }, - }, - ] - - self.assertEqual(expected, returned) + filters = { + "dynamic.predictions.detections.label": { + "values": ["carrot"], + "exclude": False, + "isMatching": False, + }, + "dynamic.predictions.detections.confidence": { + "range": [0.5, 1], + "exclude": False, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic.predictions.detections.label": { + "values": ["carrot"], + "exclude": False, + "isMatching": False, + }, + "dynamic.predictions.detections.confidence": { + "range": [0.0, 0.5], + "exclude": False, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().dynamic.predictions.detections), 1) + + filters = { + "dynamic.list_str": { + "values": ["one"], + "exclude": False, + }, + "dynamic.list_int": { + "range": [0, 2], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().dynamic.list_str), 2) + self.assertEqual(len(view.first().dynamic.list_int), 2) + + filters = { + "dynamic.list_str": { + "values": ["empty"], + "exclude": False, + }, + "dynamic.list_int": { + "range": [0, 2], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic.list_str": { + "values": ["one"], + "exclude": False, + }, + "dynamic.list_int": { + "range": [3, 4], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic.list_bool": { + "true": False, + "false": True, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic.list_bool": { + "true": False, + "false": True, + "exclude": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "dynamic.list_bool": { + "true": True, + "false": False, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + view = fosv.get_view("test", pagination_data=True) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"one": 1, "two": 2}) + + filters = { + "_label_tags": { + "values": ["two"], + "exclude": False, + "isMatching": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().dynamic.predictions.detections), 2) - @drop_datasets - def test_extended_view_video_match_label_tags_samples(self): filters = { "_label_tags": { "values": ["one"], "exclude": False, "isMatching": False, - } + }, } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().dynamic.predictions.detections), 1) + + view = fosv.get_view("test", pagination_data=True, filters=filters) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"one": 1, "two": 1}) + filters = { + "_label_tags": { + "values": ["two"], + "exclude": True, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual(len(view.first().dynamic.predictions.detections), 0) + + filters = { + "_label_tags": { + "values": ["one"], + "exclude": True, + "isMatching": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertEqual( + view.first().dynamic.predictions.detections[0].label, "not_carrot" + ) + + view = fosv.get_view("test", pagination_data=True, filters=filters) + (sample,) = list( + foo.aggregate( + foo.get_db_conn()[view._dataset._sample_collection_name], + view._pipeline(), + ) + ) + self.assertIn("_label_tags", sample) + self.assertDictEqual(sample["_label_tags"], {"two": 1}) + + filters = { + "_label_tags": { + "values": ["one"], + "exclude": True, + "isMatching": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic_list.bool": { + "true": False, + "false": True, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic_list.list_bool": { + "true": False, + "false": True, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic_list.bool": { + "true": True, + "false": False, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "dynamic_list.list_bool": { + "true": True, + "false": False, + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "dynamic_list.int": { + "range": [-1, 0], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic_list.list_int": { + "range": [-1, 0], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic_list.int": { + "range": [0, 1], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "dynamic_list.list_int": { + "range": [0, 1], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + + filters = { + "dynamic_list.int": { + "range": [0, 2], + "exclude": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "dynamic_list.list_int": { + "range": [0, 2], + "exclude": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + @drop_datasets + def test_extended_keypoint_sample(self): dataset = fod.Dataset("test") - dataset.media_type = "video" - dataset.add_frame_field( - "detections", fof.EmbeddedDocumentField, fol.Detections + dataset.default_skeleton = fo.KeypointSkeleton( + labels=["top-left", "center", "bottom-right"], edges=[[0, 1, 2]] ) + sample = fos.Sample( + filepath="video.mp4", + keypoint=fo.Keypoint( + label="keypoint", + points=[[0, 0], [0.5, 0.5], [1, 1]], + confidence=[0, 0.5, 1], + dynamic=["one", "two", "three"], + tags=["keypoint"], + ), + keypoints=fo.Keypoints( + keypoints=[ + fo.Keypoint( + label="keypoint", + points=[[0, 0], [0.5, 0.5], [1, 1]], + confidence=[0, 0.5, 1], + dynamic=["one", "two", "three"], + tags=["keypoint"], + ) + ] + ), + ) + + dataset.add_sample(sample) + dataset.add_dynamic_sample_fields() + dataset.add_dynamic_frame_fields() - returned = fosv.get_view( - "test", filters=filters, count_label_tags=False - )._pipeline()[1:] - - expected = [ - { - "$match": { - "$or": [ - {"frames.detections.detections.tags": {"$in": ["one"]}} - ], - }, - }, - { - "$addFields": { - "frames": { - "$map": { - "input": "$frames", - "as": "frame", - "in": { - "$mergeObjects": [ - "$$frame", - { - "detections": { - "$mergeObjects": [ - "$$frame.detections", - { - "detections": { - "$filter": { - "input": "$$frame.detections.detections", - "cond": { - "$cond": { - "if": { - "$gt": [ - "$$this.tags", - None, - ], - }, - "then": { - "$in": [ - "one", - "$$this.tags", - ], - }, - "else": False, - }, - }, - }, - }, - }, - ], - }, - }, - ], - }, - }, - }, - }, - }, - { - "$match": { - "$expr": { - "$gt": [ - { - "$reduce": { - "input": "$frames", - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$size": { - "$ifNull": [ - "$$this.detections.detections", - [], - ], - }, - }, - ], - }, - }, - }, - 0, - ], - }, - }, - }, - ] - - self.assertEqual(expected, returned) + filters = { + "keypoint.label": { + "values": ["empty"], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "keypoint.label": { + "values": ["keypoint"], + "exclude": True, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 0) + + filters = { + "keypoint.points": { + "values": ["top-left"], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1) + self.assertListEqual(view.first().keypoint.points[0], [0, 0]) + for point in view.first().keypoint.points[1:]: + self.assertTrue(math.isnan(point[0])) + self.assertTrue(math.isnan(point[1])) + + filters = { + "keypoint.points": { + "values": ["top-left"], + "exclude": False, + }, + } + view = fosv.get_view("test", filters=filters) + self.assertEqual(len(view), 1)