Skip to content

Commit

Permalink
Stave Database Reader (#351)
Browse files Browse the repository at this point in the history
* Add draft stave reader and tests.

* Add ontology in stave db reader tests.

* allow no-prefix check during generation

* dynamic

* Correct a few parameter calls.

* More test debugging.

* Fix test error

* fix mypy pylint

* mypy and pylint.

* pylint error

* mypy errors.

Co-authored-by: hector.liu <hector.liu@petuum.com>
  • Loading branch information
hunterhector and hector.liu authored Dec 29, 2020
1 parent 6fc056a commit 900cf93
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 31 deletions.
36 changes: 21 additions & 15 deletions forte/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import tarfile
import urllib.request
import zipfile
from typing import List, Optional, overload
from os import PathLike
from typing import List, Optional, overload, Union

import jsonpickle

from forte.utils.types import PathLike
from forte.utils.utils_io import maybe_create_dir

__all__ = [
Expand All @@ -37,26 +37,29 @@
# pylint: disable=unused-argument,function-redefined,missing-docstring

@overload
def maybe_download(urls: List[str], path: PathLike,
def maybe_download(urls: List[str], path: Union[str, PathLike],
filenames: Optional[List[str]] = None,
extract: bool = False) -> List[str]: ...


@overload
def maybe_download(urls: str, path: PathLike, filenames: Optional[str] = None,
def maybe_download(urls: str, path: Union[str, PathLike],
filenames: Optional[str] = None,
extract: bool = False) -> str: ...


def maybe_download(urls, path, filenames=None, extract=False):
def maybe_download(urls: Union[List[str], str], path: Union[str, PathLike],
filenames: Union[List[str], str, None] = None,
extract: bool = False):
r"""Downloads a set of files.
Args:
urls: A (list of) URLs to download files.
path (str): The destination path to save the files.
path: The destination path to save the files.
filenames: A (list of) strings of the file names. If given,
must have the same length with :attr:`urls`. If `None`,
filenames are extracted from :attr:`urls`.
extract (bool): Whether to extract compressed files.
extract: Whether to extract compressed files.
Returns:
A list of paths to the downloaded files.
Expand Down Expand Up @@ -116,7 +119,7 @@ def maybe_download(urls, path, filenames=None, extract=False):
# pylint: enable=unused-argument,function-redefined,missing-docstring


def _download(url: str, filename: str, path: str) -> str:
def _download(url: str, filename: str, path: Union[PathLike, str]) -> str:
def _progress_hook(count, block_size, total_size):
percent = float(count * block_size) / float(total_size) * 100.
sys.stdout.write(f'\r>> Downloading {filename} {percent:.1f}%')
Expand All @@ -141,7 +144,8 @@ def _extract_google_drive_file_id(url: str) -> str:
return file_id


def _download_from_google_drive(url: str, filename: str, path: str) -> str:
def _download_from_google_drive(url: str, filename: str,
path: Union[str, PathLike]) -> str:
r"""Adapted from `https://github.com/saurabhshri/gdrive-downloader`
"""

Expand Down Expand Up @@ -183,12 +187,14 @@ def _get_confirm_token(response):


def deserialize(string: str):
r"""Deserialize a pack from a string.
"""
Deserialize a pack from a string.
Args:
string: The raw string to deserialize from.
Returns:
"""
pack = jsonpickle.decode(string)
# Need to assign the pack manager to the pack to control it after reading
# the raw data.
# pylint: disable=protected-access
# pack._pack_manager = pack_manager
# pack_manager.set_remapped_pack_id(pack)
return pack
10 changes: 8 additions & 2 deletions forte/data/multi_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ class MultiPack(BasePack[Entry, MultiPackLink, MultiPackGroup]):
def __init__(self, pack_name: Optional[str] = None):
super().__init__(pack_name)

# Store the global ids.
# Store the pack ids of the subpacks. Note that these are UUIDs so
# they should be globally non-conflicting.
self._pack_ref: List[int] = []
# Store the reverse mapping from global id to the pack index.
# Store the reverse mapping from pack id to the pack index.
self._inverse_pack_ref: Dict[int, int] = {}

# Store the pack names.
Expand Down Expand Up @@ -258,11 +259,16 @@ def get_pack(self, name: str) -> DataPack:
"""
return self._packs[self._name_index[name]]

def pack_ids(self) -> List[int]:
return self._pack_ref

@property
def packs(self) -> List[DataPack]:
"""
Get the list of Data packs that in the order of added.
Note that please do not use this
Returns: List of data packs contained in this multi-pack.
"""
Expand Down
37 changes: 25 additions & 12 deletions forte/data/readers/deserialize_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from abc import ABC, abstractmethod

from typing import Iterator, List, Any
from typing import Iterator, List, Any, Union

from forte.common.exception import ProcessExecutionException
from forte.data.data_pack import DataPack
Expand All @@ -27,6 +27,7 @@
'RecursiveDirectoryDeserializeReader',
'DirPackReader',
'MultiPackDirectoryReader',
'MultiPackDeserializerBase',
]


Expand Down Expand Up @@ -110,28 +111,35 @@ class MultiPackDeserializerBase(MultiPackReader):
information.
"""

def _collect(self) -> Iterator[Any]: # type: ignore
def _collect(
self, *args: Any, **kwargs: Any) -> Iterator[Any]:
"""
This collect actually do not need any data source, it directly reads
the data from the configurations.
Returns:
"""
for s in self._get_multipack_content():
for s in self._get_multipack_content(*args, **kwargs):
yield s

def _parse_pack(self, multi_pack_str: str) -> Iterator[MultiPack]:
# pylint: disable=protected-access
m_pack: MultiPack = deserialize(multi_pack_str)

for pid in m_pack._pack_ref:
pack: DataPack = deserialize(self._get_pack_content(pid))
m_pack._packs.append(pack)
for pid in m_pack.pack_ids():
p_content = self._get_pack_content(pid)
pack: DataPack
if isinstance(p_content, str):
pack = deserialize(p_content)
else:
pack = p_content
# Only in deserialization we can do this.
m_pack.packs.append(pack)
yield m_pack

@abstractmethod
def _get_multipack_content(self) -> Iterator[str]:
def _get_multipack_content(self, *args: Any, **kwargs: Any
) -> Iterator[str]:
"""
Implementation of this method should be responsible for yielding
the raw content of the multi packs.
Expand All @@ -142,7 +150,7 @@ def _get_multipack_content(self) -> Iterator[str]:
raise NotImplementedError

@abstractmethod
def _get_pack_content(self, pack_id: int) -> str:
def _get_pack_content(self, pack_id: int) -> Union[str, DataPack]:
"""
Implementation of this method should be responsible for returning the
raw string of the data pack from the pack id.
Expand All @@ -151,6 +159,9 @@ def _get_pack_content(self, pack_id: int) -> str:
pack_id: representing the id of the data pack.
Returns:
The content of this data pack. You can either:
- return the raw data pack string.
- return the data pack as parsed DataPack object.
"""
raise NotImplementedError
Expand All @@ -165,7 +176,7 @@ class MultiPackDirectoryReader(MultiPackDeserializerBase):
a directory too (they can be the same directory).
"""

def _get_multipack_content(self) -> Iterator[str]:
def _get_multipack_content(self) -> Iterator[str]: # type: ignore
# pylint: disable=protected-access
for f in os.listdir(self.configs.multi_pack_dir):
if f.endswith(self.configs.pack_suffix):
Expand All @@ -180,11 +191,13 @@ def _get_pack_content(self, pack_id: int) -> str:

@classmethod
def default_configs(cls):
return {
config = super().default_configs()
config.update({
"multi_pack_dir": None,
"data_pack_dir": None,
"pack_suffix": '.json'
}
})
return config


# A short name for this class.
Expand Down
153 changes: 153 additions & 0 deletions forte/data/readers/stave_readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This class contains readers to read from the Stave annotation tool.
The Stave annotation tool can be found here: https://github.com/asyml/stave
"""

import sqlite3
from typing import Iterator, Dict

from forte.common import Resources, ProcessorConfigError
from forte.common.configuration import Config
from forte.data.data_pack import DataPack
from forte.data.data_utils import deserialize
from forte.data.readers.base_reader import PackReader
from forte.data.readers.deserialize_reader import MultiPackDeserializerBase

__all__ = [
"StaveMultiDocSqlReader",
"StaveDataPackSqlReader",
]


def load_all_datapacks(conn, pack_table_name: str, pack_col: int) -> Dict[
int, DataPack]:
"""
Load all the datapacks from the table given a sqlite connection to the
Stave database.
Args:
conn: The sqlite database connection.
pack_table_name: The name of the pack table.
pack_col: The column number to retrieve the actual pack content.
Returns:
A dictionary contains all the datapacks.
"""
c = conn.cursor()
data_packs: Dict[int, DataPack] = {}
for val in c.execute(
f'SELECT * FROM {pack_table_name}'):
pack: DataPack = deserialize(val[pack_col])
# Currently assume we do not have access to the id in the database,
# once we update all Stave db format, we can add the real id.
data_packs[pack.pack_id] = pack
return data_packs


class StaveMultiDocSqlReader(MultiPackDeserializerBase):
"""
This reader reads multi packs from Stave's database schema.
Stave is a annotation interface built on Forte's format:
- https://github.com/asyml/stave
"""

def initialize(self, resources: Resources, configs: Config):
# pylint: disable=attribute-defined-outside-init
super().initialize(resources, configs)

if not configs.stave_db_path:
raise ProcessorConfigError(
'The database path to stave is not specified.')

self.conn = sqlite3.connect(configs.stave_db_path)
self.data_packs: Dict[int, DataPack] = load_all_datapacks(
self.conn, configs.datapack_table, configs.pack_content_col)

def _get_multipack_content(self) -> Iterator[str]: # type: ignore
c = self.conn.cursor()
for value in c.execute(
f'SELECT textPack FROM {self.configs.multipack_table}'):
yield value[0]

def _get_pack_content(self, pack_id: int) -> DataPack:
return self.data_packs[pack_id]

@classmethod
def default_configs(cls):
config = super().default_configs()
config.update({
"stave_db_path": None,
"multipack_table": 'nlpviewer_backend_crossdoc',
"multipack_content_col": 2,
"multipack_project_key_col": 3,
"datapack_table": 'nlpviewer_backend_document',
"pack_content_col": 2,
"project_table": None,
"project_to_read": None,
})
return config


class StaveDataPackSqlReader(PackReader):
def initialize(self, resources: Resources, configs: Config):
super().initialize(resources, configs)

if not configs.stave_db_path:
raise ProcessorConfigError(
'The database path to stave is not specified.')

if not configs.datapack_table:
raise ProcessorConfigError(
'The table name that stores the data pack is not stored.')

def _collect(self) -> Iterator[str]: # type: ignore
# pylint: disable=attribute-defined-outside-init
self.conn = sqlite3.connect(self.configs.stave_db_path)
c = self.conn.cursor()

pack: str = self.configs.datapack_table
project: str = self.configs.project_table

if self.configs.target_project_name is None:
# Read all documents in the database.
query = f'SELECT * FROM {pack}'
else:
# Read the specific project.
query = f'SELECT textPack FROM {pack}, {project} ' \
f'WHERE {pack}.project_id = {project}.id ' \
f'AND {project}.name = "{self.configs.target_project_name}"'

for value in c.execute(query):
yield value[0]

def _parse_pack(self, pack_str: str) -> Iterator[DataPack]:
yield deserialize(pack_str)

@classmethod
def default_configs(cls):
config = super().default_configs()
config.update({
"stave_db_path": None,
"datapack_table": 'nlpviewer_backend_document',
"pack_content_col": 2,
"project_table": "nlpviewer_backend_project",
"target_project_name": None,
})
return config
6 changes: 4 additions & 2 deletions forte/utils/utils_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@

import sys

from typing import Union

def maybe_create_dir(dirname: str) -> bool:

def maybe_create_dir(dirname: Union[str, os.PathLike]) -> bool:
r"""Creates directory if it does not exist.
Args:
dirname (str): Path to the directory.
dirname: Path to the directory.
Returns:
bool: Whether a new directory is created.
Expand Down
Loading

0 comments on commit 900cf93

Please sign in to comment.