Skip to content

Commit

Permalink
fix: set torch to v1.9.0
Browse files Browse the repository at this point in the history
  • Loading branch information
billsioros committed Jul 11, 2023
1 parent 43183e4 commit 2b096c2
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 71 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ repos:
hooks:
- id: yamlfmt
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.3.1
rev: 1.7.0
hooks:
- id: nbqa-black
- id: nbqa-check-ast
- id: nbqa-flake8
- id: nbqa-ruff
- id: nbqa-isort
- id: nbqa-mypy
- repo: https://github.com/roy-ht/pre-commit-jupyter
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## v4.2.2 (2021-08-03)
### Fix
* Call absoluate on vector based content loss `__call__` ([`99fd205`](https://github.com/billsioros/thesis/commit/99fd20593901454b331c71b6ca0b0942e1d1aae0))
* Call absolute on vector based content loss `__call__` ([`99fd205`](https://github.com/billsioros/thesis/commit/99fd20593901454b331c71b6ca0b0942e1d1aae0))

**[See all commits in this version](https://github.com/billsioros/thesis/compare/v4.2.1...v4.2.2)**

Expand Down Expand Up @@ -177,7 +177,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## v3.1.0 (2021-07-15)
### Feature
* Separatelly log/plot `content_loss` ([`fc16be6`](https://github.com/billsioros/thesis/commit/fc16be6142c7435d900ef1538295c9826ede3d60))
* Separately log/plot `content_loss` ([`fc16be6`](https://github.com/billsioros/thesis/commit/fc16be6142c7435d900ef1538295c9826ede3d60))
* Instantiate models per dataset ([`f7bee53`](https://github.com/billsioros/thesis/commit/f7bee5310ed03ea9f3cbee64361d8b26d1ec5798))
* Optionally suppress exceptions on training flow ([`122a055`](https://github.com/billsioros/thesis/commit/122a0556742b85a579576a6cf96b28ac755580d5))
* Separate limits for dataset/surface loading ([`4754d72`](https://github.com/billsioros/thesis/commit/4754d72ec3383192fb5ceb6385ee8a6daeb2f5d3))
Expand Down
3 changes: 1 addition & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ nav:
- Example: src/roughgan.ipynb
- Code Reference:
- CLI:
- Benchmark: src/cli/benchmark.md
- Benchmark: src/cli/benchmark.md
- Contributing:
- Contributing Guidelines: CONTRIBUTING.md
- Code Of Conduct: CODE_OF_CONDUCT.md
Expand Down Expand Up @@ -83,7 +83,6 @@ plugins:
minify_html: true
- mkdocs-jupyter:
ignore_h1_titles: True

markdown_extensions:
- abbr
- admonition
Expand Down
101 changes: 51 additions & 50 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ ray = ">=2.5.1"
click = ">=8.1.3"
rich = ">=13.4.2"
pyinsect = {git = "https://github.com/ggianna/PyINSECT"}
torch = "1.13.1"
torch = "1.9.0"
torchvision = "0.10.0"
kaleido = "0.2.1"

Expand Down Expand Up @@ -145,7 +145,7 @@ paths = ["src/roughgan", "tests"]

[tool.poe.tasks.lint]
help = "Lint your code for errors"
cmd = "poetry run flake8 ."
cmd = "poetry run ruff ."

[tool.poe.tasks.security]
help = "Run security checks on your application"
Expand Down
10 changes: 5 additions & 5 deletions src/roughgan/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from datetime import datetime

training_callback = None

Expand All @@ -83,18 +82,21 @@ def logging_callback(config, logging_dir):

return config


from roughgan.models import PerceptronGenerator


def get_generator():
return PerceptronGenerator.from_device(device)


from roughgan.models import PerceptronDiscriminator


def get_discriminator(generator):
return PerceptronDiscriminator.from_generator(generator)


from torch.nn import BCELoss

criterion = BCELoss().to(device)
Expand All @@ -114,7 +116,6 @@ def get_discriminator(generator):
training={
"manager": {
"benchmark": True,

"train_epoch": per_epoch,
"log_every_n": 10,
"criterion": {"instance": criterion},
Expand All @@ -136,7 +137,6 @@ def get_discriminator(generator):
},
content_loss={
"type": NGramGraphContentLoss,

},
data={
"loader": functools.partial(
Expand Down Expand Up @@ -165,12 +165,14 @@ def get_discriminator(generator):
def get_generator():
return CNNGenerator.from_device(device)


from roughgan.models import CNNDiscriminator


def get_discriminator(generator):
return CNNDiscriminator.from_generator(generator)


from torch.nn import BCELoss

criterion = BCELoss().to(device)
Expand All @@ -188,7 +190,6 @@ def get_discriminator(generator):
training={
"manager": {
"benchmark": True,

"train_epoch": per_epoch,
"log_every_n": 10,
"criterion": {"instance": criterion},
Expand All @@ -210,7 +211,6 @@ def get_discriminator(generator):
},
content_loss={
"type": ArrayGraph2DContentLoss,

},
data={
"loader": functools.partial(
Expand Down
Loading

0 comments on commit 2b096c2

Please sign in to comment.