Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example extractors #145

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9906bc4
adding more minimal extractors, and replacing them inside rust tests
yenicelik Nov 3, 2023
11e0009
modified develop.md a bit
yenicelik Nov 3, 2023
9bba21a
made lists empty to valid yaml
yenicelik Nov 3, 2023
8eddb43
added pythonpath in dockerfile. lingua tests also running
yenicelik Nov 5, 2023
888bed3
will proceed on desktop
yenicelik Nov 6, 2023
7e0620e
will proceed on desktop
yenicelik Nov 6, 2023
21944e6
devcontainer is ready to develop
yenicelik Nov 6, 2023
a34e43c
devcontainer is ready to develop
yenicelik Nov 6, 2023
73d2721
cleaning up formatting and some comments.
yenicelik Nov 16, 2023
42255ab
cleaning up formatting and some comments.
yenicelik Nov 16, 2023
39bd55f
merged with main
yenicelik Nov 16, 2023
3d82175
added volume for huggingface models to devcontainer
yenicelik Nov 16, 2023
d27c07a
Merge remote-tracking branch 'remote/main' into david/lightweight-ide…
yenicelik Nov 16, 2023
e7d9c67
Merge remote-tracking branch 'remote/main' into david/lightweight-ide…
yenicelik Nov 19, 2023
6a3336b
hash extractor seems to work, . instead of :
yenicelik Nov 19, 2023
15b8dae
WIP adding pdf to markdown extractor, and a language detector
yenicelik Nov 19, 2023
76987e7
language extractor, will proceed with pdf->markdown
yenicelik Nov 20, 2023
af673e1
removed lingua from test python dependencies, goal is to spawn a dock…
yenicelik Nov 20, 2023
bd15293
will proceed with invoice extractor in a separate branch, i ll push t…
yenicelik Nov 20, 2023
42c393c
Merge remote-tracking branch 'remote/main' into david/invoice-extractor
yenicelik Nov 20, 2023
e2581ba
check if merged with master
yenicelik Nov 20, 2023
252185f
merge with main
yenicelik Nov 21, 2023
d4ea4ac
Merge remote-tracking branch 'upstream/main' into david/invoice-extra…
yenicelik Nov 21, 2023
38bc2e2
extractor is working, will document the process and start to improve it
yenicelik Nov 21, 2023
923dc05
revert extractors to main state
yenicelik Nov 22, 2023
eedd855
removed redundant query_embed functions, and removed pip requirements…
yenicelik Nov 22, 2023
11cb8fe
Merge remote-tracking branch 'upstream/main' into david/invoice-extra…
yenicelik Nov 22, 2023
fed0518
moved devcontainer to a different PR
yenicelik Nov 22, 2023
e76970c
cleaned up identity hash embedding
yenicelik Nov 22, 2023
9dbe6e1
cleaned up language extractor
yenicelik Nov 22, 2023
b4084f4
cleaned up simple invoice parser
yenicelik Nov 22, 2023
191ad85
remove markdown extractors that are not working
yenicelik Nov 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ build
upload.pdf
indexify_extractor_sdk.egg-info
extractors/indexify_extractors.egg-info/
data/
40 changes: 40 additions & 0 deletions extractors/identity_hash_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import hashlib
import numpy as np
from typing import List

from indexify_extractor_sdk import (ExtractorSchema, EmbeddingSchema)
from indexify_extractor_sdk.base_embedding import (BaseEmbeddingExtractor, EmbeddingInputParams)

class IdentityHashEmbedding(BaseEmbeddingExtractor):
"""
Implements a Hash Extractor, which can be used to find duplicates within the dataset.
It hashes the text into bytes, and interprets these are a numpy array.

We can extend this by LocalitySensitiveHashing, to also account for small perturbations in the input bytes.

This is equivalent to an identity mapping (with the sample-size n large enough, there will be collisions, but this is highly unlikely )
"""

def __init__(self):
super(IdentityHashEmbedding, self).__init__(max_context_length=128)

def extract_embeddings(self, texts: List[str]) -> List[List[float]]:
return [self._embed(text) for text in texts]

def extract_query_embeddings(self, query: str) -> List[float]:
yenicelik marked this conversation as resolved.
Show resolved Hide resolved
return self._embed(query)

def schemas(self) -> ExtractorSchema:
input_params = EmbeddingInputParams()
return ExtractorSchema(
input_params=input_params.model_dump_json(),
embedding_schemas={
"embedding": EmbeddingSchema(distance_metric="cosine", dim=32)
},
)

def _embed(self, text) -> List[float]:
model = hashlib.sha256()
model.update(bytes(text, 'utf-8'))
out = model.digest()
return np.frombuffer(out, dtype=np.int8).tolist()
8 changes: 8 additions & 0 deletions extractors/identity_hash_embedding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: yenicelik/identity-hash-extractor
yenicelik marked this conversation as resolved.
Show resolved Hide resolved
version: 1
description: "Embedding that extracts the Sha256 hash of whatever bytestring was inputted. This is useful to quickly check for duplicates."
module: identity_hash_embedding:IdentityHashEmbedding
gpu: false
python_dependencies:
- numpy
system_dependencies:
56 changes: 56 additions & 0 deletions extractors/language_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
from pydantic import BaseModel
from typing import List, Literal

