Skip to content

Commit 9ef8390

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
feat: introduce new step types
1 parent ba68eab commit 9ef8390

File tree

4 files changed

+171
-14
lines changed

4 files changed

+171
-14
lines changed

src/openlayer/lib/integrations/langchain_callback.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,22 @@ def _process_and_upload_trace(self, root_step: steps.Step) -> None:
190190
and root_step.inputs
191191
and "prompt" in root_step.inputs
192192
):
193-
config.update({"prompt": root_step.inputs["prompt"]})
193+
config.update({"prompt": utils.json_serialize(root_step.inputs["prompt"])})
194194

195195
if tracer._publish:
196196
try:
197197
client = tracer._get_client()
198198
if client:
199+
# Apply final JSON serialization to ensure everything is serializable
200+
serialized_trace_data = utils.json_serialize(trace_data)
201+
serialized_config = utils.json_serialize(config)
202+
199203
client.inference_pipelines.data.stream(
200204
inference_pipeline_id=utils.get_env_variable(
201205
"OPENLAYER_INFERENCE_PIPELINE_ID"
202206
),
203-
rows=[trace_data],
204-
config=config,
207+
rows=[serialized_trace_data],
208+
config=serialized_config,
205209
)
206210
except Exception as err: # pylint: disable=broad-except
207211
tracer.logger.error("Could not stream data to Openlayer %s", err)
@@ -270,6 +274,17 @@ def _convert_langchain_objects(self, obj: Any) -> Any:
270274
if hasattr(obj, "messages"):
271275
return [self._convert_langchain_objects(m) for m in obj.messages]
272276

