Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added execution traces and state updates support for Starknet #64

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sqa/query/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def __init__(self, item: Item, fields: FieldSelection, sources: list[ItemSrcQuer

self.sql = (f'SELECT _idx AS idx, block_number, ({weight_exp}) AS weight '
f'FROM read_parquet({self.params.new_variable("file")}) i '
f'SEMI JOIN keys ON i.block_number = keys.block_number AND ')
f'SEMI JOIN keys ON i.block_number = keys.block_number'
f'{" AND " if item.table().primary_key else ""}')

self.sql += join_condition(item.table().primary_key, 'i', 'keys')

Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(self, item: Item, fields: FieldSelection):
self.table_name = item.table().name
self.projected_columns = ['block_number'] + item.selected_columns(fields)

order = ', '.join(item.table().primary_key)
order = ', '.join(item.table().primary_key) or 'block_number'

self.sql = (f'SELECT block_number, to_json(list({item.project(fields)} ORDER BY {order})) AS data '
f'FROM items '
Expand Down
190 changes: 189 additions & 1 deletion sqa/starknet/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,53 @@ class EventFieldSelection(TypedDict, total=False):
data: bool


class TraceFieldSelection(TypedDict, total=False):
invocationType: bool
callerAddress: bool
callContractAddress: bool
callType: bool
callClassHash: bool
callEntryPointSelector: bool
callEntryPointType: bool
callRevertReason: bool
calldata: bool
callResult: bool
parentIndex: bool
traceType: bool


class StateUpdateFieldSelection(TypedDict, total=False):
blockNumber: bool
blockHash: bool
newRoot: bool
oldRoot: bool
storageDiffsAddress: bool
storageDiffsKeys: bool
storageDiffsValues: bool
deprecatedDeclaredClasses: bool
declaredClassesClassHash: bool
declaredClassesCompiledClassHash: bool
deployedContractsAddress: bool
deployedContractsClassHash: bool
replacedClassesContractAddress: bool
replacedClassesClassHash: bool
noncesContractAddress: bool
noncesNonce: bool

class FieldSelection(TypedDict, total=False):
block: BlockFieldSelection
transaction: TransactionFieldSelection
event: EventFieldSelection
trace: TraceFieldSelection
stateUpdate: StateUpdateFieldSelection


class _FieldSelectionSchema(mm.Schema):
block = field_map_schema(BlockFieldSelection)
transaction = field_map_schema(TransactionFieldSelection)
event = field_map_schema(EventFieldSelection)
trace = field_map_schema(TraceFieldSelection)
stateUpdate = field_map_schema(StateUpdateFieldSelection)


class TransactionRequest(TypedDict, total=False):
Expand Down Expand Up @@ -94,10 +131,47 @@ class _EventRequestSchema(mm.Schema):
transaction = mm.fields.Boolean()


class TraceRequest(TypedDict, total=False):
# TODO: potentialy filter by calldata simular to keys of events
invocationType: list[str]
callerAddress: list[str]
callContractAddress: list[str]
callClassHash: list[str]
callType: list[str]
callEntryPointSelector: list[str]
callEntryPointType: list[str]
traceType: list[str]
transaction: bool


class _TraceRequestSchema(mm.Schema):
invocationType: list[str]
Wizard1209 marked this conversation as resolved.
Show resolved Hide resolved
callerAddress = mm.fields.List(mm.fields.Str())
callContractAddress = mm.fields.List(mm.fields.Str())
callClassHash = mm.fields.List(mm.fields.Str())
callType = mm.fields.List(mm.fields.Str())
callEntryPointSelector = mm.fields.List(mm.fields.Str())
callEntryPointType = mm.fields.List(mm.fields.Str())
traceType = mm.fields.List(mm.fields.Str())
transaction = mm.fields.Boolean()


class StateUpdateRequest(TypedDict, total=False):
newRoot: list[str]
oldRoot: list[str]


class _StateUpdateRequestSchema(mm.Schema):
newRoot = mm.fields.List(mm.fields.Str())
oldRoot = mm.fields.List(mm.fields.Str())


class _QuerySchema(BaseQuerySchema):
fields = mm.fields.Nested(_FieldSelectionSchema())
transactions = mm.fields.List(mm.fields.Nested(_TransactionRequestSchema()))
events = mm.fields.List(mm.fields.Nested(_EventRequestSchema()))
traces = mm.fields.List(mm.fields.Nested(_TraceRequestSchema()))
stateUpdates = mm.fields.List(mm.fields.Nested(_StateUpdateRequestSchema()))


QUERY_SCHEMA = _QuerySchema()
Expand Down Expand Up @@ -134,6 +208,43 @@ class _QuerySchema(BaseQuerySchema):
)


_traces_table = Table(
name='traces',
primary_key=['transaction_index', 'call_index'],
column_weights={
'calldata': 'calldata_size',
'call_result': 'call_result_size',
'call_events_keys': 'call_events_keys_size',
'call_events_data': 'call_events_data_size',
'call_events_order': 'call_events_order_size',
'call_messages_payload': 'call_messages_payload_size',
'call_messages_from_address': 'call_messages_from_address_size',
'call_messages_to_address': 'call_messages_to_address_size',
'call_messages_order': 'call_messages_order_size',
}
)


_state_updates_table = Table(
name='state_updates',
primary_key=[],
column_weights={
'storage_diffs_address': 'storage_diffs_address_size',
'storage_diffs_keys': 'storage_diffs_keys_size',
'storage_diffs_values': 'storage_diffs_values_size',
'deprecated_declared_classes': 'deprecated_declared_classes_size',
'declared_classes_class_hash': 'declared_classes_class_hash_size',
'declared_classes_compiled_class_hash': 'declared_classes_compiled_class_hash_size',
'deployed_contracts_address': 'deployed_contracts_address_size',
'deployed_contracts_class_hash': 'deployed_contracts_class_hash_size',
'replaced_classes_contract_address': 'replaced_classes_contract_address_size',
'replaced_classes_class_hash': 'replaced_classes_class_hash_size',
'nonces_contract_address': 'nonces_contract_address_size',
'nonces_nonce': 'nonces_nonce_size',
}
)


class _BlockItem(Item):
def table(self) -> Table:
return _blocks_table
Expand Down Expand Up @@ -223,13 +334,72 @@ def project(self, fields: FieldSelection) -> str:
})


