Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pydantic-ai-graph - simplify public generics #539

Merged
merged 3 commits into from
Jan 2, 2025

Conversation

dmontagu
Copy link
Contributor

This is a refactoring of the work in #528 attempting to reduce the number of user-facing generic parameters.

There are some things that we can probably adjust/remove, such as the introduction of both GraphRunner and GraphRun. But there are some benefits to this, as it provides a way to get nearly the same API we currently have without BaseNode being aware of the input type, and without forcing BaseNode subclasses to have a fixed signature.

Some of the less consequential changes in here (such as moving/duplicating typevars) were done because I was hitting cyclic import errors and/or spurious type errors (I think I maybe ran into some bugs in pyright that I no longer know how to reproduce).

I'm happy to rework the implementation based on feedback if there's interest, but I don't really want to go through adding lots of tests and documenting everything if this line of implementation is going to be rejected, so @samuelcolvin maybe we can discuss synchronously.

@dmontagu
Copy link
Contributor Author

Note I say "user-facing generic parameters" because, for example, while GraphRunner has a ParamSpec as one of its parameters, users should never need to manually annotate it — they would just pass a node class to the get_runner method of Graph. This is important because it's hard to manually specify a paramspec for a node, but it's not so hard to manually specify the parameters for the Graph type itself, which may be desirable to do as a way to ensure the nodes being added satisfy constraints on their run-end types, rather than letting that be inferred by the type system.

Copy link

cloudflare-workers-and-pages bot commented Dec 23, 2024

Deploying pydantic-ai with  Cloudflare Pages  Cloudflare Pages

Latest commit: c8790d9
Status: ✅  Deploy successful!
Preview URL: https://afb81913.pydantic-ai.pages.dev
Branch Preview URL: https://dmontagu-graph-refactor.pydantic-ai.pages.dev

View logs

Comment on lines -26 to -33
class Snapshot:
"""Snapshot of a graph."""

last_node_id: str
next_node_id: str
start_ts: datetime
duration: float
state: bytes | None = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced this with non-serialized history events mostly because I think people will want to interact with the deserialized history, so storing it as raw bytes seems unfortunate. I think we need to ensure both Step and EndEvent are serializable (and therefore all BaseNode instances and all RunEndT values need to be serializable). I'm not sure the best way to achieve that though.

Copy link
Member

@samuelcolvin samuelcolvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall I think this looks good.


NodeInputT = TypeVar('NodeInputT', default=Any)
GraphOutputT = TypeVar('GraphOutputT', default=Any)
RunEndT = TypeVar('RunEndT', default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should not have a default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we get rid of the default, then it needs to come before StateT everywhere. I'm okay with that, but just pointing it out. I will make that change

Copy link
Contributor Author

@dmontagu dmontagu Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's awkward because parameters with defaults need to go after parameters without defaults, which means that unless we drop the default value of StateT for Graph, we'd have to put RunEndT before StateT on Graph. But then you have to make a decision of whether you want consistency in parameter ordering on BaseNode, or convenience when omitting default parameters:

  • If we keep the ordering of parameters on BaseNode as it currently is, then Graph has the parameters in one order and BaseNode has it in the other, which feels like a recipe for confusion.
  • If we change the ordering of parameters on BaseNode, then to explicitly specify the StateT type you need to specify the NodeRunEndT, but explicitly specifying the StateT will be a lot more common than explicitly specifying the node run end type, so that feels pretty unfortunate.

Given the above, I now feel like we should either:

  • keep a default of None for the RunEndT on Graph, or
  • drop the default of None for the StateT on Graph.

I'm okay with either, lmk what you think. (I'm also okay with the first two bullets' suggestions, if you really prefer, but I think they are worse than making it so the graph typevars either both have defaults or both don't, and preserving parameter ordering).

Copy link
Contributor Author

@dmontagu dmontagu Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll note that trying to do the refactor, I think things get a little more complicated if we remove the default, and I'm not sure that having the default of None is that crazy given that it provides a way to end the run without returning a meaningful value, which may be more common than you'd expect if it ends up being common to create agents that do mutating tool calls before finishing rather than returning values that are consumed downstream.


