Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion keras/src/backend/common/name_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __exit__(self, *args, **kwargs):
name_scope_stack = global_state.get_global_attribute(
"name_scope_stack"
)
name_scope_stack.pop()
if name_scope_stack:
name_scope_stack.pop()


def current_path():
Expand Down
85 changes: 85 additions & 0 deletions keras/src/backend/common/name_scope_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import threading

from keras.src import testing
from keras.src.backend.common import global_state
from keras.src.backend.common.name_scope import current_path
from keras.src.backend.common.name_scope import name_scope

Expand Down Expand Up @@ -46,3 +49,85 @@ def test_override_parent(self):
current_path(), "absolute/path/middle/inner"
)
self.assertEqual(current_path(), "outer")

def test_exit_with_none_stack(self):
"""Test that __exit__ handles None name_scope_stack gracefully."""
# Create a name_scope instance
scope = name_scope("test")
# Enter the scope normally
scope.__enter__()

# Simulate the scenario where global state is cleared
# (e.g., in a different thread)
global_state.set_global_attribute("name_scope_stack", None)

# Exit should not raise an AttributeError
scope.__exit__()

# Clean up: reset the stack
global_state.set_global_attribute("name_scope_stack", [])

def test_exit_with_empty_stack(self):
"""Test that __exit__ handles empty name_scope_stack gracefully."""
# Create a name_scope instance
scope = name_scope("test")
# Enter the scope normally
scope.__enter__()

# Simulate the scenario where the stack is cleared
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
name_scope_stack.clear()

# Exit should not raise an IndexError
scope.__exit__()

# Verify stack is still empty
name_scope_stack = global_state.get_global_attribute(
"name_scope_stack", default=[]
)
self.assertEqual(len(name_scope_stack), 0)

def test_multithreaded_name_scope(self):
"""Test name_scope in multithreaded environment."""
results = []

def thread_function(thread_id):
# Each thread should have its own name_scope_stack
with name_scope(f"thread_{thread_id}"):
path = current_path()
results.append(path)
# Verify we get the expected path
self.assertEqual(path, f"thread_{thread_id}")

# Create and start multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=thread_function, args=(i,))
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

# Verify all threads executed successfully
self.assertEqual(len(results), 5)

def test_exit_without_pop_on_exit(self):
"""Test that __exit__ respects _pop_on_exit flag."""
# Create a name_scope but don't enter it
scope = name_scope("test")
# _pop_on_exit should be False
self.assertFalse(scope._pop_on_exit)

# Set up a stack manually
global_state.set_global_attribute("name_scope_stack", [scope])

scope.__exit__()

# Verify the stack still contains the scope
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
self.assertEqual(len(name_scope_stack), 1)

# Clean up
global_state.set_global_attribute("name_scope_stack", [])