class _TraceScan(Scan):
def table(self) -> Table:
return _traces_table

def request_name(self) -> str:
return 'traces'

def where(self, req: TraceRequest) -> Iterable[Expression | None]:
yield field_in('caller_address', req.get('callerAddress'))
yield field_in('call_contract_address', req.get('callContractAddress'))
yield field_in('call_type', req.get('callType'))
yield field_in('call_class_hash', req.get('callClassHash'))
yield field_in('call_entry_point_selector', req.get('callEntryPointSelector'))
yield field_in('call_entry_point_type', req.get('callEntryPointType'))
yield field_in('trace_type', req.get('traceType'))
yield field_in('invocation_type', req.get('invocationType'))


class _TraceItem(Item):
def table(self) -> Table:
return _traces_table

def name(self) -> str:
return 'traces'

def get_selected_fields(self, fields: FieldSelection) -> list[str]:
return get_selected_fields(fields.get('trace'), [
'transaction_index',
'call_index'
])


class _StateUpdateScan(Scan):
def table(self) -> Table:
return _state_updates_table

def request_name(self) -> str:
return 'stateUpdates'

def where(self, req: StateUpdateRequest) -> Iterable[Expression | None]:
yield field_in('new_root', req.get('newRoot'))
yield field_in('old_root', req.get('oldRoot'))


class _StateUpdateItem(Item):
def table(self) -> Table:
return _state_updates_table

def name(self) -> str:
return 'state_updates'

def get_selected_fields(self, fields: FieldSelection) -> list[str]:
return get_selected_fields(fields.get('stateUpdate'))


def _build_model() -> Model:
tx_scan = _TxScan()
event_scan = _EventScan()
trace_scan = _TraceScan()
state_update_scan = _StateUpdateScan()

block_item = _BlockItem()
tx_item = _TxItem()
event_item = _EventItem()
trace_item = _TraceItem()
state_update_item = _StateUpdateItem()

event_item.sources.extend([
event_scan,
Expand All @@ -248,10 +418,28 @@ def _build_model() -> Model:
scan=event_scan,
include_flag_name='transaction',
scan_columns=['transaction_index']
),
RefRel(
scan=trace_scan,
include_flag_name='transaction',
scan_columns=['transaction_index']
)
])

return [tx_scan, event_scan, block_item, tx_item, event_item]
trace_item.sources.extend([
trace_scan,
JoinRel(
scan=tx_scan,
include_flag_name='traces',
query='SELECT * FROM traces i, s WHERE '
'i.block_number = s.block_number AND '
'i.transaction_index = s.transaction_index'
)
])

state_update_item.sources.extend([state_update_scan])

return [tx_scan, event_scan, trace_scan, state_update_scan, block_item, tx_item, event_item, trace_item, state_update_item]


