Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface to fetch entries in primitive types from DataPack #900

Merged
merged 13 commits into from
Jan 3, 2023
6 changes: 3 additions & 3 deletions forte/data/base_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,13 @@ def get_ids_by_creator(self, component: str) -> Set[int]:
return entry_set

def is_created_by(
self, entry: Entry, components: Union[str, Iterable[str]]
self, entry_tid: int, components: Union[str, Iterable[str]]
Pushkar-Bhuse marked this conversation as resolved.
Show resolved Hide resolved
) -> bool:
"""
Check if the entry is created by any of the provided components.

Args:
entry: The entry to check.
entry_tid: `tid` of the entry to check.
components: The list of component names.

Returns:
Expand All @@ -807,7 +807,7 @@ def is_created_by(
components = [components]

for c in components:
if entry.tid in self._creation_records[c]:
if entry_tid in self._creation_records[c]:
break
else:
# The entry not created by any of these components.
Expand Down
51 changes: 39 additions & 12 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,7 @@ def get( # type: ignore
range_annotation: Optional[Union[Annotation, AudioAnnotation]] = None,
components: Optional[Union[str, Iterable[str]]] = None,
include_sub_type: bool = True,
get_raw: bool = False,
) -> Iterable[EntryType]:
r"""This function is used to get data from a data pack with various
methods.
Expand Down Expand Up @@ -1546,6 +1547,8 @@ def get( # type: ignore
entries generated by any component.
include_sub_type: whether to consider the sub types of
the provided entry type. Default `True`.
get_raw: boolean to indicate if the entry should be returned in
its primitive form as opposed to an object. False by default

Yields:
Each `Entry` found using this method.
Expand All @@ -1566,11 +1569,25 @@ def require_annotations(entry_class=Annotation) -> bool:
# If we don't have any annotations but the items to check requires them,
# then we simply yield from an empty list.
if (
len(self.annotations) == 0
len(
mylibrar marked this conversation as resolved.
Show resolved Hide resolved
list(
self._data_store.all_entries(
"forte.data.ontology.top.Annotation"
)
)
)
== 0
and isinstance(range_annotation, Annotation)
and require_annotations(Annotation)
) or (
len(self.audio_annotations) == 0
len(
list(
self._data_store.all_entries(
"forte.data.ontology.top.AudioAnnotation"
)
)
)
== 0
and isinstance(range_annotation, AudioAnnotation)
and require_annotations(AudioAnnotation)
):
Expand Down Expand Up @@ -1604,21 +1621,31 @@ def require_annotations(entry_class=Annotation) -> bool:
range_span=range_annotation # type: ignore
and (range_annotation.begin, range_annotation.end),
):
entry: Entry = self.get_entry(tid=entry_data[TID_INDEX])

# Filter by components
if components is not None:
if not self.is_created_by(entry, components):
if not self.is_created_by(
entry_data[TID_INDEX], components
):
continue

