Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refact] Remove the ConcretizedCallable and Embedding class #278

Open
jpmoutinho opened this issue Sep 2, 2024 · 0 comments
Open

[Refact] Remove the ConcretizedCallable and Embedding class #278

jpmoutinho opened this issue Sep 2, 2024 · 0 comments
Labels
feature New feature or request Refactor

Comments

@jpmoutinho
Copy link
Collaborator

jpmoutinho commented Sep 2, 2024

Describe the feature

The ConcretizedCallable and Embedding classes were added as a prototype way to handle parameter expressions. However, they have since been adapted into Qadence 2 Platforms, and we might not need to keep them in PyQ.

The ConcretizedCallable itself can be used as a torch specific expression, that is evaluated based on a values dict. Renaming it as Expr, it could take the following form:

class Expr:
    def __init__(
        self,
        call_name: str,
        abstract_args: list[str | float | int | complex | Expr],
        device: torch.device | str = "cpu",
        dtype: torch.dtype | str = DEFAULT_MATRIX_DTYPE,
    ) -> None:
        self.fn_to_call = getattr(torch, call_name)
        self.abstract_args = abstract_args
        self._device = device
        self._dtype = dtype

    def evaluate(self, values: dict[str, torch.Tensor] = dict()) -> torch.Tensor:
        args = []
        for symbol_or_numeric in self.abstract_args:
            if isinstance(symbol_or_numeric, Expr):
                args.append(symbol_or_numeric(values))
            elif isinstance(symbol_or_numeric, (float, int, complex)):
                args.append(torch.tensor(symbol_or_numeric, device = self.device, d))
            elif isinstance(symbol_or_numeric, str):
                args.append(values[symbol_or_numeric])
        return self.fn_to_call(*args)

    def __call__(self, values: dict[str, torch.Tensor] = dict()) -> torch.Tensor:
        return self.evaluate(values)

    (...) all the __add__ and etc definitions (...)

If we then allow all parameters to accept an Expr instance, instead of just str, probably we would no longer need the Embedding class, since the values dict would just get sent into the Expr tree.

For example, the Scale forward currently reads:

    def forward(
        self,
        state: Tensor,
        values: dict[str, Tensor] | ParameterDict = dict(),
        embedding: Embedding | None = None,
    ) -> State:

        if embedding is not None:
            values = embedding(values)

        scale = (
            values[self.param_name]
            if isinstance(self.param_name, str)
            else self.param_name
        )

        return scale * self.operations[0].forward(state, values)

Instead, it could be:

    def forward(
        self,
        state: Tensor,
        values: dict[str, Tensor] | ParameterDict = dict(),
    ) -> State:

        if isinstance(self.param_name, str):
            scale = values[self.param_name]
        elif isinstance(self.param_name, Expr):
            scale = self.param_name(values)
        else:
            scale = self.param_name

        return scale * self.operations[0].forward(state, values)

(and similarly for the tensor method). Then we can write stuff like:

import torch
import pyqtorch as pyq

## Creating an expression and embedding

expr = "z" ** pyq.log(1.0 / (1.0 + (2.0 * pyq.sin("x")) + "y"))

## Passing the expression to a block parameter

op = pyq.Scale(pyq.I(0), expr)

values = {"x": torch.tensor(1.0), "y": torch.tensor(-1.0),  "z": torch.tensor(2.0)}

matrix = op.tensor(values)

matrix[..., 0].real

---

tensor([[0.6971, 0.0000],
        [0.0000, 0.6971]])

Currently, the one other use of the Embedding class is to reembed the time values in the time-dependent hamiltonian evolution. But if instead there is an Expr there where one of the inputs is some parameter "t", we can probably just call it with updated values of "t".

It should be implemented because

No response

Additional context

No response

Would you like to work on this issue?

None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request Refactor
Projects
None yet
Development

No branches or pull requests

1 participant