Skip to content

Commit

Permalink
Fix lightning numeric bounds (#5125)
Browse files Browse the repository at this point in the history
* improve float tests

* lightning _first fix

* None handling
  • Loading branch information
benjaminpkane authored Nov 15, 2024
1 parent 7e0f101 commit e1e0a81
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
21 changes: 16 additions & 5 deletions fiftyone/server/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,15 @@ def _first(
if sort:
pipeline.append({"$match": {path: {"$ne": None}}})

pipeline.append({"$sort": {path: sort}})

return pipeline + [
{"$group": {"_id": {"$min" if sort == 1 else "$max": f"${path}"}}}
{
"$group": {
"_id": None,
"value": {"$min" if sort == 1 else "$max": f"${path}"},
}
}
]


Expand Down Expand Up @@ -458,7 +465,11 @@ def _match_arrays(dataset: fo.Dataset, path: str, is_frame_field: bool):

def _parse_result(data):
if data and data[0]:
return data[0].get("_id", None)
value = data[0]
if value.get("value", None) is not None:
return value["value"]

return value.get("_id", None)

return None

Expand All @@ -468,13 +479,13 @@ def _unwind(dataset: fo.Dataset, path: str, is_frame_field: bool):
path = None
pipeline = []

prefix = ""
if is_frame_field:
path = keys[0]
keys = keys[1:]
prefix = "frames."

for key in keys:
path = ".".join([path, key]) if path else key
field = dataset.get_field(path)
field = dataset.get_field(f"{prefix}{path}")
while isinstance(field, fof.ListField):
pipeline.append({"$unwind": f"${path}"})
field = field.field
Expand Down
14 changes: 7 additions & 7 deletions tests/unittests/lightning_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,13 +488,13 @@ async def test_floats(self, dataset: fo.Dataset):
dataset,
dict(
float=-1.0,
float_list=[-1.0],
float_list=[0.0, -1.0],
inf=-1.0,
inf_list=[-1.0],
inf_list=[0.0, -1.0],
nan=-1.0,
nan_list=[-1.0],
nan_list=[0.0, -1.0],
ninf=-1.0,
ninf_list=[-1.0],
ninf_list=[0.0, -1.0],
),
dict(
float=0.0,
Expand All @@ -508,13 +508,13 @@ async def test_floats(self, dataset: fo.Dataset):
),
dict(
float=1.0,
float_list=[1.0],
float_list=[0.0, 1.0],
inf=1.0,
inf_list=[1.0],
nan=1.0,
nan_list=[1.0],
nan_list=[0.0, 1.0],
ninf=1.0,
ninf_list=[1.0],
ninf_list=[0.0, 1.0],
),
)

Expand Down

0 comments on commit e1e0a81

Please sign in to comment.