-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pull request #5: Quantization Helper & pyproject improvement
Merge in CSID/glaucus from feature/toml-and-utils to main Squashed commit of the following: commit 631df6e10625bdef6580e0da333d0b43704e34f2 Author: Kyle A Logue <kyle.a.logue@aero.org> Date: Thu Feb 8 15:57:03 2024 -0800 Quantization Helper & pyproject improvement * move all configuration into pyproject.toml * add function to adapt quantized weights to non-quantized model * increment to v1.1.4
- Loading branch information
Kyle A Logue
committed
Feb 12, 2024
1 parent
dda82f2
commit 48910a0
Showing
11 changed files
with
204 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
'''utilities''' | ||
# Copyright 2023 The Aerospace Corporation | ||
# This file is a part of Glaucus | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import copy | ||
import re | ||
|
||
|
||
def adapt_glaucus_quantized_weights(state_dict: dict) -> dict: | ||
""" | ||
The pretrained Glaucus models have a quantization layer that shifts the | ||
encoder list positions, so if we create a model w/o quantization we have to | ||
shift those layers slightly to make the pretrained model work. | ||
This function decrements the position of the decoder layers in the state | ||
dict to allow loading from a pre-trained model that was quantization aware. | ||
ie: `fc_decoder._fc.1.weight` becomes `fc_decoder._fc.0.weight` | ||
There will be extra layers remaining, but we can discard them by loading | ||
with `strict=False`. See the README for an example. | ||
Parameters | ||
---------- | ||
state_dict : dict | ||
Torch state dictionary including quantization layers. | ||
Returns | ||
------- | ||
new_state_dict : dict | ||
State dictionary without quantization layers. | ||
""" | ||
new_state_dict = copy.deepcopy(state_dict) | ||
|
||
pattern = r"(fc_decoder._fc.)(\d+)(\.\w+)" # regex pattern | ||
|
||
for key, value in state_dict.items(): | ||
match = re.match(pattern, key) | ||
if match: | ||
extracted_int = int(match.group(2)) | ||
new_key = f"{match.group(1)}{extracted_int-1}{match.group(3)}" | ||
new_state_dict[new_key] = value | ||
return new_state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
[project] | ||
name = "glaucus" | ||
description = "Glaucus is a PyTorch complex-valued ML autoencoder & RF estimation python module. " | ||
keywords = ["dsp", "ml", "autoencoder", "sigint", "rf"] | ||
classifiers = [ | ||
"License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)", | ||
"Operating System :: OS Independent", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
] | ||
dynamic = ["version", "readme"] | ||
authors = [ | ||
{name = "Kyle Logue", email = "kyle.logue@aero.org"} | ||
] | ||
requires-python = ">=3.8" | ||
dependencies = [ | ||
"torch", # basic ML framework | ||
"lightning", # extensions for PyTorch | ||
"madgrad", # our favorite optimizer | ||
"hypothesis", # best unit testing | ||
] | ||
[project.urls] | ||
repository = "https://github.com/the-aerospace-corporation/glaucus" | ||
|
||
[tool.setuptools] | ||
packages = ["glaucus"] | ||
[tool.setuptools.dynamic] | ||
version = {attr = "glaucus.__version__"} | ||
readme = {file = ["README.md"], content-type = "text/markdown"} | ||
|
||
[build-system] | ||
requires = ["setuptools>=65.0", "setuptools-scm"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.coverage.run] | ||
branch = true | ||
source = ["glaucus", "tests"] | ||
# -rA captures stdout from all tests and places it after the pytest summary | ||
command_line = "-m pytest -rA --doctest-modules --junitxml=pytest.xml" | ||
|
||
[tool.pytest.ini_options] | ||
addopts = "--doctest-modules" | ||
testpaths = ["glaucus", "tests"] | ||
|
||
[tool.pylint] | ||
[tool.pylint.main] | ||
load-plugins = [ | ||
"pylint.extensions.typing", | ||
"pylint.extensions.docparams", | ||
] | ||
exit-zero = true | ||
[tool.pylint.messages_control] | ||
disable = [ | ||
"logging-not-lazy", | ||
"missing-module-docstring", | ||
"import-error", | ||
"unspecified-encoding", | ||
] | ||
max-line-length = 160 | ||
[tool.pylint.REPORTS] | ||
# omit from the similarity reports | ||
ignore-comments = "yes" | ||
ignore-docstrings = "yes" | ||
ignore-imports = "yes" | ||
ignore-signatures = "yes" | ||
min-similarity-lines = 4 | ||
|
||
[tool.pytype] | ||
inputs = ["glaucus", "tests"] | ||
|
||
[tool.black] | ||
line-length = 160 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.