277+
# Handle Pydantic model instances
278+
if hasattr(obj, "model_dump") and callable(getattr(obj, "model_dump")):
279+
try:
280+
return self._convert_langchain_objects(obj.model_dump())
281+
except Exception:
282+
pass
283+
284+
# Handle Pydantic model classes/metaclasses (type objects)
285+
if isinstance(obj, type):
286+
return str(obj.__name__ if hasattr(obj, "__name__") else obj)
287+
273288
# Handle other LangChain objects with common attributes
274289
if hasattr(obj, "dict") and callable(getattr(obj, "dict")):
275290
# Many LangChain objects have a dict() method
@@ -556,6 +571,7 @@ def _handle_chain_start(
556571
metadata={
557572
"tags": tags,
558573
"serialized": serialized,
574+
"is_chain": True,
559575
**(metadata or {}),
560576
**kwargs,
561577
},
@@ -637,14 +653,16 @@ def _handle_tool_start(
637653
run_id=run_id,
638654
parent_run_id=parent_run_id,
639655
name=tool_name,
640-
step_type=enums.StepType.USER_CALL,
656+
step_type=enums.StepType.TOOL,
641657
inputs=tool_input,
642658
metadata={
643659
"tags": tags,
644660
"serialized": serialized,
645661
**(metadata or {}),
646662
**kwargs,
647663
},
664+
function_name=tool_name,
665+
arguments=tool_input,
648666
)
649667

650668
def _handle_tool_end(
@@ -690,13 +708,16 @@ def _handle_agent_action(
690708
run_id=run_id,
691709
parent_run_id=parent_run_id,
692710
name=f"Agent Tool: {action.tool}",
693-
step_type=enums.StepType.USER_CALL,
711+
step_type=enums.StepType.AGENT,
694712
inputs={
695713
"tool": action.tool,
696714
"tool_input": action.tool_input,
697715
"log": action.log,
698716
},
699717
metadata={"agent_action": True, **kwargs},
718+
tool=action.tool,
719+
action=action,
720+
agent_type="langchain_agent",
700721
)
701722

702723
def _handle_agent_finish(
@@ -740,7 +761,7 @@ def _handle_retriever_start(
740761
run_id=run_id,
741762
parent_run_id=parent_run_id,
742763
name=retriever_name,
743-
step_type=enums.StepType.USER_CALL,
764+
step_type=enums.StepType.RETRIEVER,
744765
inputs={"query": query},
745766
metadata={
746767
"tags": tags,
@@ -775,6 +796,11 @@ def _handle_retriever_end(
775796
if current_trace:
776797
current_trace.update_metadata(context=doc_contents)
777798

799+
# Update the step with RetrieverStep-specific attributes
800+
step = self.steps[run_id]
801+
if isinstance(step, steps.RetrieverStep):
802+
step.documents = doc_contents
803+
778804
self._end_step(
779805
run_id=run_id,
780806
parent_run_id=parent_run_id,
@@ -1146,19 +1172,23 @@ def _process_and_upload_async_trace(self, trace: traces.Trace) -> None:
11461172
and root_step.inputs
11471173
and "prompt" in root_step.inputs
11481174
):
1149-
config.update({"prompt": root_step.inputs["prompt"]})
1175+
config.update({"prompt": utils.json_serialize(root_step.inputs["prompt"])})
11501176

11511177
# Upload to Openlayer
11521178
if tracer._publish:
11531179
try:
11541180
client = tracer._get_client()
11551181
if client:
1182+
# Apply final JSON serialization to ensure everything is serializable
1183+
serialized_trace_data = utils.json_serialize(trace_data)
1184+
serialized_config = utils.json_serialize(config)
1185+
11561186
client.inference_pipelines.data.stream(
11571187
inference_pipeline_id=utils.get_env_variable(
11581188
"OPENLAYER_INFERENCE_PIPELINE_ID"
11591189
),
1160-
rows=[trace_data],
1161-
config=config,
1190+
rows=[serialized_trace_data],
1191+
config=serialized_config,
11621192
)
11631193
except Exception as err:
11641194
tracer.logger.error("Could not stream data to Openlayer %s", err)

src/openlayer/lib/tracing/enums.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66
class StepType(enum.Enum):
77
USER_CALL = "user_call"
88
CHAT_COMPLETION = "chat_completion"
9+
AGENT = "agent"
10+
RETRIEVER = "retriever"
11+
TOOL = "tool"
12+
HANDOFF = "handoff"

src/openlayer/lib/tracing/steps.py

Lines changed: 119 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import time
44
import uuid
5-
from typing import Any, Dict, Optional
5+
from typing import Any, Dict, List, Optional
66

77
from .. import utils
88
from . import enums
@@ -54,10 +54,10 @@ def to_dict(self) -> Dict[str, Any]:
5454
"name": self.name,
5555
"id": str(self.id),
5656
"type": self.step_type.value,
57-
"inputs": self.inputs,
58-
"output": self.output,
59-
"groundTruth": self.ground_truth,
60-
"metadata": self.metadata,
57+
"inputs": utils.json_serialize(self.inputs),
58+
"output": utils.json_serialize(self.output),
59+
"groundTruth": utils.json_serialize(self.ground_truth),
60+
"metadata": utils.json_serialize(self.metadata),
6161
"steps": [nested_step.to_dict() for nested_step in self.steps],
6262
"latency": self.latency,
6363
"startTime": self.start_time,
@@ -119,6 +119,116 @@ def to_dict(self) -> Dict[str, Any]:
119119
return step_dict
120120

121121

122+
class AgentStep(Step):
123+
"""Agent step represents an agent in the trace."""
124+
125+
def __init__(
126+
self,
127+
name: str,
128+
inputs: Optional[Any] = None,
129+
output: Optional[Any] = None,
130+
metadata: Optional[Dict[str, any]] = None,
131+
) -> None:
132+
super().__init__(name=name, inputs=inputs, output=output, metadata=metadata)
133+
self.step_type = enums.StepType.AGENT
134+
self.tool: str = None
135+
self.action: Any = None
136+
self.agent_type: str = None
137+
138+
def to_dict(self) -> Dict[str, Any]:
139+
"""Dictionary representation of the AgentStep."""
140+
step_dict = super().to_dict()
141+
step_dict.update(
142+
{
143+
"tool": self.tool,
144+
"action": self.action,
145+
"agentType": self.agent_type,
146+
}
147+
)
148+
return step_dict
149+
150+
151+
class RetrieverStep(Step):
152+
"""Retriever step represents a retriever in the trace."""
153+
154+
def __init__(
155+
self,
156+
name: str,
157+
inputs: Optional[Any] = None,
158+
output: Optional[Any] = None,
159+
metadata: Optional[Dict[str, any]] = None,
160+
) -> None:
161+
super().__init__(name=name, inputs=inputs, output=output, metadata=metadata)
162+
self.step_type = enums.StepType.RETRIEVER
163+
self.documents: List[Any] = None
164+
165+
def to_dict(self) -> Dict[str, Any]:
166+
"""Dictionary representation of the RetrieverStep."""
167+
step_dict = super().to_dict()
168+
step_dict.update(
169+
{
170+
"documents": self.documents,
171+
}
172+
)
173+
return step_dict
174+
175+
176+
class ToolStep(Step):
177+
"""Tool step represents a tool in the trace."""
178+
179+
def __init__(
180+
self,
181+
name: str,
182+
inputs: Optional[Any] = None,
183+
output: Optional[Any] = None,
184+
metadata: Optional[Dict[str, any]] = None,
185+
) -> None:
186+
super().__init__(name=name, inputs=inputs, output=output, metadata=metadata)
187+
self.step_type = enums.StepType.TOOL
188+
self.function_name: str = None
189+
self.arguments: Any = None
190+
191+
def to_dict(self) -> Dict[str, Any]:
192+
"""Dictionary representation of the ToolStep."""
193+
step_dict = super().to_dict()
194+
step_dict.update(
195+
{
196+
"functionName": self.function_name,
197+
"arguments": self.arguments,
198+
}
199+
)
200+
return step_dict
201+
202+
203+
class HandoffStep(Step):
204+
"""Handoff step represents a handoff in the trace."""
205+
206+
def __init__(
207+
self,
208+
name: str,
209+
inputs: Optional[Any] = None,
210+
output: Optional[Any] = None,
211+
metadata: Optional[Dict[str, any]] = None,
212+
) -> None:
213+
super().__init__(name=name, inputs=inputs, output=output, metadata=metadata)
214+
self.step_type = enums.StepType.HANDOFF
215+
self.from_component: str = None
216+
self.to_component: str = None
217+
self.handoff_data: Any = None
218+
219+
def to_dict(self) -> Dict[str, Any]:
220+
"""Dictionary representation of the HandoffStep."""
221+
step_dict = super().to_dict()
222+
step_dict.update(
223+
{
224+
"fromComponent": self.from_component,
225+
"toComponent": self.to_component,
226+
"handoffData": self.handoff_data,
227+
}
228+
)
229+
return step_dict
230+
231+
122232
# ----------------------------- Factory function ----------------------------- #
123233
def step_factory(step_type: enums.StepType, *args, **kwargs) -> Step:
124234
"""Factory function to create a step based on the step_type."""
@@ -127,5 +237,9 @@ def step_factory(step_type: enums.StepType, *args, **kwargs) -> Step:
127237
step_type_mapping = {
128238
enums.StepType.USER_CALL: UserCallStep,
129239
enums.StepType.CHAT_COMPLETION: ChatCompletionStep,
240+
enums.StepType.AGENT: AgentStep,
241+
enums.StepType.RETRIEVER: RetrieverStep,
242+
enums.StepType.TOOL: ToolStep,
243+
enums.StepType.HANDOFF: HandoffStep,
130244
}
131245
return step_type_mapping[step_type](*args, **kwargs)

src/openlayer/lib/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ def json_serialize(data):
4848
return [json_serialize(item) for item in data]
4949
elif isinstance(data, tuple):
5050
return tuple(json_serialize(item) for item in data)
51+
elif isinstance(data, type):
52+
# Handle model classes/metaclasses
53+
return str(data.__name__ if hasattr(data, "__name__") else data)
54+
elif hasattr(data, "model_dump") and callable(getattr(data, "model_dump")):
55+
# Handle Pydantic model instances
56+
try:
57+
return json_serialize(data.model_dump())
58+
except Exception:
59+
return str(data)
5160
else:
5261
# Fallback: Convert to string if not serializable
5362
try:

0 commit comments

Comments
 (0)