Skip to content

Commit

Permalink
clean commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nono-Sang committed May 15, 2024
1 parent f5168ae commit a1dfe3f
Showing 1 changed file with 40 additions and 47 deletions.
87 changes: 40 additions & 47 deletions tests/convert_torch_to_of/test_torch2of_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
Install:
pip install pytest
Uasge:
python -m pytest diffusers/tests/torch_to_oflow/test_torch2of_demo.py
python -m pytest test_torch2of_demo.py
"""
import torch
import oneflow as flow
import unittest
import numpy as np
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.transform import transform_mgr, get_mock_cls_name
from onediff.infer_compiler.transform import transform_mgr


class PyTorchModel(torch.nn.Module):
Expand Down Expand Up @@ -37,48 +39,39 @@ def forward(self, x):

def apply_model(self, x):
return self.forward(x)


def test_torch2of_demo():
# Register PyTorch model to OneDiff
cls_key = get_mock_cls_name(PyTorchModel)
transform_mgr.update_class_proxies({cls_key: OneFlowModel})

# Compile PyTorch model to OneFlow
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pytorch_model = PyTorchModel().to(device)
of_model = oneflow_compile(pytorch_model)

# Verify conversion
x = torch.randn(4, 4).to(device)

#### 1. Use apply_model method
y_pt = pytorch_model.apply_model(x)
y_of = of_model.apply_model(x)

#### 2. Use __call__ method
y_pt = pytorch_model(x)
y_of = of_model(x)

print(
f"""
PyTorch output: {type(y_pt)}
{y_pt}
OneFlow output: {type(y_of)}
{y_of}
"""
)

"""Output:
PyTorch output: <class 'torch.Tensor'>
tensor([[ 0.3281, 0.3748, 0.1928, -0.0843],
[-0.3299, 1.1503, 0.0115, -0.2054],
[ 0.1292, -0.1496, 1.1743, -0.4726],
[ 0.2691, 0.2332, 0.1492, 0.5211]], device='cuda:0',
grad_fn=<AddmmBackward0>)
OneFlow output: <class 'torch.Tensor'>
tensor([[ 0.3281, 0.3748, 0.1928, -0.0843],
[-0.3299, 1.1503, 0.0115, -0.2054],
[ 0.1292, -0.1496, 1.1743, -0.4726],
[ 0.2691, 0.2332, 0.1492, 0.5211]], device='cuda:0')
"""


class TestTorch2ofDemo(unittest.TestCase):
def judge_tensor_func(self, y_pt, y_of):
assert type(y_pt) == type(y_of)
assert y_pt.device == y_of.device
y_pt = y_pt.cpu().detach().numpy()
y_of = y_of.cpu().detach().numpy()
assert np.allclose(y_pt, y_of, atol=1e-3, rtol=1e-3)

def test_torch2of_demo(self):
# Register PyTorch model to OneDiff
cls_key = transform_mgr.get_transformed_entity_name(PyTorchModel)
transform_mgr.update_class_proxies({cls_key: OneFlowModel})

# Compile PyTorch model to OneFlow
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pytorch_model = PyTorchModel().to(device)
of_model = oneflow_compile(pytorch_model)

# Verify conversion
x = torch.randn(4, 4).to(device)

#### 1. Use apply_model method
y_pt = pytorch_model.apply_model(x)
y_of = of_model.apply_model(x)
self.judge_tensor_func(y_pt, y_of)

#### 2. Use __call__ method
y_pt = pytorch_model(x)
y_of = of_model(x)
self.judge_tensor_func(y_pt, y_of)


if __name__ == "__main__":
unittest.main()

0 comments on commit a1dfe3f

Please sign in to comment.