Skip to content

Commit

Permalink
[CLI-941] fix floating point issue (#2257)
Browse files Browse the repository at this point in the history
  • Loading branch information
dannygoldstein authored Jun 7, 2021
1 parent a09a81e commit fa47ec9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/wandb_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import os
import sys
import numpy as np
import platform
import pytest
import yaml
import wandb
from wandb import wandb_sdk
from wandb.proto.wandb_internal_pb2 import RunPreemptingRecord
Expand Down Expand Up @@ -60,6 +61,27 @@ def test_run_pub_history(fake_run, record_q, records_util):
# TODO(jhr): check history vals


@pytest.mark.skipif(
platform.system() == "Windows", reason="numpy.float128 does not exist on windows"
)
def test_numpy_high_precision_float_downcasting(fake_run, record_q, records_util):
# CLI: GH2255
run = fake_run()
run.log(dict(this=np.float128(0.0)))
r = records_util(record_q)
assert len(r.records) == 1
assert len(r.summary) == 0
history = r.history
assert len(history) == 1

found = False
for item in history[0].item:
if item.key == "this":
assert item.value_json == "0.0"
found = True
assert found


def test_log_code_settings(live_mock_server, test_settings):
with open("test.py", "w") as f:
f.write('print("test")')
Expand Down
7 changes: 7 additions & 0 deletions wandb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@ def json_friendly(obj):
obj = obj.item()
if isinstance(obj, float) and math.isnan(obj):
obj = None
elif isinstance(obj, np.generic) and obj.dtype.kind == "f":
# obj is a numpy float with precision greater than that of native python float
# (i.e., float96 or float128). in this case obj.item() does not return a native
# python float to avoid loss of precision, so we need to explicitly cast this
# down to a 64bit float
obj = float(obj)

elif isinstance(obj, bytes):
obj = obj.decode("utf-8")
elif isinstance(obj, (datetime, date)):
Expand Down

0 comments on commit fa47ec9

Please sign in to comment.