Skip to content

Commit

Permalink
improve flow.load for better error message (#10138)
Browse files Browse the repository at this point in the history
https://github.com/Oneflow-Inc/diffusers/issues/174#issuecomment-1500797463
这种情况提供更准确的报错信息

---------

Signed-off-by: daquexian <daquexian566@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 23, 2023
1 parent 93ae4de commit 57b0d52
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 35 deletions.
70 changes: 41 additions & 29 deletions python/oneflow/framework/check_point_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
MAP_LOCATION: TypeAlias = Optional[
Union[Callable[[Tensor, str], Tensor], flow.device, str, flow.placement]
]
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes], Path]
FILE_LIKE: TypeAlias = Union[os.PathLike, BinaryIO, IO[bytes], Path]


class _opener(object):
Expand Down Expand Up @@ -107,7 +107,7 @@ def _open_file_like(path_or_buffer, mode):


def _is_path(path_or_buffer):
return isinstance(path_or_buffer, str) or isinstance(path_or_buffer, Path)
return isinstance(path_or_buffer, Path)


def _check_seekable(f) -> bool:
Expand Down Expand Up @@ -475,9 +475,8 @@ def tensor_pickling_context(


def is_oneflow_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool:
if _is_path(path):
if not path.is_file():
return False
if _is_path(path) and not path.is_file():
return False
try:
with _open_file_like(path, "rb") as f:
content = pickle.load(f)
Expand Down Expand Up @@ -517,20 +516,27 @@ def load_from_oneflow_single_file(
def is_file_and_support_pytorch_format(
path: FILE_LIKE, support_pytorch_format: bool
) -> bool:
if _is_path(path):
return path.is_file() and support_pytorch_format
return support_pytorch_format
if not support_pytorch_format:
return False
if _is_path(path) and not path.is_file():
return False
try:
with flow.mock_torch.disable():
import torch

content = torch.load(path, map_location="cpu")
return True, (content,)
except:
return False


@load_if(is_file_and_support_pytorch_format)
def load_from_pytorch_file(
path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION,
path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION, torch_obj: Any = None
):
with flow.mock_torch.disable():
import torch

if global_src_rank is None or global_src_rank == flow.env.get_rank():
torch_obj = torch.load(path, map_location="cpu")
if torch_obj is not None:
with flow.mock_torch.disable():
import torch

def torch_tensor_to_flow(x):
if isinstance(x, torch.Tensor):
Expand All @@ -539,17 +545,17 @@ def torch_tensor_to_flow(x):
return x

flow_obj = ArgsTree(torch_obj).map_leaf(torch_tensor_to_flow)
else:
flow_obj = None
if global_src_rank is not None:
flow_obj = flow.utils.global_view.to_global(
flow_obj,
placement=flow.placement("cpu", [global_src_rank]),
sbp=flow.sbp.broadcast,
warn_on_non_tensor_leaf=False,
)
flow_obj = _map_location(flow_obj, map_location)
return flow_obj
else:
flow_obj = None
if global_src_rank is not None:
flow_obj = flow.utils.global_view.to_global(
flow_obj,
placement=flow.placement("cpu", [global_src_rank]),
sbp=flow.sbp.broadcast,
warn_on_non_tensor_leaf=False,
)
flow_obj = _map_location(flow_obj, map_location)
return flow_obj


def is_dir_and_has_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool:
Expand Down Expand Up @@ -585,7 +591,7 @@ def load_from_oneflow_pickle_dir(


def load(
path: FILE_LIKE,
path: Union[FILE_LIKE, str],
global_src_rank: Optional[int] = None,
map_location: MAP_LOCATION = None,
*,
Expand All @@ -611,7 +617,7 @@ def load(
The loaded object
"""
if isinstance(path, str):
path: Path = Path(path)
path = Path(path)
rank = flow.env.get_rank()
if global_src_rank is None or global_src_rank == rank:
for i, (condition, load) in enumerate(load_methods):
Expand All @@ -621,7 +627,11 @@ def load(
_broadcast_py_object(i, global_src_rank)
break
else:
raise NotImplementedError("No valid load method found for {}".format(path))
if _is_path(path):
err_msg = f'Cannot load file "{path}"'
else:
err_msg = "Cannot load the data"
raise ValueError(err_msg)
else:
i = _broadcast_py_object(None, global_src_rank)
load = load_methods[i][1]
Expand Down Expand Up @@ -687,6 +697,9 @@ def save(
save_as_external_data (bool): useful only if path_or_buffer is a string or
os.PathLike object containing a file name
"""
if isinstance(path_or_buffer, str):
path_or_buffer = Path(path_or_buffer)

if isinstance(obj, graph_util.Graph):
if not _is_path(path_or_buffer):
raise ValueError(
Expand All @@ -703,7 +716,6 @@ def save(
pickled_bytes = pickle.dumps(obj)

if _is_path(path_or_buffer) and save_as_external_data:
path_or_buffer: Path = Path(path_or_buffer)
path_or_buffer.mkdir(exist_ok=True)
path_or_buffer = path_or_buffer / PICKLE_FILENAME

Expand Down
7 changes: 5 additions & 2 deletions python/oneflow/mock_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def __init__(self):
self.globals = globals
self.lazy = _importer.lazy
self.verbose = _importer.verbose
self.from_cli = _importer.from_cli
_importer._disable(globals)

def __enter__(self):
Expand All @@ -305,7 +304,11 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):
if self.enable:
_importer._enable(
self.globals, self.lazy, self.verbose, from_cli=self.from_cli
# When re-enabling mock torch, from_cli shoule always be False
self.globals,
self.lazy,
self.verbose,
from_cli=False,
)


Expand Down
20 changes: 16 additions & 4 deletions python/oneflow/test/exceptions/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,28 @@
import torch


@flow.unittest.skip_unless_1n1d()
class TestSaveLoad(flow.unittest.TestCase):
@flow.unittest.skip_unless_1n1d()
def test_support_pytorch_with_global_src_rank(test_case):
conv_torch = torch.nn.Conv2d(3, 3, 3)
conv_flow = flow.nn.Conv2d(3, 3, 3)
with tempfile.NamedTemporaryFile() as f:
torch.save(conv_torch.state_dict(), f.name)
with test_case.assertRaises(NotImplementedError) as ctx:
conv_flow.load_state_dict(flow.load(f.name, support_pytorch=False))
test_case.assertTrue("No valid load method found" in str(ctx.exception))
with test_case.assertRaises(ValueError) as ctx:
conv_flow.load_state_dict(
flow.load(f.name, support_pytorch_format=False)
)
test_case.assertTrue("Cannot load file" in str(ctx.exception))

def test_load_invalid_file(test_case):
f = tempfile.NamedTemporaryFile()
f.write(b"invalid file")
f.flush()
with test_case.assertRaises(ValueError) as ctx:
flow.load(f.name)
test_case.assertTrue("Cannot load file" in str(ctx.exception))

f.close()


if __name__ == "__main__":
Expand Down

0 comments on commit 57b0d52

Please sign in to comment.