From 3fb4393b9c3bca419ce3e3f7beab35b8741530a2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 23 Sep 2024 11:43:19 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- .github/workflows/build-wheels-windows.yml | 8 ++++++-- .github/workflows/wheels-windows.yml | 6 +++--- tensordict/base.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 876f3bddb..fd76d0a5a 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -32,10 +32,12 @@ jobs: matrix: include: - repository: pytorch/tensordict + pre-script: "" + env-script: .github/scripts/version_script.bat post-script: "python packaging/wheel/relocate.py" smoke-test-script: test/smoke_test.py package-name: tensordict - name: pytorch/tensordict + name: ${{ matrix.repository }} uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main with: repository: ${{ matrix.repository }} @@ -43,7 +45,9 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + env-script: ${{ matrix.env-script }} + post-script: ${{ matrix.post-script }} package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - env-script: .github/scripts/version_script.bat diff --git a/.github/workflows/wheels-windows.yml b/.github/workflows/wheels-windows.yml index 884913f52..79ae4e979 100644 --- a/.github/workflows/wheels-windows.yml +++ b/.github/workflows/wheels-windows.yml @@ -36,12 +36,12 @@ jobs: python3 -mpip install wheel TENSORDICT_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: tensordict-win-${{ matrix.python_version[0] }}.whl path: dist/tensordict-*.whl - name: Upload wheel for download - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: tensordict-batch.whl path: dist/*.whl @@ -72,7 +72,7 @@ jobs: run: | python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml - name: Download built wheels - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: tensordict-win-${{ matrix.python_version }}.whl path: wheels diff --git a/tensordict/base.py b/tensordict/base.py index 1431ed511..9a691e4c1 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1495,6 +1495,16 @@ def grad(self): """Returns a tensordict containing the .grad attributes of the leaf tensors.""" return self._grad() + @grad.setter + def grad(self, grad): + def set_grad(x, grad): + if x.grad is None: + x.grad = grad + else: + x.grad.copy_(grad) + + self._fast_apply(set_grad, grad) + def zero_grad(self, set_to_none: bool = True) -> T: """Zeros all the gradients of the TensorDict recursively.