Skip to content

Commit

Permalink
make subscription methods more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
jmorganca authored Nov 26, 2024
2 parents 6c44bb2 + d4c3897 commit 758a1d2
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,32 @@

class SubscriptableBaseModel(BaseModel):
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
"""
>>> msg = Message(role='user')
>>> msg['role']
'user'
>>> msg = Message(role='user')
>>> msg['nonexistent']
Traceback (most recent call last):
KeyError: 'nonexistent'
"""
if key in self:
return getattr(self, key)

raise KeyError(key)

def __setitem__(self, key: str, value: Any) -> None:
"""
>>> msg = Message(role='user')
>>> msg['role'] = 'assistant'
>>> msg['role']
'assistant'
>>> tool_call = Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))
>>> msg = Message(role='user', content='hello')
>>> msg['tool_calls'] = [tool_call]
>>> msg['tool_calls'][0]['function']['name']
'foo'
"""
setattr(self, key, value)

def __contains__(self, key: str) -> bool:
Expand Down Expand Up @@ -61,7 +84,20 @@ def __contains__(self, key: str) -> bool:
return False

def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
"""
>>> msg = Message(role='user')
>>> msg.get('role')
'user'
>>> msg = Message(role='user')
>>> msg.get('nonexistent')
>>> msg = Message(role='user')
>>> msg.get('nonexistent', 'default')
'default'
>>> msg = Message(role='user', tool_calls=[ Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))])
>>> msg.get('tool_calls')[0]['function']['name']
'foo'
"""
return self[key] if key in self else default


class Options(SubscriptableBaseModel):
Expand Down

0 comments on commit 758a1d2

Please sign in to comment.