MODEL = _build_model()
53 changes: 45 additions & 8 deletions sqa/starknet/writer/ingest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from collections import defaultdict
import logging
from typing import AsyncIterator, Optional, cast

from sqa.eth.ingest.ingest import _run_subtasks
from sqa.starknet.writer.model import EventPage, Block, Event, WriterBlock, WriterEvent, WriterTransaction
from sqa.starknet.writer.model import BlockStateUpdate, EventPage, Block, Event, Trace, WriterBlock, WriterBlockStateUpdate, WriterEvent, WriterTrace, WriterTransaction
from sqa.util.rpc.client import RpcClient


Expand Down Expand Up @@ -100,8 +101,6 @@ async def _fetch_starknet_stride(self, from_block: int, to_block: int) -> list[W

await _run_subtasks(self._starknet_stride_subtasks(blocks))

# TODO: if self._with_traces and not self._with_receipts:_txs_with_missing_status

# NOTE: This code moved here to separate retrieving raw data as it is in node from preparing data to comply with archive format
strides: list[WriterBlock] = self.make_writer_ready_blocks(blocks)
LOG.debug('stride is ready', extra=extra)
Expand All @@ -113,7 +112,11 @@ def _starknet_stride_subtasks(self, blocks: list[Block]):
else:
yield self._fetch_starknet_logs(blocks)

# TODO: _fetch_traces
if self._with_traces:
yield self._fetch_starknet_traces(blocks)

if self._with_statediffs:
yield self._fetch_starknet_state_updates(blocks)

async def _fetch_starknet_logs(self, blocks: list[Block]) -> None:
priority = blocks[0]['block_number']
Expand Down Expand Up @@ -154,6 +157,29 @@ async def _fetch_starknet_logs(self, blocks: list[Block]) -> None:
async def _fetch_starknet_receipts(self, blocks: list[Block]) -> None:
raise NotImplementedError('receipts for starknet not implemented') # TODO: https://docs.alchemy.com/reference/starknet-gettransactionreceipt for block for tx

async def _fetch_starknet_traces(self, blocks: list[Block]) -> None:
for block in blocks:
num = block['block_number']

traces: list[Trace] = await self._rpc.call(
'starknet_traceBlockTransactions',
params=[{"block_number": num}],
priority=num
)
block['traces'] = traces

async def _fetch_starknet_state_updates(self, blocks: list[Block]) -> None:
for block in blocks:
num = block['block_number']

state_update: BlockStateUpdate = await self._rpc.call(
'starknet_getStateUpdate',
params=[{"block_number": num}],
priority=num
)

block['state_update'] = state_update

async def _detect_special_chains(self) -> None:
self._is_starknet = True

Expand All @@ -166,8 +192,6 @@ async def _fetch_starknet_blocks(self, from_block: int, to_block: int) -> list[B
priority=from_block
)

# TODO: validate tx root?

return blocks

async def _get_chain_height(self) -> int:
Expand All @@ -180,7 +204,6 @@ def make_writer_ready_blocks(blocks: list[Block]) -> list[WriterBlock]:
stride: list[WriterBlock] = cast(list[WriterBlock], blocks) # cast ahead for less mypy problems
# NOTE: This function transform exact RPC node objects to Writer object with all extra fields for writing to table
transaction_hash_to_index = {}
event_index = {}
for block in stride:
block['number'] = block['block_number']
block['hash'] = block['block_hash']
Expand All @@ -193,14 +216,28 @@ def make_writer_ready_blocks(blocks: list[Block]) -> list[WriterBlock]:
tx['block_number'] = block['block_number']

transaction_hash_to_index[tx['transaction_hash']] = tx['transaction_index']
event_index[tx['transaction_hash']] = 0

# could be done with one dict, but i thought its nicer with two
for block in stride:
event_index: dict[str, int] = defaultdict(lambda: 0)
block['writer_events'] = cast(list[WriterEvent], block['events'])
for event in block['writer_events']:
event['transaction_index'] = transaction_hash_to_index[event['transaction_hash']]
event['event_index'] = event_index[event['transaction_hash']]
event_index[event['transaction_hash']] += 1

for block in stride:
if 'traces' not in block:
continue
block['writer_traces'] = cast(list[WriterTrace], block['traces'])
for trace in block['writer_traces']:
trace['transaction_index'] = transaction_hash_to_index[trace['transaction_hash']]
trace['block_number'] = block['block_number']

for block in stride:
if 'state_update' not in block:
continue
block['writer_state_update'] = cast(list[WriterBlockStateUpdate], block['state_update'])
block['writer_state_update']['block_number'] = block['block_number']

return stride
Loading