diff --git a/.circleci/config.yml b/.circleci/config.yml index a388b54d..2657ed55 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -179,7 +179,20 @@ jobs: shell: bash -leo pipefail command: | make test-nb-customization - + test-nb-pretraining: + executor: python-executor + steps: + - checkout + # Download and cache dependencies + - restore_cache: + keys: + - v1-dependencies-{{ checksum "poetry.lock" }} + - install_poetry + - run: + name: run test-nb-pretraining + shell: bash -leo pipefail + command: | + make test-nb-pretraining workflows: version: 2 CI-tabnet: @@ -200,8 +213,14 @@ workflows: requires: - install - test-nb-multi-task: - requires: - - install + requires: + - install + - test-nb-customization: + requires: + - install + - test-nb-pretraining: + requires: + - install - lint-code: requires: - install diff --git a/Makefile b/Makefile index 953bfcf1..2fed8dd9 100644 --- a/Makefile +++ b/Makefile @@ -81,27 +81,31 @@ doc: build ## Build and generate docs test-nb-census: ## run census income tests using notebooks $(MAKE) _run_notebook NB_FILE="./census_example.ipynb" -.PHONY: test-obfuscator +.PHONY: test-nb-census test-nb-forest: ## run census income tests using notebooks $(MAKE) _run_notebook NB_FILE="./forest_example.ipynb" -.PHONY: test-obfuscator +.PHONY: test-nb-forest test-nb-regression: ## run regression example tests using notebooks $(MAKE) _run_notebook NB_FILE="./regression_example.ipynb" -.PHONY: test-obfuscator +.PHONY: test-nb-regression test-nb-multi-regression: ## run multi regression example tests using notebooks $(MAKE) _run_notebook NB_FILE="./multi_regression_example.ipynb" -.PHONY: test-obfuscator +.PHONY: test-nb-multi-regression test-nb-multi-task: ## run multi task classification example tests using notebooks $(MAKE) _run_notebook NB_FILE="./multi_task_example.ipynb" -.PHONY: test-obfuscator +.PHONY: test-nb-multi-task test-nb-customization: ## run customization example tests using notebooks $(MAKE) _run_notebook NB_FILE="./customizing_example.ipynb" -.PHONY: test-obfuscator +.PHONY: test-nb-customization + +test-nb-pretraining: ## run customization example tests using notebooks + $(MAKE) _run_notebook NB_FILE="./pretraining_example.ipynb" +.PHONY: test-nb-pretraining help: ## Display help @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/pytorch_tabnet/pretraining.py b/pytorch_tabnet/pretraining.py index 9c849277..d4ca49e9 100644 --- a/pytorch_tabnet/pretraining.py +++ b/pytorch_tabnet/pretraining.py @@ -40,6 +40,7 @@ def update_fit_params( ): self.updated_weights = weights filter_weights(self.updated_weights) + self.preds_mapper = None def fit( self,