from indexify_extractor_sdk import (
Extractor,
Feature,
ExtractorSchema,
Content,
)
from lingua import LanguageDetectorBuilder


class LanguageExtractionInputParams(BaseModel):
overlap: int = 0
text_splitter: Literal["char", "token", "recursive", "new_line"] = "new_line"

class LanguageExtractor(Extractor):
"""
Extractor class for detecting the language of given content.
"""

def __init__(self):
super().__init__()
self._model = LanguageDetectorBuilder.from_all_languages().build()

def extract(
self, content: List[Content], params: LanguageExtractionInputParams
) -> List[List[Content]]:
content_texts = [c.data.decode("utf-8") for c in content]
out = []
for i, x in enumerate(content):
language = self._model.detect_language_of(content_texts[i])
confidence = self._model.compute_language_confidence(content_texts[i], language)
# TODO: Could be modified depending on the database we have
data = {"language": language.name, "score": str(confidence)}
out.append(
[Content.from_text(
text=content_texts[i],
feature=Feature.metadata(value=data, name="language"),
)]
)
return out

def schemas(self) -> ExtractorSchema:
"""
Returns a list of options for indexing.
"""
input_params = LanguageExtractionInputParams()
# TODO If it's metadata, how do we extract things
# This extractor does not return any embedding, only a dictionary!
return ExtractorSchema(
embedding_schemas={},
input_params=json.dumps(input_params.model_json_schema()),
)

8 changes: 8 additions & 0 deletions extractors/language_extractor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: diptanu/language-extractor
version: 1
description: "Classifies text into a human language"
module: language_extractor:LanguageExtractor
gpu: false
python_dependencies:
- lingua-language-detector
system_dependencies:
File renamed without changes.
4 changes: 2 additions & 2 deletions extractors/pdf_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def extract(
for (i, (chunk, embeddings)) in enumerate(zip(chunks, embeddings)):
embeddings_list.append(Embeddings(content_id=c.id, text=chunk, embeddings=embeddings, metadata=json.dumps({"page": i})))
return embeddings_list

def extract_query_embeddings(self, query: str) -> List[float]:
return self._model.embed_query(query)

Expand All @@ -44,4 +44,4 @@ def info(self) -> ExtractorInfo:
input_params=input_params,
output_schema=EmbeddingSchema(distance="cosine", dim=384),
)

134 changes: 0 additions & 134 deletions extractors/setup.py

This file was deleted.

95 changes: 95 additions & 0 deletions extractors/simple_invoice_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from io import BytesIO
import json
import re
import requests
import timeit
from typing import List, Literal

from PIL import Image
from pdf2image import convert_from_bytes
from pydantic import BaseModel
import pytesseract
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel

from indexify_extractor_sdk import (
Extractor,
Feature,
ExtractorSchema,
Content,
)


class SimpleInvoiceParserInputParams(BaseModel):
# No input except the file itself
...

class SimpleInvoiceParserExtractor(Extractor):
def __init__(self):
super().__init__()
self.processor = DonutProcessor.from_pretrained("to-be/donut-base-finetuned-invoices")
self.model = VisionEncoderDecoderModel.from_pretrained("to-be/donut-base-finetuned-invoices")
# TODO: Is this for example how we would pick it up? Probably the model would still need to be defined by the user i.e. how it should be used
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)

def _process_document(self, image):
# prepare encoder inputs
pixel_values = self.processor(image, return_tensors="pt").pixel_values

# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = self.processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

# generate answer
outputs = self.model.generate(
pixel_values.to(self.device),
decoder_input_ids=decoder_input_ids.to(self.device),
max_length=self.model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)

# postprocess
sequence = self.processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return self.processor.token2json(sequence), image

def extract(
self, content: List[Content], params: SimpleInvoiceParserInputParams
) -> List[List[Content]]:
content_filebytes = [c.data for c in content]

# TODO: Right now it only looks at the first image! We should probably flatten it and do it for each page!
images = [convert_from_bytes(x)[0].convert("RGB") for x in content_filebytes]

out = []
for i, x in enumerate(content):
print("i, x are: ", i, x)
data = self._process_document(images[i])[0] # Key 1 includes the image, which we ignore in this case
out.append(
[Content.from_text(
text="", # TODO: Diptanu, what do we do for PDFs? Do you want to save the raw bytes too, I feel like this is unnecessary? Also, I felt like these would be stored in a database _before_ processing, not after
feature=Feature.metadata(value=data, name="invoice_simple_donut"),
)]
)
return out

def schemas(self) -> ExtractorSchema:
"""
Returns a list of options for indexing.
"""
input_params = SimpleInvoiceParserExtractor()
# TODO If it's metadata, how do we extract things
# This extractor does not return any embedding, only a dictionary!
return ExtractorSchema(
embedding_schemas={},
input_params=json.dumps(input_params.model_json_schema()),
)

19 changes: 19 additions & 0 deletions extractors/simple_invoice_parser.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: yenicelik/simple-invoice-parser
version: 1
description: "Parses an invoice using the to-be/donut-base-finetuned-invoices huggingface model"
module: simple_invoice_parser:SimpleInvoiceParserExtractor
gpu: true
python_dependencies:
- transformers
- torch
- torchvision
- Pillow
- pytesseract
- timm
- sentencepiece
- donut-python
- pdf2image
system_dependencies:
- tesseract-ocr
- libtesseract-dev
- poppler-utils
Loading