def __init__(self, input_data: NodeInputT) -> None:
self.input_data = input_data
_node_id: ClassVar[str | None] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this have to be private? I get it can be but think it'll be clearer that it's part of the public api if it's public.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be part of the public API, but the problem is that in lots of places we were accessing it directly, rather than via get_id() which will grab the actual string, and which needs to be used everywhere (or else you'll get Nones and therefore not unique ids). I changed it to private to make that mistake less likely but I understand your point that now it looks like users aren't supposed to set it. Maybe we can call it node_id_override instead? Or we can just keep it as node_id but basically I just don't want to make that mistake of accessing that value instead of calling node.get_id(). Maybe we can set it in the metaclass if it doesn't have an explicitly-set value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it wouldn't just be better to drop the attribute entirely and have people override get_id if they want to manually change the value. That would ensure it's not possible to use the wrong api, for the in-my-opinion small cost of requiring a little more verbosity if you want to override the value. Given I expect we'll generally discourage overriding the value that seems okay to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm imagining adding a diff that looks like this:

     @classmethod
     @cache
     def get_id(cls) -> str:
-        return cls._node_id or cls.__name__
+        """Get the ID of the node.
+
+        You can override this if you want to serialize/deserialize nodes using a specific ID.
+        """
+        return cls.__name__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed _node_id for now, didn't add that docstring since we don't have node serialization/deserialization now and not sure that's how it will work. Can undo this change if preferred though.

pydantic_ai_graph/pydantic_ai_graph/graph.py Show resolved Hide resolved
pydantic_ai_graph/pydantic_ai_graph/graph.py Show resolved Hide resolved
pydantic_ai_graph/pydantic_ai_graph/graph.py Show resolved Hide resolved
pydantic_ai_graph/pydantic_ai_graph/graph.py Outdated Show resolved Hide resolved


@dataclass
class GraphRunner(Generic[RunSignatureT, StateT, RunEndT]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good, but I wonder if we could find a better name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had GraphExecutor before, not sure if that's better or worse. I agree it sort of feels like there should be a better name out there, but I worry that a better name would require more refactors in order to be appropriate. (I mean, maybe that would be overall a good thing, though I'm hoping to avoid it.)



@dataclass
class GraphRun(Generic[StateT, RunEndT]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we need this? Maybe we could make it private if we're not sure?

Copy link
Contributor Author

@dmontagu dmontagu Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's currently the only thing that exposes the step API, which I feel like you want for unit testing. Note that I think we want to expose a way to do unit testing that goes beyond just calling node.run specifically for the sake of testing callbacks (once we add support for them). Maybe there are other reasons as well, that's less clear to me though.

But I think we could rework so GraphRun is private and there is a public function that users can call that basically does the same thing as pydantic_ai_graph.graph.GraphRun.step with callback execution, but doesn't require access to that class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear I definitely don't think we need to have this class, it could be replaced with function calls operating on state and history directly, but:

  • My comment above was a response to the suggestion of making it private — I think the functionality for stepping through a graph (with callbacks) should be public. (But that doesn't mean it needs to be a class.)
  • I think having a class will make it easier to add features over time (in particular callbacks, though maybe other things). But we could drop the class for now and/or make it private as a way to keep users from depending on APIs we may want to break in the near future.

pydantic_ai_graph/pydantic_ai_graph/state.py Outdated Show resolved Hide resolved
tests/test_graph.py Outdated Show resolved Hide resolved
@dmontagu
Copy link
Contributor Author

I addressed some of the comments above, I've left unresolved the comments that I wasn't confident enough to resolve without further discussion (or at least further insistence on your part).

@samuelcolvin samuelcolvin merged commit b9155e8 into graph Jan 2, 2025
15 checks passed
@samuelcolvin samuelcolvin deleted the dmontagu/graph-refactor branch January 2, 2025 18:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants