Skip to content

Commit

Permalink
Merge pull request #486 from NiklasGustafsson/main
Browse files Browse the repository at this point in the history
Release preparation. Updating version number.
  • Loading branch information
NiklasGustafsson authored Dec 3, 2021
2 parents be6a38f + fe95202 commit 7dc140b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
9 changes: 9 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
## TorchSharp Release Notes

Releases, starting with 9/2/2021, are listed with the most recent release at the top.

## NuGet Version 0.95.4

__API Changes:__

Added OneCycleLR and CyclicLR schedulers
Added DisposeScopeManager and torch.NewDisposeScope() to facilitate a new solution for managing disposing of tensors with fewer usings.
Added Tensor.set_()
Added 'copy' argument to Tensor.to()

__Fixed Bugs:__

#476 BatchNorm does not expose bias,weight,running_mean,running_var
#475 Loading Module that's on CUDA
#372 Module.save moves Module to CPU
#468 How to set Conv2d kernel_size=(2,300)
#450 Smoother disposing

### NuGet Version 0.95.3

Expand Down
2 changes: 1 addition & 1 deletion build/BranchInfo.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<PropertyGroup>
<MajorVersion>0</MajorVersion>
<MinorVersion>95</MinorVersion>
<PatchVersion>3</PatchVersion>
<PatchVersion>4</PatchVersion>
</PropertyGroup>

</Project>
10 changes: 2 additions & 8 deletions docfx/articles/saveload.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Python pickling is intimately coupled to Python and its runtime object model. It

In order to share models between .NET applications, Python pickling is not at all necessary, and even for moving model state from Python to .NET, it is overkill. The state of a model is a simple dictionary where the keys are strings and the values are tensors.

Therefore, TorchSharp in its current form, implements its own very simple model serialization format, which allows models originating in either .NET or Python to be loaded using .NET, as long as the model was saved using the special format.
Therefore, TorchSharp, in its current form, implements its own very simple model serialization format, which allows models originating in either .NET or Python to be loaded using .NET, as long as the model was saved using the special format.

The MNIST and AdversarialExampleGeneration examples in this repo rely on saving and restoring model state -- the latter example relies on a pre-trained model from MNST.

Expand All @@ -35,20 +35,14 @@ In C#, saving a model looks like this:
model.save("model_weights.dat");
```

It's important to note that calling 'save' will move the model to the CPU, where it remains after the call. If you need to continue to use the model after saving it, you will have to explicitly move it back:

```C#
model.to(Device.CUDA);
```

And loading it again is done by:

```C#
model = [...];
model.load("model_weights.dat");
```

The model should be created on the CPU before loading weights, then moved to the target device.
For efficient memory management, the model should be created on the CPU before loading weights, then moved to the target device.

><br/>It is __critical__ that all submodules and buffers in a custom module or composed by a Sequential object have exactly the same name in the original and target models, since that is how persisted tensors are associated with the model into which they are loaded.<br/><br/>The CustomModule 'RegisterComponents' will automatically find all fields that are either modules or tensors, register the former as modules, and the latter as buffers. It registers all of these using the name of the field, just like the PyTorch Module base class does.<br/><br/>
Expand Down

0 comments on commit 7dc140b

Please sign in to comment.