Skip to content

Commit

Permalink
Fix bug in reload mode equality check. Better equality conversion for…
Browse files Browse the repository at this point in the history
… state variables (#8385)

* Add code

* Add deep equality

* add changeset

* Add code

* add changeset

* Update gradio/utils.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Add code

* Add code

* add code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people authored May 29, 2024
1 parent e738e26 commit 97ac79b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .changeset/ripe-tools-jam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Fix bug in reload mode equality check. Better equality conversion for state variables
10 changes: 9 additions & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,15 @@ async def postprocess_data(

if block.stateful:
if not utils.is_update(predictions[i]):
if block._id not in state or state[block._id] != predictions[i]:
has_change_event = False
for dep in state.blocks_config.fns.values():
if block._id in [t[0] for t in dep.targets if t[1] == "change"]:
has_change_event = True
break
if has_change_event and (
block._id not in state
or not utils.deep_equal(state[block._id], predictions[i])
):
changed_state_ids.append(block._id)
state[block._id] = predictions[i]
output.append(None)
Expand Down
30 changes: 28 additions & 2 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import anyio
import gradio_client.utils as client_utils
import httpx
import orjson
from gradio_client.documentation import document
from typing_extensions import ParamSpec

Expand Down Expand Up @@ -290,6 +291,29 @@ def iter_py_files() -> Iterator[Path]:
time.sleep(0.05)


def deep_equal(a: Any, b: Any) -> bool:
"""
Deep equality check for component values.
Prefer orjson for performance and compatibility with numpy arrays/dataframes/torch tensors.
If objects are not serializable by orjson, fall back to regular equality check.
"""

def _serialize(a: Any) -> bytes:
return orjson.dumps(
a,
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
)

try:
return _serialize(a) == _serialize(b)
except TypeError:
try:
return a == b
except Exception:
return False


def reassign_keys(old_blocks: Blocks, new_blocks: Blocks):
from gradio.blocks import BlockContext

Expand All @@ -310,8 +334,10 @@ def reassign_context_keys(
old_block.__class__ == new_block.__class__
and old_block is not None
and old_block.key not in assigned_keys
and json.dumps(getattr(old_block, "value", None))
== json.dumps(getattr(new_block, "value", None))
and deep_equal(
getattr(old_block, "value", None),
getattr(new_block, "value", None),
)
):
new_block.key = old_block.key
else:
Expand Down

0 comments on commit 97ac79b

Please sign in to comment.