Skip to content

Commit

Permalink
arning regarding the dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-g-ion committed Feb 7, 2024
1 parent 1617745 commit 1ba3fba
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion examples/basic_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The TT approximation of a given tensor is $\\mathsf{x}_{i_1i_2...i_d} \\approx \\sum\\limits_{r_1,...,r_{d-1}=1}^{R_1,...,R_{d-1}} \\mathsf{g}^{(1)}_{1i_1r_1}\\cdots\\mathsf{g}^{(d)}_{r_{d-1}i_d1} $. Using the constructor `torchtt.TT()` a full tensor can be decomposed in the TT format."
"The TT approximation of a given tensor is $\\mathsf{x}_{i_1i_2...i_d} \\approx \\sum\\limits_{r_1,...,r_{d-1}=1}^{R_1,...,R_{d-1}} \\mathsf{g}^{(1)}_{1i_1r_1}\\cdots\\mathsf{g}^{(d)}_{r_{d-1}i_d1} $. Using the constructor `torchtt.TT()` a full tensor can be decomposed in the TT format.\n",
"The `dtype` of the input will be passed to the TT cores."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/basic_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# We now create a 4d `torch.tensor` which we will use later
tens_full = tn.reshape(tn.arange(32*16*8*10, dtype = tn.float64),[32,16,8,10])

# The TT approximation of a given tensor is $\mathsf{x}_{i_1i_2...i_d} \approx \sum\limits_{r_1,...,r_{d-1}=1}^{R_1,...,R_{d-1}} \mathsf{g}^{(1)}_{1i_1r_1}\cdots\mathsf{g}^{(d)}_{r_{d-1}i_d1} $. Using the constructor `torchtt.TT()` a full tensor can be decomposed in the TT format.
# The TT approximation of a given tensor is $\mathsf{x}_{i_1i_2...i_d} \approx \sum\limits_{r_1,...,r_{d-1}=1}^{R_1,...,R_{d-1}} \mathsf{g}^{(1)}_{1i_1r_1}\cdots\mathsf{g}^{(d)}_{r_{d-1}i_d1} $. Using the constructor `torchtt.TT()` a full tensor can be decomposed in the TT format. The `dtype` of the input will be passed to the TT cores.
tens_tt = tntt.TT(tens_full)

# The newly instantiated object contains the cores as a list, the mode sizes and the rank.
Expand Down
2 changes: 2 additions & 0 deletions torchtt/_tt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(self, source, shape=None, eps=1e-10, rmax=sys.maxsize):
This class implements basic operators such as `+,-,*,/,@,**` (add, subtract, elementwise multiplication, elementwise division, matrix vector product and Kronecker product) between TT instances.
The `examples\` folder server as a tutorial for all the possibilities of the toolbox.
Be aware of the dtype of the inputs since it can affect the accuracy. Recommended is to work with float64.
Examples:
.. code-block:: python
Expand Down

0 comments on commit 1ba3fba

Please sign in to comment.