Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to load FeedForward32Policy -- multiple values for keyword argument 'net_arch' #857

Open
glolichen opened this issue Aug 6, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@glolichen
Copy link

Bug description

A imitation.policies.base.FeedForward32Policy that is saved using policy.save() cannot be loaded with imitation.policies.base.FeedForward32Policy.load(), raising the following error:

Steps to reproduce

Train a policy using imitation.algorithms.bc.BC, then save the trained policy.

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)
bc_trainer.train(n_epochs=1)
bc_trainer.policy.save("policy.zip")

This should save a FeedForward32Policy. Then, load the policy.

imitation.policies.base.FeedForward32Policy.load("policy.zip")

This raised the following exception:

  File "/home/jayden/Desktop/Programs/aithermostat/venv/lib/python3.8/site-packages/imitation/policies/base.py", line 104, in __init__
    super().__init__(*args, **kwargs, net_arch=[32, 32])
TypeError: __init__() got multiple values for keyword argument 'net_arch'

Environment

  • Operating system and version: Linux
  • Python version: 3.8.19
  • Output of pip freeze --all:
absl-py==2.1.0
aiohappyeyeballs==2.3.4
aiohttp==3.10.0
aiosignal==1.3.1
ale-py==0.8.1
alembic==1.13.2
async-timeout==4.0.3
attrs==23.2.0
AutoROM==0.6.1
AutoROM.accept-rom-license==0.6.1
cachetools==5.4.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorlog==6.8.2
contourpy==1.1.1
cycler==0.12.1
datasets==2.20.0
dill==0.3.8
docopt==0.6.2
Farama-Notifications==0.0.4
filelock==3.15.4
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.5.0
gitdb==4.0.11
GitPython==3.1.43
google-auth==2.32.0
google-auth-oauthlib==1.0.0
greenlet==3.0.3
grpcio==1.65.2
gymnasium==0.29.1
huggingface-hub==0.24.5
huggingface-sb3==3.0
idna==3.7
imitation==1.0.0
importlib_metadata==8.2.0
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.2.2
kiwisolver==1.4.5
Mako==1.3.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
munch==4.0.0
networkx==3.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opencv-python==4.10.0.84
optuna==3.6.1
packaging==24.1
pandas==2.0.3
pillow==10.4.0
pip==23.0.1
protobuf==5.27.3
psutil==6.0.0
py-cpuinfo==9.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pygame==2.6.0
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
requests==2.32.3
requests-oauthlib==2.0.0
rich==13.7.1
rsa==4.9
sacred==0.8.5
scikit-learn==1.3.2
scipy==1.10.1
seals==0.2.1
setuptools==56.0.0
Shimmy==1.3.0
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.31
stable_baselines3==2.3.2
sympy==1.13.1
tensorboard==2.14.0
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
torch==2.4.0
tqdm==4.66.4
triton==3.0.0
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
wasabi==1.1.3
Werkzeug==3.0.3
wheel==0.43.0
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.19.2
@glolichen glolichen added the bug Something isn't working label Aug 6, 2024
@glolichen
Copy link
Author

glolichen commented Aug 6, 2024

A temporary fix I used is to remove the net_arch keyword argument from src/imitation/policies/base.py:104

super().__init__(*args, **kwargs, net_arch=[32, 32])

to

super().__init__(*args, **kwargs)

This is not an intelligent fix but does appear to resolve the above issue. This change is made in #858

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant