Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a compatibility issue, add requirements.txt, and automate tests. #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ If you use this code in your research please consider citing

- Hardware: PC with Tesla-V100.
- Software: *CUDA >= 10.0*, *Anaconda3*, *pytorch >= 1.0.0*
- The file requirements.txt was generated after running tests in AWS SageMaker using a ml.g4dn.xlarge instance with the kernel conda_pytorch_39

### Download Dataset

Expand All @@ -20,7 +21,14 @@ If you use this code in your research please consider citing
Please merge the dataset and the label into the same folder

### Evaluate DRT
#### You can evaluate the model automatically by running the script below.

```
chmod +x test.sh
./test.sh
```

#### Manual test
The pre-trained models are provided- [Clipart](https://drive.google.com/file/d/1mh1jpUWQrginSACZvZDmtyYeh-TZUxBS/view?usp=sharing), [Infograph](https://drive.google.com/file/d/16zmGRRnXwsTMgj2-RKhwWdaOLXkozXMl/view?usp=sharing), [Painting](https://drive.google.com/file/d/15YhOjPjuutHrcK-m511OERu_4vIVYArD/view?usp=sharing), [Quickdraw](https://drive.google.com/file/d/1O4JwTDudqT1aj2VfFxgU1ld7bk0Hlcth/view?usp=sharing), [Real](https://drive.google.com/file/d/1ygMj4nJU74qywMbdq2DvQyyZZHngBD-3/view?usp=sharing), [Sketch](https://drive.google.com/file/d/1FVNy6OVkptKCL6rp7SqRlrZ5aYM-77vy/view?usp=sharing). Here we use 'Clipart' as an example. If you want to test other domains, all you need to do is just to replace the name of the dataset.

```
Expand Down
6 changes: 3 additions & 3 deletions drt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,21 @@ def print_options(save_path, opt):

data_transforms = {
src_path: transforms.Compose([
transforms.Scale(256),
transforms.Resize(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
trg_path: transforms.Compose([
transforms.Scale(256),
transforms.Resize(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
val_path: transforms.Compose([
transforms.Scale(256),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
Expand Down
Loading