@@ -1321,7 +1321,7 @@ def __init__(
1321
1321
domain : str ,
1322
1322
op_type : str ,
1323
1323
inputs : Iterable [Value | None ],
1324
- attributes : Iterable [Attr | RefAttr ] = (),
1324
+ attributes : Iterable [Attr ] = (),
1325
1325
* ,
1326
1326
overload : str = "" ,
1327
1327
num_outputs : int | None = None ,
@@ -1353,7 +1353,7 @@ def __init__(
1353
1353
metadata_props: The metadata properties.
1354
1354
1355
1355
Raises:
1356
- TypeError: If the attributes are not :class:`Attr` or :class:`RefAttr` .
1356
+ TypeError: If the attributes are not :class:`Attr`.
1357
1357
ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
1358
1358
ValueError: If an output value is ``None``, when outputs is specified.
1359
1359
ValueError: If an output value has a producer set already, when outputs is specified.
@@ -1368,13 +1368,13 @@ def __init__(
1368
1368
# Values belong to their defining nodes. The values list is immutable
1369
1369
self ._outputs : tuple [Value , ...] = self ._create_outputs (num_outputs , outputs )
1370
1370
attributes = tuple (attributes )
1371
- if attributes and not isinstance (attributes [0 ], ( Attr , RefAttr ) ):
1371
+ if attributes and not isinstance (attributes [0 ], Attr ):
1372
1372
raise TypeError (
1373
- f"Expected the attributes to be Attr or RefAttr , got { type (attributes [0 ])} . "
1373
+ f"Expected the attributes to be Attr, got { type (attributes [0 ])} . "
1374
1374
"If you are copying the attributes from another node, make sure you call "
1375
1375
"node.attributes.values() because it is a dictionary."
1376
1376
)
1377
- self ._attributes : OrderedDict [str , Attr | RefAttr ] = OrderedDict (
1377
+ self ._attributes : OrderedDict [str , Attr ] = OrderedDict (
1378
1378
(attr .name , attr ) for attr in attributes
1379
1379
)
1380
1380
self ._overload : str = overload
@@ -1633,7 +1633,7 @@ def outputs(self, _: Sequence[Value]) -> None:
1633
1633
raise AttributeError ("outputs is immutable. Please create a new node instead." )
1634
1634
1635
1635
@property
1636
- def attributes (self ) -> OrderedDict [str , Attr | RefAttr ]:
1636
+ def attributes (self ) -> OrderedDict [str , Attr ]:
1637
1637
"""The attributes of the node."""
1638
1638
return self ._attributes
1639
1639
@@ -3106,22 +3106,28 @@ def __repr__(self) -> str:
3106
3106
return f"{ self .__class__ .__name__ } ({ self .domain !r} , { self .name !r} , { self .overload !r} , inputs={ self .inputs !r} , attributes={ self .attributes !r} ), outputs={ self .outputs !r} )"
3107
3107
3108
3108
3109
- class RefAttr (_protocols .ReferenceAttributeProtocol , _display .PrettyPrintable ):
3110
- """Reference attribute."""
3109
+ class Attr (
3110
+ _protocols .AttributeProtocol ,
3111
+ _protocols .ReferenceAttributeProtocol ,
3112
+ _display .PrettyPrintable ,
3113
+ ):
3114
+ """Base class for ONNX attributes or references."""
3111
3115
3112
- __slots__ = ("_name" , "_ref_attr_name" , "_type" , "doc_string" )
3116
+ __slots__ = ("_name" , "_ref_attr_name" , "_type" , "_value" , " doc_string" )
3113
3117
3114
3118
def __init__ (
3115
3119
self ,
3116
3120
name : str ,
3117
- ref_attr_name : str ,
3118
3121
type : _enums .AttributeType ,
3122
+ value : Any ,
3123
+ ref_attr_name : str | None = None ,
3119
3124
* ,
3120
3125
doc_string : str | None = None ,
3121
- ) -> None :
3126
+ ):
3122
3127
self ._name = name
3123
- self ._ref_attr_name = ref_attr_name
3124
3128
self ._type = type
3129
+ self ._value = value
3130
+ self ._ref_attr_name = ref_attr_name
3125
3131
self .doc_string = doc_string
3126
3132
3127
3133
@property
@@ -3132,43 +3138,21 @@ def name(self) -> str:
3132
3138
def name (self , value : str ) -> None :
3133
3139
self ._name = value
3134
3140
3135
- @property
3136
- def ref_attr_name (self ) -> str :
3137
- return self ._ref_attr_name
3138
-
3139
- @ref_attr_name .setter
3140
- def ref_attr_name (self , value : str ) -> None :
3141
- self ._ref_attr_name = value
3142
-
3143
3141
@property
3144
3142
def type (self ) -> _enums .AttributeType :
3145
3143
return self ._type
3146
3144
3147
- @type .setter
3148
- def type (self , value : _enums .AttributeType ) -> None :
3149
- self ._type = value
3150
-
3151
- def __repr__ (self ) -> str :
3152
- return f"{ self .__class__ .__name__ } ({ self ._name !r} , { self ._type !r} , ref_attr_name={ self .ref_attr_name !r} )"
3153
-
3154
-
3155
- class Attr (_protocols .AttributeProtocol , _display .PrettyPrintable ):
3156
- """Base class for ONNX attributes."""
3145
+ @property
3146
+ def value (self ) -> Any :
3147
+ return self ._value
3157
3148
3158
- __slots__ = ("doc_string" , "name" , "type" , "value" )
3149
+ @property
3150
+ def ref_attr_name (self ) -> str | None :
3151
+ return self ._ref_attr_name
3159
3152
3160
- def __init__ (
3161
- self ,
3162
- name : str ,
3163
- type : _enums .AttributeType ,
3164
- value : Any ,
3165
- * ,
3166
- doc_string : str | None = None ,
3167
- ):
3168
- self .name = name
3169
- self .type = type
3170
- self .value = value
3171
- self .doc_string = doc_string
3153
+ def is_ref (self ) -> bool :
3154
+ """Check if this attribute is a reference attribute."""
3155
+ return self .ref_attr_name is not None
3172
3156
3173
3157
def __eq__ (self , other : object ) -> bool :
3174
3158
if not isinstance (other , _protocols .AttributeProtocol ):
@@ -3185,11 +3169,15 @@ def __eq__(self, other: object) -> bool:
3185
3169
return True
3186
3170
3187
3171
def __str__ (self ) -> str :
3172
+ if self .is_ref ():
3173
+ return f"@{ self .ref_attr_name } "
3188
3174
if self .type == _enums .AttributeType .GRAPH :
3189
3175
return textwrap .indent ("\n " + str (self .value ), " " * 4 )
3190
3176
return str (self .value )
3191
3177
3192
3178
def __repr__ (self ) -> str :
3179
+ if self .is_ref ():
3180
+ return f"{ self .__class__ .__name__ } ({ self .name !r} , { self .type !r} , ref_attr_name={ self .ref_attr_name !r} )"
3193
3181
return f"{ self .__class__ .__name__ } ({ self .name !r} , { self .type !r} , { self .value !r} )"
3194
3182
3195
3183
# Well typed getters
@@ -3269,6 +3257,29 @@ def as_graphs(self) -> Sequence[Graph]:
3269
3257
3270
3258
3271
3259
# NOTE: The following functions are just for convenience
3260
+
3261
+
3262
+ def RefAttr (
3263
+ name : str ,
3264
+ ref_attr_name : str ,
3265
+ type : _enums .AttributeType ,
3266
+ doc_string : str | None = None ,
3267
+ ) -> Attr :
3268
+ """Create a reference attribute.
3269
+
3270
+ Args:
3271
+ name: The name of the attribute.
3272
+ type: The type of the attribute.
3273
+ ref_attr_name: The name of the referenced attribute.
3274
+ doc_string: Documentation string.
3275
+
3276
+ Returns:
3277
+ A reference attribute.
3278
+ """
3279
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
3280
+ return Attr (name , type , None , ref_attr_name = ref_attr_name , doc_string = doc_string )
3281
+
3282
+
3272
3283
def AttrFloat32 (name : str , value : float , doc_string : str | None = None ) -> Attr :
3273
3284
"""Create a float attribute."""
3274
3285
# NOTE: The function name is capitalized to maintain API backward compatibility.
0 commit comments