# Filter out incompatible audio span comparison for Links and Groups
if (
issubclass(entry_type_, (Link, Group))
and isinstance(range_annotation, AudioAnnotation)
and not self._index.in_audio_span(
entry, range_annotation.span
entry: Union[Entry, Dict[str, Any]]
if get_raw:
entry = self._data_store.transform_data_store_entry(
entry_data
)
):
continue
else:
entry = self.get_entry(tid=entry_data[TID_INDEX])

# Filter out incompatible audio span comparison for Links and Groups
if (
mylibrar marked this conversation as resolved.
Show resolved Hide resolved
issubclass(entry_type_, (Link, Group))
and isinstance(range_annotation, AudioAnnotation)
and not self._index.in_audio_span(
entry, range_annotation.span
)
):
continue

yield entry # type: ignore
except ValueError:
Expand Down
123 changes: 121 additions & 2 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __getstate__(self):
state.pop("_DataStore__tid_ref_dict")
state.pop("_DataStore__tid_idx_dict")
state.pop("_DataStore__deletion_count")
state.pop("_type_attributes", None)
state["entries"] = state.pop("_DataStore__elements")
for _, v in state["fields"].items():
if constants.PARENT_CLASS_KEY in v:
Expand Down Expand Up @@ -806,9 +807,22 @@ def fetch_entry_type_data(
# ie. NoneType.
if attr_class is None:
attr_class = type(None)
attr_args = get_args(attr_info.type)
if len(attr_args) == 0:
raw_attr_args = get_args(attr_info.type)
if len(raw_attr_args) == 0:
attr_args = tuple([attr_info.type])
else:
attr_args = ()
for args in raw_attr_args:
# This is the case when we have a multidimensional
# type attribute like List[Tuple[int, int]]. In this
# case get_args will return a tuple of tuples that
# looks like ((Tuple, int, int),). We thus convert
# this into a single dimensional tuple -
# (Tuple, int, int).
if isinstance(args, tuple):
attr_args += args
else:
attr_args += (args,)

# Prior to Python 3.7, fetching generic type
# aliases resulted in actual type objects whereas from
Expand Down Expand Up @@ -1315,6 +1329,111 @@ def _get_existing_ann_entry_tid(self, entry: List[Any]):
"getting entry id for annotation-like entry."
)

def get_attribute_positions(self, type_name: str) -> Dict[str, int]:
r"""This function returns a dictionary where the key represents
the attributes of the entry of type ``type_name`` and value
represents the index of the position where this attribute is
stored in the data store entry of this type.
For example:

.. code-block:: python

positions = data_store.get_attribute_positions(
"ft.onto.base_ontology.Document"
)

# positions = {
# "begin": 2,
# "end": 3,
# "payload_idx": 4,
# "document_class": 5,
# "sentiment": 6,
# "classifications": 7
# }

Args:
type_name (str): The fully qualified type name of a type.

Returns:
A dictionary indicating the attributes of an entry of type
``type_name`` and their respective positions in a data store
entry.
"""
type_data = self._get_type_info(type_name)

positions: Dict[str, int] = {}
for attr, val in type_data[constants.ATTR_INFO_KEY].items():
positions[attr] = val[constants.ATTR_INDEX_KEY]

return positions

def transform_data_store_entry(self, entry: List[Any]) -> Dict:
r"""
This method converts a raw data store entry into a format more easily
understandable to users. Data Store entries are stored as lists and
are not very easily understandable. This method converts ``DataStore``
entries from a list format to a dictionary based format where the key
is the names of the attributes of an entry and the value is the values
corresponding attributes in the data store entry.
For example:

.. code-block:: python

# Entry of type 'ft.onto.base_ontology.Sentence'
Pushkar-Bhuse marked this conversation as resolved.
Show resolved Hide resolved
data_store_entry = [
171792711812874531962213686690228233530,
'ft.onto.base_ontology.Sentence',
0,
164,
0,
'-',
0,
{},
{},
{}
]

transformed_entry = pack.transform_data_store_entry(
data_store_entry
)

# transformed_entry = {
# 'begin': 0,
# 'end': 164,
# 'payload_idx': 0,
# 'speaker': '-',
# 'part_id': 0,
# 'sentiment': {},
# 'classification': {},
# 'classifications': {},
# 'tid': 171792711812874531962213686690228233530,
# 'type': 'ft.onto.base_ontology.Sentence'}
# }


Args:
entry: A list representing a valid data store entry

Returns:
a dictionary representing the the input data store entry
"""

attribute_positions = self.get_attribute_positions(
entry[constants.ENTRY_TYPE_INDEX]
)

# We now convert the entry from data store format (list) to user
# representation format (dict) to make the contents of the entry more
# understandable.
user_rep: Dict[str, Any] = {}
for attr, pos in attribute_positions.items():
user_rep[attr] = entry[pos]

user_rep["tid"] = entry[constants.TID_INDEX]
user_rep["type"] = entry[constants.ENTRY_TYPE_INDEX]

return user_rep

def set_attribute(self, tid: int, attr_name: str, attr_value: Any):
r"""This function locates the entry data with ``tid`` and sets its
``attr_name`` with `attr_value`. It first finds ``attr_id`` according
Expand Down
28 changes: 20 additions & 8 deletions forte/data/multi_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging

from pathlib import Path
from typing import Dict, List, Union, Iterator, Optional, Type, Any, Tuple
from typing import Dict, List, Union, Iterator, Optional, Type, Any, Tuple, cast

import jsonpickle

Expand Down Expand Up @@ -802,7 +802,8 @@ def get( # type: ignore
self,
entry_type: Union[str, Type[EntryType]],
components: Optional[Union[str, List[str]]] = None,
include_sub_type=True,
include_sub_type: bool = True,
get_raw: bool = False,
) -> Iterator[EntryType]:
"""Get entries of ``entry_type`` from this multi pack.

Expand All @@ -827,6 +828,8 @@ def get( # type: ignore
any component will be returned.
include_sub_type: whether to return the sub types of the
queried `entry_type`. True by default.
get_raw: boolean to indicate if the entry should be returned in
its primitive form as opposed to an object. False by default

Returns:
An iterator of the entries matching the arguments, following
Expand Down Expand Up @@ -855,16 +858,25 @@ def get( # type: ignore
type_name=get_full_module_name(entry_type_),
include_sub_type=include_sub_type,
):
entry: Entry = self._entry_converter.get_entry_object(
tid=entry_data[TID_INDEX],
pack=self,
type_name=entry_data[ENTRY_TYPE_INDEX],
)
# Filter by components
if components is not None:
if not self.is_created_by(entry, components):
if not self.is_created_by(
entry_data[TID_INDEX], components
):
continue

entry: Union[Entry, Dict[str, Any]]

if get_raw:
data_store = cast(DataStore, self._data_store)
entry = data_store.transform_data_store_entry(entry_data)
else:
entry = self._entry_converter.get_entry_object(
tid=entry_data[TID_INDEX],
pack=self,
type_name=entry_data[ENTRY_TYPE_INDEX],
)

yield entry # type: ignore
except ValueError:
# type_name does not exist in DataStore
Expand Down
20 changes: 7 additions & 13 deletions forte/data/ontology/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
BaseGroup,
MultiEntry,
EntryType,
FList,
)
from forte.data.span import Span
from forte.utils.utils import get_full_module_name
Expand Down Expand Up @@ -317,7 +316,7 @@ class Group(BaseGroup[Entry]):
a "coreference group" is a group of coreferential entities. Each group will
store a set of members, no duplications allowed.
"""
members: FList[Entry]
members: List[int]
member_type: str

MemberType = Entry
Expand Down Expand Up @@ -345,7 +344,7 @@ def add_member(self, member: Entry):
f"The members of {type(self)} should be "
f"instances of {self.MemberType}, but got {type(member)}"
)
self.members.append(member)
self.members.append(member.tid)

def get_members(self) -> List[Entry]:
r"""Get the member entries in the group. The function will retrieve
Expand All @@ -365,7 +364,7 @@ def get_members(self) -> List[Entry]:
member_entries = []
if self.members is not None:
for m in self.members:
member_entries.append(m)
member_entries.append(self.pack.get_entry(m))
return member_entries


Expand Down Expand Up @@ -515,7 +514,7 @@ class MultiPackGroup(MultiEntry, BaseGroup[Entry]):
of members.
"""
member_type: str
members: Optional[FList[Entry]]
members: List[Tuple[int, int]]
mylibrar marked this conversation as resolved.
Show resolved Hide resolved

MemberType = Entry

Expand All @@ -527,7 +526,6 @@ def __init__(
# in data store and must thus be in a primitive form.
self.member_type = get_full_module_name(self.MemberType)
super().__init__(pack)

if members is not None:
self.add_members(members)

Expand All @@ -537,17 +535,13 @@ def add_member(self, member: Entry):
f"The members of {type(self)} should be "
f"instances of {self.MemberType}, but got {type(member)}"
)
if self.members is None:
self.members = cast(FList, [member])
else:
self.members.append(member)
self.members.append((member.pack_id, member.tid))

def get_members(self) -> List[Entry]:
members = []
if self.members is not None:
member_data = self.members
for m in member_data:
members.append(m)
for pack_idx, member_tid in self.members:
members.append(self.pack.get_subentry(pack_idx, member_tid))
return members


Expand Down
Loading