diff --git a/.gitignore b/.gitignore index dc7ddbf5ed..cccdb95ebf 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,4 @@ eva/* blog.md tests/integration_tests/short/*.db +test/third_party_tests/*.db diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..8874ba4db3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: check-docstring-first + + - repo: https://github.com/ambv/black + rev: 24.2.0 # Use the latest stable version of Black + hooks: + - id: black + language_version: python3 \ No newline at end of file diff --git a/benchmark/text_summarization/text_summarization_with_evadb.py b/benchmark/text_summarization/text_summarization_with_evadb.py index 26214b8ad9..efa8f45790 100644 --- a/benchmark/text_summarization/text_summarization_with_evadb.py +++ b/benchmark/text_summarization/text_summarization_with_evadb.py @@ -5,35 +5,38 @@ cursor.query("DROP TABLE IF EXISTS cnn_news_test;").df() -cursor.query(""" +cursor.query( + """ CREATE TABLE IF NOT EXISTS cnn_news_test( id TEXT(128), article TEXT(4096), highlights TEXT(1024) - );""").df() -cursor.load('./cnn_news_test.csv', 'cnn_news_test', format="CSV").df() + );""" +).df() +cursor.load("./cnn_news_test.csv", "cnn_news_test", format="CSV").df() cursor.query("DROP FUNCTION IF EXISTS TextSummarizer;").df() -cursor.query("""CREATE UDF IF NOT EXISTS TextSummarizer +cursor.query( + """CREATE UDF IF NOT EXISTS TextSummarizer TYPE HuggingFace TASK 'summarization' MODEL 'sshleifer/distilbart-cnn-12-6' MIN_LENGTH 5 - MAX_LENGTH 100;""").df() + MAX_LENGTH 100;""" +).df() cursor.query("DROP TABLE IF EXISTS cnn_news_summary;").df() cursor._evadb.config.update_value("executor", "batch_mem_size", 300000) -cursor._evadb.config.update_value("executor", "gpu_ids", [0,1]) +cursor._evadb.config.update_value("executor", "gpu_ids", [0, 1]) cursor._evadb.config.update_value("experimental", "ray", True) start_time = time.perf_counter() -cursor.query(""" +cursor.query( + """ CREATE TABLE IF NOT EXISTS cnn_news_summary AS - SELECT TextSummarizer(article) FROM cnn_news_test;""").df() + SELECT TextSummarizer(article) FROM cnn_news_test;""" +).df() end_time = time.perf_counter() print(f"{end_time-start_time:.2f} seconds") - - - diff --git a/docs/conf.py b/docs/conf.py index 95e1eb908a..2daa2b6cec 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,7 +31,7 @@ "sphinx_copybutton", "sphinx.ext.doctest", "sphinx.ext.coverage", -# "sphinx.ext.autosectionlabel", + # "sphinx.ext.autosectionlabel", "sphinx.ext.autosummary", "sphinx.ext.autodoc", "sphinx.ext.autodoc.typehints", @@ -90,11 +90,17 @@ # General information about the project. project = "EvaDB" copyright = str(date.today().year) + ", EvaDB." -author = u"EvaDB" +author = "EvaDB" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md", "images/reference/README.md"] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "README.md", + "images/reference/README.md", +] # The name of the Pygments (syntax highlighting) style to use. @@ -129,12 +135,12 @@ "color-background-secondary": "#fff", "color-sidebar-background-border": "none", "font-stack": "Inter, Arial, sans-serif", - "font-stack--monospace": "Fira Code, Courier, monospace" + "font-stack--monospace": "Fira Code, Courier, monospace", }, "dark_css_variables": { "color-background-secondary": "#000", "font-stack": "Inter, Arial, sans-serif", - "font-stack--monospace": "Fira Code, Courier, monospace" + "font-stack--monospace": "Fira Code, Courier, monospace", }, # Add important announcement here "announcement": "
", @@ -151,18 +157,19 @@ # Adding custom css file html_static_path = ["_static"] html_css_files = [ - "custom.css", + "custom.css", "algolia.css", "https://cdn.jsdelivr.net/npm/@docsearch/css@3", "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/fontawesome.min.css", "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/solid.min.css", - "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/brands.min.css" + "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/brands.min.css", ] # Check link: https://stackoverflow.com/questions/14492743/have-sphinx-report-broken-links/14735060#14735060 nitpicky = True # BUG: https://stackoverflow.com/questions/11417221/sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning -nitpick_ignore_regex = [('py:class', r'.*')] +nitpick_ignore_regex = [("py:class", r".*")] + # -- Initialize Sphinx ---------------------------------------------- def setup(app): @@ -173,5 +180,8 @@ def setup(app): ) # Custom JS app.add_js_file("js/top-navigation.js", defer="defer") - app.add_js_file("https://cdn.jsdelivr.net/npm/@docsearch/js@3.3.3/dist/umd/index.js",defer="defer") - app.add_js_file("js/algolia.js",defer="defer") + app.add_js_file( + "https://cdn.jsdelivr.net/npm/@docsearch/js@3.3.3/dist/umd/index.js", + defer="defer", + ) + app.add_js_file("js/algolia.js", defer="defer") diff --git a/docs/extensions/mps/__init__.py b/docs/extensions/mps/__init__.py index 425c6c72cd..d647410019 100644 --- a/docs/extensions/mps/__init__.py +++ b/docs/extensions/mps/__init__.py @@ -1,7 +1,7 @@ -''' +""" Sphinx extensions for the MPS documentation. See -''' +""" import re from collections import defaultdict @@ -17,7 +17,7 @@ from . import designs -versionlabels['deprecatedstarting'] = "Deprecated starting with version %s" +versionlabels["deprecatedstarting"] = "Deprecated starting with version %s" admonitionlabels.update( aka="Also known as", bibref="Related publication", @@ -36,75 +36,88 @@ similars="Similar terms", specific="In the MPS", topics="Topic", - topicss="Topics"), + topicss="Topics", +), class MpsDomain(Domain): - label = 'MPS' - name = 'mps' + label = "MPS" + name = "mps" class MpsDirective(Directive): @classmethod def add_to_app(cls, app): - if hasattr(cls, 'name'): + if hasattr(cls, "name"): name = cls.name - elif hasattr(cls, 'node_class') and cls.node_class is not None: + elif hasattr(cls, "node_class") and cls.node_class is not None: name = cls.node_class.__name__ else: return - if hasattr(cls, 'node_class') and hasattr(cls, 'visit'): - app.add_node(cls.node_class, html=cls.visit, latex=cls.visit, - text=cls.visit, man=cls.visit) - if hasattr(cls, 'domain'): + if hasattr(cls, "node_class") and hasattr(cls, "visit"): + app.add_node( + cls.node_class, + html=cls.visit, + latex=cls.visit, + text=cls.visit, + man=cls.visit, + ) + if hasattr(cls, "domain"): app.add_directive_to_domain(cls.domain, name, cls) else: app.add_directive(name, cls) class MpsPrefixDirective(MpsDirective): - domain = 'mps' - name = 'prefix' + domain = "mps" + name = "prefix" has_content = True def run(self): targetid = self.content[0] self.state.document.mps_tag_prefix = targetid - targetnode = nodes.target('', '', ids=[targetid]) + targetnode = nodes.target("", "", ids=[targetid]) return [targetnode] def mps_tag_role(name, rawtext, text, lineno, inliner, options={}, content=[]): try: - targetid = '.'.join([inliner.document.mps_tag_prefix, text]) + targetid = ".".join([inliner.document.mps_tag_prefix, text]) except AttributeError: - return [], [inliner.document.reporter.warning - ('mps:tag without mps:prefix', line=lineno)] + return [], [ + inliner.document.reporter.warning("mps:tag without mps:prefix", line=lineno) + ] if len(text) == 0: - return [], [inliner.document.reporter.error - ('missing argument for mps:tag', line=lineno)] - targetnode = nodes.target('', '', ids=[targetid]) - tag = '.{}:'.format(text) - refnode = nodes.reference('', '', refid=targetid, classes=['mpstag'], - *[nodes.Text(tag)]) + return [], [ + inliner.document.reporter.error("missing argument for mps:tag", line=lineno) + ] + targetnode = nodes.target("", "", ids=[targetid]) + tag = ".{}:".format(text) + refnode = nodes.reference( + "", "", refid=targetid, classes=["mpstag"], *[nodes.Text(tag)] + ) return [targetnode, refnode], [] def mps_ref_role(name, rawtext, text, lineno, inliner, options={}, content=[]): textnode = nodes.Text(text) - if text.startswith('.'): + if text.startswith("."): # Tag is relative to current prefix and so points elsewhere # in this document, so create reference node. try: targetid = inliner.document.mps_tag_prefix + text - refnode = nodes.reference('', '', refid=targetid, *[textnode]) + refnode = nodes.reference("", "", refid=targetid, *[textnode]) except AttributeError: - return [textnode], [inliner.document.reporter.warning - (':mps:ref without mps:prefix', line=lineno)] + return [textnode], [ + inliner.document.reporter.warning( + ":mps:ref without mps:prefix", line=lineno + ) + ] else: # Tag is absolute: need to create pending_xref node. - refnode = addnodes.pending_xref('', refdomain='std', reftarget=text, - reftype='view') + refnode = addnodes.pending_xref( + "", refdomain="std", reftarget=text, reftype="view" + ) refnode += textnode return [refnode], [] @@ -114,7 +127,7 @@ class Admonition(nodes.Admonition, nodes.Element): def visit_admonition_node(self, node): - name = type(node).__name__ + ('s' if node.plural else '') + name = type(node).__name__ + ("s" if node.plural else "") self.visit_admonition(node, name=name) @@ -130,9 +143,11 @@ class AdmonitionDirective(MpsDirective, BaseAdmonition): class PluralDirective(AdmonitionDirective): def run(self): ad = super(PluralDirective, self).run() - refs = sum(1 for node in ad[0][0] - if isinstance(node, (addnodes.pending_xref, - nodes.Referential))) + refs = sum( + 1 + for node in ad[0][0] + if isinstance(node, (addnodes.pending_xref, nodes.Referential)) + ) if refs > 1: ad[0].plural = True return ad @@ -187,10 +202,12 @@ class NoteDirective(AdmonitionDirective): def run(self): ad = super(NoteDirective, self).run() - if (isinstance(ad[0][0], nodes.enumerated_list) - and sum(1 for _ in ad[0][0].traverse(nodes.list_item)) > 1 - or isinstance(ad[0][0], nodes.footnote) - and sum(1 for _ in ad[0].traverse(nodes.footnote)) > 1): + if ( + isinstance(ad[0][0], nodes.enumerated_list) + and sum(1 for _ in ad[0][0].traverse(nodes.list_item)) > 1 + or isinstance(ad[0][0], nodes.footnote) + and sum(1 for _ in ad[0].traverse(nodes.footnote)) > 1 + ): ad[0].plural = True return ad @@ -232,7 +249,7 @@ class specific(Admonition): class SpecificDirective(AdmonitionDirective): - domain = 'mps' + domain = "mps" node_class = specific @@ -264,7 +281,7 @@ class GlossaryTransform(transforms.Transform): see_only_ids = set() xref_ids = defaultdict(list) default_priority = 999 - sense_re = re.compile(r'(.*)\s+(\([0-9]+\))$', re.S) + sense_re = re.compile(r"(.*)\s+(\([0-9]+\))$", re.S) def superscript_children(self, target): """ @@ -292,8 +309,8 @@ def apply(self): # Change parenthesized sense numbers to superscripts in # cross-references to glossary entries. for target in self.document.traverse(addnodes.pending_xref): - if target['reftype'] == 'term': - ids = self.xref_ids['term-{}'.format(target['reftarget'])] + if target["reftype"] == "term": + ids = self.xref_ids["term-{}".format(target["reftarget"])] ids.append((target.source, target.line)) if len(target) == 1 and isinstance(target[0], nodes.emphasis): target[0][:] = list(self.superscript_children(target[0])) @@ -303,40 +320,46 @@ def apply(self): ids = set() for c in target: if isinstance(c, nodes.term): - ids = set(c['ids']) - if (isinstance(c, nodes.definition) - and len(c) == 1 - and isinstance(c[0], see)): + ids = set(c["ids"]) + if ( + isinstance(c, nodes.definition) + and len(c) == 1 + and isinstance(c[0], see) + ): self.see_only_ids |= ids # Add cross-reference targets for plurals. - objects = self.document.settings.env.domaindata['std']['objects'] - endings = [(l, l + 's') for l in 'abcedfghijklmnopqrtuvwxz'] - endings.extend([ - ('ss', 'sses'), - ('ing', 'ed'), - ('y', 'ies'), - ('e', 'ed'), - ('', 'ed'), - ]) + objects = self.document.settings.env.domaindata["std"]["objects"] + endings = [(l, l + "s") for l in "abcedfghijklmnopqrtuvwxz"] + endings.extend( + [ + ("ss", "sses"), + ("ing", "ed"), + ("y", "ies"), + ("e", "ed"), + ("", "ed"), + ] + ) for (name, fullname), value in list(objects.items()): - if name != 'term': + if name != "term": continue m = self.sense_re.match(fullname) if m: old_fullname = m.group(1) - sense = ' ' + m.group(2) + sense = " " + m.group(2) else: old_fullname = fullname - sense = '' + sense = "" if any(old_fullname.endswith(e) for _, e in endings): continue for old_ending, new_ending in endings: if not old_fullname.endswith(old_ending): continue - new_fullname = '{}{}{}'.format( - old_fullname[:len(old_fullname) - len(old_ending)], - new_ending, sense) + new_fullname = "{}{}{}".format( + old_fullname[: len(old_fullname) - len(old_ending)], + new_ending, + sense, + ) new_key = name, new_fullname if new_key not in objects: objects[new_key] = value @@ -352,17 +375,16 @@ def warn_indirect_terms(cls, app, exception): if not exception: for i in cls.see_only_ids: for doc, line in cls.xref_ids[i]: - print('{}:{}: WARNING: cross-reference to {}.' - .format(doc, line, i)) + print("{}:{}: WARNING: cross-reference to {}.".format(doc, line, i)) def setup(app): designs.convert_updated(app) app.add_domain(MpsDomain) - app.add_role_to_domain('mps', 'tag', mps_tag_role) - app.add_role_to_domain('mps', 'ref', mps_ref_role) + app.add_role_to_domain("mps", "tag", mps_tag_role) + app.add_role_to_domain("mps", "ref", mps_ref_role) app.add_transform(GlossaryTransform) - app.connect('build-finished', GlossaryTransform.warn_indirect_terms) + app.connect("build-finished", GlossaryTransform.warn_indirect_terms) for g in globals().values(): if isclass(g) and issubclass(g, MpsDirective): g.add_to_app(app) diff --git a/docs/extensions/mps/designs.py b/docs/extensions/mps/designs.py index 9f7a3b4e29..9d49a398b7 100644 --- a/docs/extensions/mps/designs.py +++ b/docs/extensions/mps/designs.py @@ -15,7 +15,7 @@ import re import shutil -TYPES = ''' +TYPES = """ AccessSet Accumulation Addr Align AllocFrame AllocPattern AP Arg Arena Attr Bool BootBlock BT Buffer BufferMode Byte Chain Chunk @@ -27,40 +27,42 @@ Size Space SplayNode SplayTree StackContext Thread Trace TraceId TraceSet TraceStartWhy TraceState ULongest VM Word ZoneSet -''' +""" -mode = re.compile(r'\.\. mode: .*\n') -prefix = re.compile(r'^:Tag: ([a-z][a-z.0-9-]*[a-z0-9])$', re.MULTILINE) +mode = re.compile(r"\.\. mode: .*\n") +prefix = re.compile(r"^:Tag: ([a-z][a-z.0-9-]*[a-z0-9])$", re.MULTILINE) rst_tag = re.compile( - r'^:(?:Author|Date|Status|Revision|Copyright|Organization|Format|Index ' - r'terms|Readership):.*?$\n', - re.MULTILINE | re.IGNORECASE) -mps_tag = re.compile(r'_`\.([a-z][A-Za-z.0-9_-]*[A-Za-z0-9])`:') -mps_ref = re.compile(r'`(\.[a-z][A-Za-z.0-9_-]*[A-Za-z0-9])`_(?: )?') -funcdef = re.compile(r'^``([^`]*\([^`]*\))``$', re.MULTILINE) -macrodef = re.compile(r'^``([A-Z][A-Z0-9_]+)``$', re.MULTILINE) -macro = re.compile(r'``([A-Z][A-Z0-9_]+)``(?: )?') -typedef = re.compile(r'^``typedef ([^`]*)``$', re.MULTILINE) -func = re.compile(r'``([A-Za-z][A-Za-z0-9_]+\(\))``') + r"^:(?:Author|Date|Status|Revision|Copyright|Organization|Format|Index " + r"terms|Readership):.*?$\n", + re.MULTILINE | re.IGNORECASE, +) +mps_tag = re.compile(r"_`\.([a-z][A-Za-z.0-9_-]*[A-Za-z0-9])`:") +mps_ref = re.compile(r"`(\.[a-z][A-Za-z.0-9_-]*[A-Za-z0-9])`_(?: )?") +funcdef = re.compile(r"^``([^`]*\([^`]*\))``$", re.MULTILINE) +macrodef = re.compile(r"^``([A-Z][A-Z0-9_]+)``$", re.MULTILINE) +macro = re.compile(r"``([A-Z][A-Z0-9_]+)``(?: )?") +typedef = re.compile(r"^``typedef ([^`]*)``$", re.MULTILINE) +func = re.compile(r"``([A-Za-z][A-Za-z0-9_]+\(\))``") typename = re.compile( - r'``({0}|[A-Z][A-Za-z0-9_]*(?:Class|Struct|Method)|mps_[a-z_]+_[stu])``(' - r'?: )?' - .format('|'.join(map(re.escape, TYPES.split())))) -design_ref = re.compile(r'^( *\.\. _design\.mps\.(?:[^:\n]+): (?:[^#:\n]+))$', - re.MULTILINE) + r"``({0}|[A-Z][A-Za-z0-9_]*(?:Class|Struct|Method)|mps_[a-z_]+_[stu])``(" + r"?: )?".format("|".join(map(re.escape, TYPES.split()))) +) +design_ref = re.compile( + r"^( *\.\. _design\.mps\.(?:[^:\n]+): (?:[^#:\n]+))$", re.MULTILINE +) design_frag_ref = re.compile( - r'^( *\.\. _design\.mps\.([^:\n]+)\.([^:\n]+): (?:[^#:\n]+))#(.+)$', - re.MULTILINE) -history = re.compile(r'^Document History\n.*', - re.MULTILINE | re.IGNORECASE | re.DOTALL) + r"^( *\.\. _design\.mps\.([^:\n]+)\.([^:\n]+): (?:[^#:\n]+))#(.+)$", re.MULTILINE +) +history = re.compile(r"^Document History\n.*", re.MULTILINE | re.IGNORECASE | re.DOTALL) # Strip section numbering -secnum = re.compile(r'^(?:[0-9]+|[A-Z])\.\s+(.*)$\n(([=`:.\'"~^_*+#-])\3+)$', - re.MULTILINE) +secnum = re.compile( + r'^(?:[0-9]+|[A-Z])\.\s+(.*)$\n(([=`:.\'"~^_*+#-])\3+)$', re.MULTILINE +) def secnum_sub(m): - return m.group(1) + '\n' + m.group(3) * len(m.group(1)) + return m.group(1) + "\n" + m.group(3) * len(m.group(1)) # Convert Ravenbrook style citations into MPS Manual style citations. @@ -71,7 +73,7 @@ def secnum_sub(m): # .. [THVV_1995] Tom Van Vleck. 1995. "`Structure Marking # `__". citation = re.compile( - r''' + r""" ^\.\.\s+(?P\[[^\n\]]+\])\s* "(?P[^"]+?)"\s* ;\s*(?P<author>[^;]+?)\s* @@ -79,70 +81,71 @@ def secnum_sub(m): ;\s*(?P<date>[0-9-]+)\s* (?:;\s*<\s*(?P<url>[^>]*?)\s*>\s*)? \. - ''', - re.VERBOSE | re.MULTILINE | re.IGNORECASE | re.DOTALL + """, + re.VERBOSE | re.MULTILINE | re.IGNORECASE | re.DOTALL, ) def citation_sub(m): - groups = {k: re.sub(r'\s+', ' ', v) for k, v in m.groupdict().items() if v} - fmt = '.. {ref} {author}.' - if 'organization' in groups: - fmt += ' {organization}.' - fmt += ' {date}.' - if 'url' in groups: + groups = {k: re.sub(r"\s+", " ", v) for k, v in m.groupdict().items() if v} + fmt = ".. {ref} {author}." + if "organization" in groups: + fmt += " {organization}." + fmt += " {date}." + if "url" in groups: fmt += ' "`{title} <{url}>`__".' else: fmt += ' "{title}".' return fmt.format(**groups) -index = re.compile(r'^:Index\s+terms:(.*$\n(?:[ \t]+.*$\n)*)', - re.MULTILINE | re.IGNORECASE) +index = re.compile( + r"^:Index\s+terms:(.*$\n(?:[ \t]+.*$\n)*)", re.MULTILINE | re.IGNORECASE +) # <http://sphinx-doc.org/markup/misc.html#directive-index> -index_term = re.compile(r'^\s*(\w+):\s*(.*?)\s*$', re.MULTILINE) +index_term = re.compile(r"^\s*(\w+):\s*(.*?)\s*$", re.MULTILINE) def index_sub(m): - s = '\n.. index::\n' + s = "\n.. index::\n" for term in index_term.finditer(m.group(1)): - s += ' %s: %s\n' % (term.group(1), term.group(2)) - s += '\n' + s += " %s: %s\n" % (term.group(1), term.group(2)) + s += "\n" return s def convert_file(name, source, dest): - s = open(source, 'rb').read().decode('utf-8') + s = open(source, "rb").read().decode("utf-8") # We want the index directive to go right at the start, so that it leads # to the whole document. m = index.search(s) if m: - s = index_sub(m) + '.. _design-{0}:\n\n'.format(name) + s - s = mode.sub(r'', s) - s = prefix.sub(r'.. mps:prefix:: \1', s) - s = rst_tag.sub(r'', s) - s = mps_tag.sub(r':mps:tag:`\1`', s) - s = mps_ref.sub(r':mps:ref:`\1`', s) - s = typedef.sub(r'.. c:type:: \1', s) - s = funcdef.sub(r'.. c:function:: \1', s) - s = macrodef.sub(r'.. c:macro:: \1', s) - s = typename.sub(r':c:type:`\1`', s) - s = func.sub(r':c:func:`\1`', s) - s = macro.sub(r':c:macro:`\1`', s) + s = index_sub(m) + ".. _design-{0}:\n\n".format(name) + s + s = mode.sub(r"", s) + s = prefix.sub(r".. mps:prefix:: \1", s) + s = rst_tag.sub(r"", s) + s = mps_tag.sub(r":mps:tag:`\1`", s) + s = mps_ref.sub(r":mps:ref:`\1`", s) + s = typedef.sub(r".. c:type:: \1", s) + s = funcdef.sub(r".. c:function:: \1", s) + s = macrodef.sub(r".. c:macro:: \1", s) + s = typename.sub(r":c:type:`\1`", s) + s = func.sub(r":c:func:`\1`", s) + s = macro.sub(r":c:macro:`\1`", s) s = secnum.sub(secnum_sub, s) s = citation.sub(citation_sub, s) - s = design_ref.sub(r'\1.html', s) - s = design_frag_ref.sub(r'\1.html#design.mps.\2.\3', s) - s = history.sub('', s) + s = design_ref.sub(r"\1.html", s) + s = design_frag_ref.sub(r"\1.html#design.mps.\2.\3", s) + s = history.sub("", s) # Don't try to format all the quoted code blocks as C. - s = '.. highlight:: none\n\n' + s + s = ".. highlight:: none\n\n" + s try: os.makedirs(os.path.dirname(dest)) except: pass - with open(dest, 'wb') as out: - out.write(s.encode('utf-8')) + with open(dest, "wb") as out: + out.write(s.encode("utf-8")) def newer(src, target): @@ -150,23 +153,25 @@ def newer(src, target): target, False otherwise. """ - return (not os.path.isfile(target) - or os.path.getmtime(target) < os.path.getmtime(src) - or os.path.getmtime(target) < os.path.getmtime(__file__)) + return ( + not os.path.isfile(target) + or os.path.getmtime(target) < os.path.getmtime(src) + or os.path.getmtime(target) < os.path.getmtime(__file__) + ) # Mini-make def convert_updated(app): # app.info(bold('converting MPS design documents')) - for design in glob.iglob('../design/*.txt'): + for design in glob.iglob("../design/*.txt"): name = os.path.splitext(os.path.basename(design))[0] - if name == 'index': + if name == "index": continue - converted = 'source/design/%s.rst' % name + converted = "source/design/%s.rst" % name if newer(design, converted): - app.info('converting design %s' % name) + app.info("converting design %s" % name) convert_file(name, design, converted) - for diagram in glob.iglob('../design/*.svg'): - target = os.path.join('source/design/', os.path.basename(diagram)) + for diagram in glob.iglob("../design/*.svg"): + target = os.path.join("source/design/", os.path.basename(diagram)) if newer(diagram, target): shutil.copyfile(diagram, target) diff --git a/evadb/evadb_cmd_client.py b/evadb/evadb_cmd_client.py index 2f15f075d5..f0263858f5 100644 --- a/evadb/evadb_cmd_client.py +++ b/evadb/evadb_cmd_client.py @@ -19,9 +19,8 @@ from evadb.utils.logging_manager import logger -""" -To allow running evadb_server from any location -""" +# == To allow running evadb_server from any location == + THIS_DIR = dirname(__file__) EvaDB_CODE_DIR = abspath(join(THIS_DIR, "..")) sys.path.append(EvaDB_CODE_DIR) diff --git a/evadb/evadb_server.py b/evadb/evadb_server.py index 0b8a15c175..79fd6b0a13 100644 --- a/evadb/evadb_server.py +++ b/evadb/evadb_server.py @@ -20,9 +20,8 @@ from psutil import process_iter -""" -To allow running evadb_server from any location -""" +# == To allow running evadb_server from any location == + THIS_DIR = dirname(__file__) EvaDB_CODE_DIR = abspath(join(THIS_DIR, "..")) sys.path.append(EvaDB_CODE_DIR) diff --git a/evadb/functions/chatgpt.py b/evadb/functions/chatgpt.py index bf0d338689..3476ebde3c 100644 --- a/evadb/functions/chatgpt.py +++ b/evadb/functions/chatgpt.py @@ -153,9 +153,11 @@ def completion_with_backoff(**kwargs): def_sys_prompt_message = { "role": "system", - "content": prompt - if prompt is not None - else "You are a helpful assistant that accomplishes user tasks.", + "content": ( + prompt + if prompt is not None + else "You are a helpful assistant that accomplishes user tasks." + ), } params["messages"].append(def_sys_prompt_message) diff --git a/evadb/functions/forecast.py b/evadb/functions/forecast.py index 6041e6b499..73ddff09e0 100644 --- a/evadb/functions/forecast.py +++ b/evadb/functions/forecast.py @@ -107,9 +107,11 @@ def forward(self, data) -> pd.DataFrame: columns={ "unique_id": self.id_column_rename, "ds": self.time_column_rename, - self.model_name - if self.library == "statsforecast" - else self.model_name + "-median": self.predict_column_rename, + ( + self.model_name + if self.library == "statsforecast" + else self.model_name + "-median" + ): self.predict_column_rename, self.model_name + "-lo-" + str(self.conf): self.predict_column_rename diff --git a/evadb/optimizer/rules/rules.py b/evadb/optimizer/rules/rules.py index cb9ff32742..6c0f6125a6 100644 --- a/evadb/optimizer/rules/rules.py +++ b/evadb/optimizer/rules/rules.py @@ -1306,9 +1306,7 @@ def apply(self, before: LogicalVectorIndexScan, context: OptimizerContext): yield after -""" -Rules to optimize Ray. -""" +# == Rules to optimize Ray. == def get_ray_env_dict(): diff --git a/evadb/parser/drop_object_statement.py b/evadb/parser/drop_object_statement.py index f88e2ead92..c7751234df 100644 --- a/evadb/parser/drop_object_statement.py +++ b/evadb/parser/drop_object_statement.py @@ -17,7 +17,6 @@ class DropObjectStatement(AbstractStatement): - """Drop Object Statement constructed after parsing the input query Attributes: diff --git a/evadb/third_party/databases/mariadb/mariadb_handler.py b/evadb/third_party/databases/mariadb/mariadb_handler.py index 9c3b1f1b77..be00eeca9f 100644 --- a/evadb/third_party/databases/mariadb/mariadb_handler.py +++ b/evadb/third_party/databases/mariadb/mariadb_handler.py @@ -23,7 +23,6 @@ class MariaDbHandler(DBHandler): - """ Class for implementing the Maria DB handler as a backend store for EvaDB. diff --git a/evadb/third_party/databases/snowflake/snowflake_handler.py b/evadb/third_party/databases/snowflake/snowflake_handler.py index 0b5ba4553d..8df50d6a3c 100644 --- a/evadb/third_party/databases/snowflake/snowflake_handler.py +++ b/evadb/third_party/databases/snowflake/snowflake_handler.py @@ -25,7 +25,6 @@ class SnowFlakeDbHandler(DBHandler): - """ Class for implementing the SnowFlake DB handler as a backend store for EvaDB. @@ -172,7 +171,7 @@ def _snowflake_to_python_types(self, snowflake_type: str): "BINARY": bytes, "DATE": datetime.date, "TIME": datetime.time, - "TIMESTAMP": datetime.datetime + "TIMESTAMP": datetime.datetime, # Add more mappings as needed } diff --git a/evadb/third_party/databases/sqlite/sqlite_handler.py b/evadb/third_party/databases/sqlite/sqlite_handler.py index ae75cba36f..0c332a5bb9 100644 --- a/evadb/third_party/databases/sqlite/sqlite_handler.py +++ b/evadb/third_party/databases/sqlite/sqlite_handler.py @@ -125,9 +125,11 @@ def _fetch_results_as_df(self, cursor): res = cursor.fetchall() res_df = pd.DataFrame( res, - columns=[desc[0].lower() for desc in cursor.description] - if cursor.description - else [], + columns=( + [desc[0].lower() for desc in cursor.description] + if cursor.description + else [] + ), ) return res_df except sqlite3.ProgrammingError as e: diff --git a/evadb/third_party/huggingface/create.py b/evadb/third_party/huggingface/create.py index f49488c9f6..ab99375fae 100644 --- a/evadb/third_party/huggingface/create.py +++ b/evadb/third_party/huggingface/create.py @@ -28,10 +28,9 @@ ) from evadb.utils.generic_utils import try_to_import_transformers -""" -We currently support the following tasks from HuggingFace. -Each task is mapped to the type of input it expects. -""" +# == We currently support the following tasks from HuggingFace. == +# == Each task is mapped to the type of input it expects. == + INPUT_TYPE_FOR_SUPPORTED_TASKS = { "audio-classification": HFInputTypes.AUDIO, "automatic-speech-recognition": HFInputTypes.AUDIO, diff --git a/script/data/download_file.py b/script/data/download_file.py index b301178333..495694027d 100644 --- a/script/data/download_file.py +++ b/script/data/download_file.py @@ -6,18 +6,16 @@ # map file names to their corresponding google drive ids file_id_map = { - # datasets - "bddtest" : "1XDkcJ0eh7ov1r5pm7AsVfCahTQaAJ9sn", - + "bddtest": "1XDkcJ0eh7ov1r5pm7AsVfCahTQaAJ9sn", # models - "vehicle_make_predictor" : "1pM3FFlSMWhZ4LYpdL2tNRvofUKzifzZe" - + "vehicle_make_predictor": "1pM3FFlSMWhZ4LYpdL2tNRvofUKzifzZe", } + def download_file_from_google_drive(file_name, destination): """ - Downloads a zip file from google drive. Assumes the file has open access. + Downloads a zip file from google drive. Assumes the file has open access. Args: file_name: name of the file to download destination: path to save the file to @@ -30,14 +28,15 @@ def download_file_from_google_drive(file_name, destination): session = requests.Session() - response = session.get(URL, params = { 'id' : id }, stream = True) + response = session.get(URL, params={"id": id}, stream=True) token = get_confirm_token(response) if token: - params = { 'id' : id, 'confirm' : token } - response = session.get(URL, params = params, stream = True) + params = {"id": id, "confirm": token} + response = session.get(URL, params=params, stream=True) + + save_response_content(response, destination) - save_response_content(response, destination) def get_confirm_token(response): """ @@ -47,26 +46,28 @@ def get_confirm_token(response): """ for key, value in response.cookies.items(): - if key.startswith('download_warning'): + if key.startswith("download_warning"): return value return None + def save_response_content(response, destination): """ - Writes the content of the response to the destination. + Writes the content of the response to the destination. Args: response: response object from the request destination: path to save the file to """ - + CHUNK_SIZE = 32768 with open(destination, "wb") as f: for chunk in tqdm(response.iter_content(CHUNK_SIZE)): - if chunk: # filter out keep-alive new chunks + if chunk: # filter out keep-alive new chunks f.write(chunk) + if __name__ == "__main__": file_name = sys.argv[1] destination = os.path.join(os.getcwd(), file_name + ".zip") diff --git a/script/docs/catalog_plotter.py b/script/docs/catalog_plotter.py index 77c8505fd2..2be12d8735 100644 --- a/script/docs/catalog_plotter.py +++ b/script/docs/catalog_plotter.py @@ -232,9 +232,7 @@ def format_col_str(col): suffix = ( "(FK)" if col.name in fk_col_names - else "(PK)" - if col.name in pk_col_names - else "" + else "(PK)" if col.name in pk_col_names else "" ) if show_datatypes: return "- %s : %s" % (col.name + suffix, format_col_type(col)) @@ -248,9 +246,11 @@ def format_name(obj_name, format_dict): return '<FONT COLOR="{color}" POINT-SIZE="{size}">{bld}{it}{name}{e_it}{e_bld}</FONT>'.format( name=obj_name, color=format_dict.get("color") if "color" in format_dict else "initial", - size=float(format_dict["fontsize"]) - if "fontsize" in format_dict - else "initial", + size=( + float(format_dict["fontsize"]) + if "fontsize" in format_dict + else "initial" + ), it="<I>" if format_dict.get("italics") else "", e_it="</I>" if format_dict.get("italics") else "", bld="<B>" if format_dict.get("bold") else "", diff --git a/script/formatting/formatter.py b/script/formatting/formatter.py index 433de9c48c..12fd2ec06c 100755 --- a/script/formatting/formatter.py +++ b/script/formatting/formatter.py @@ -29,12 +29,14 @@ background_loop = asyncio.new_event_loop() + def background(f): def wrapped(*args, **kwargs): return background_loop.run_in_executor(None, f, *args, **kwargs) return wrapped + # ============================================== # CONFIGURATION # ============================================== @@ -210,7 +212,7 @@ def format_file(file_path, add_header, strip_header, format_code): fd.write(new_file_data) elif format_code: - #LOG.info("Formatting File : " + file_path) + # LOG.info("Formatting File : " + file_path) # ISORT isort_command = f"{ISORT_BINARY} --profile black {file_path}" os.system(isort_command) @@ -227,24 +229,28 @@ def format_file(file_path, add_header, strip_header, format_code): # PYLINT pylint_command = f"{PYLINT_BINARY} --spelling-private-dict-file {ignored_words_file} --rcfile={PYLINTRC} {file_path}" - #LOG.warning(pylint_command) - #ret_val = os.system(pylint_command) - #if ret_val: + # LOG.warning(pylint_command) + # ret_val = os.system(pylint_command) + # if ret_val: # sys.exit(1) # CHECK FOR INVALID WORDS (like print) - with open(file_path, 'r') as file: + with open(file_path, "r") as file: for line_num, line in enumerate(file, start=1): - if file_path not in IGNORE_PRINT_FILES and ' print(' in line: - LOG.warning(f"print() found in {file_path}, line {line_num}: {line.strip()}") - sys.exit(1) + if file_path not in IGNORE_PRINT_FILES and " print(" in line: + LOG.warning( + f"print() found in {file_path}, line {line_num}: {line.strip()}" + ) + sys.exit(1) # END WITH fd.close() + # END FORMAT__FILE(FILE_NAME) + # check the notebooks def check_notebook_format(notebook_file): # print(notebook_file) @@ -264,16 +270,18 @@ def check_notebook_format(notebook_file): # Check that all cells have a valid cell type (code, markdown, or raw) for cell in nb.cells: - if cell.cell_type not in ['code', 'markdown', 'raw']: - LOG.error(f"ERROR: Notebook {notebook_file} contains an invalid cell type: {cell.cell_type}") + if cell.cell_type not in ["code", "markdown", "raw"]: + LOG.error( + f"ERROR: Notebook {notebook_file} contains an invalid cell type: {cell.cell_type}" + ) sys.exit(1) # Check that all code cells have a non-empty source code for cell in nb.cells: - if cell.cell_type == 'code' and not cell.source.strip(): + if cell.cell_type == "code" and not cell.source.strip(): LOG.error(f"ERROR: Notebook {notebook_file} contains an empty code cell") sys.exit(1) - + # Check for "print(response)" # too harsh replaxing it # for cell in nb.cells: @@ -284,7 +292,7 @@ def check_notebook_format(notebook_file): # Check for "Colab link" contains_colab_link = False for cell in nb.cells: - if cell.cell_type == 'markdown' and 'colab' in cell.source: + if cell.cell_type == "markdown" and "colab" in cell.source: # Check if colab link is correct # notebook_file_name must match colab link if notebook_file_name in cell.source: @@ -292,7 +300,9 @@ def check_notebook_format(notebook_file): break if contains_colab_link is False: - LOG.error(f"ERROR: Notebook {notebook_file} does not contain correct Colab link -- update the link.") + LOG.error( + f"ERROR: Notebook {notebook_file} does not contain correct Colab link -- update the link." + ) sys.exit(1) return True @@ -303,6 +313,7 @@ def check_notebook_format(notebook_file): import enchant from enchant.checker import SpellChecker + chkr = SpellChecker("en_US") # Check spelling @@ -312,7 +323,9 @@ def check_notebook_format(notebook_file): chkr.set_text(cell.source) for err in chkr: if err.word not in ignored_words: - LOG.warning(f"WARNING: Notebook {notebook_file} contains the misspelled word: {err.word}") + LOG.warning( + f"WARNING: Notebook {notebook_file} contains the misspelled word: {err.word}" + ) # format all the files in the dir passed as argument @@ -335,9 +348,10 @@ def format_dir(dir_path, add_header, strip_header, format_code): # END ADD_HEADERS_DIR(DIR_PATH) + @background def check_file(file): - #print(file) + # print(file) valid = False # only format the default directories file_path = str(Path(file).absolute()) @@ -352,6 +366,7 @@ def check_file(file): format_file(file, True, False, False) format_file(file, False, False, True) + # ============================================== # Main Function # ============================================== @@ -388,7 +403,9 @@ def check_file(file): parser.add_argument( "-d", "--dir-name", help="directory containing files to be acted on" ) - parser.add_argument("-k", "--spell-check", help="enable spelling check (off by default)") + parser.add_argument( + "-k", "--spell-check", help="enable spelling check (off by default)" + ) args = parser.parse_args() @@ -412,22 +429,18 @@ def check_file(file): ) elif args.dir_name: LOG.info("Scanning directory " + "".join(args.dir_name)) - format_dir( - args.dir_name, args.add_header, args.strip_header, args.format_code - ) + format_dir(args.dir_name, args.add_header, args.strip_header, args.format_code) # BY DEFAULT, WE FIX THE MODIFIED FILES else: # LOG.info("Default fix modified files") MERGEBASE = subprocess.check_output( - "git merge-base origin/staging HEAD", - shell=True, - universal_newlines=True + "git merge-base origin/staging HEAD", shell=True, universal_newlines=True ).rstrip() files = ( subprocess.check_output( f"git diff --name-only --diff-filter=ACRM {MERGEBASE} -- '*.py'", shell=True, - universal_newlines=True + universal_newlines=True, ) .rstrip() .split("\n") @@ -438,7 +451,7 @@ def check_file(file): # CHECK ALL THE NOTEBOOKS - # Iterate over all files in the directory + # Iterate over all files in the directory # and check if they are Jupyter notebooks for file in os.listdir(EvaDB_NOTEBOOKS_DIR): if file.endswith(".ipynb"): @@ -451,39 +464,47 @@ def check_file(file): # GO OVER ALL DOCS # Install aspell # apt-get install aspell - - #LOG.info("ASPELL") - for elem in Path(EvaDB_DOCS_DIR).rglob('*.*'): + + # LOG.info("ASPELL") + for elem in Path(EvaDB_DOCS_DIR).rglob("*.*"): if elem.suffix == ".rst" or elem.suffix == ".yml": - os.system(f"aspell --lang=en --personal='{ignored_words_file}' check {elem}") + os.system( + f"aspell --lang=en --personal='{ignored_words_file}' check {elem}" + ) - os.system(f"aspell --lang=en --personal='{ignored_words_file}' check 'README.md'") + os.system( + f"aspell --lang=en --personal='{ignored_words_file}' check 'README.md'" + ) # CODESPELL - #LOG.info("Codespell") - subprocess.check_output(""" codespell "evadb/*.py" """, - shell=True, - universal_newlines=True) - subprocess.check_output(""" codespell "evadb/*/*.py" """, - shell=True, - universal_newlines=True) - subprocess.check_output(""" codespell "docs/source/*/*.rst" """, - shell=True, - universal_newlines=True) - subprocess.check_output(""" codespell "docs/source/*.rst" """, - shell=True, - universal_newlines=True) - subprocess.check_output(""" codespell "*.md" """, - shell=True, - universal_newlines=True) - subprocess.check_output(""" codespell "evadb/*.md" """, - shell=True, - universal_newlines=True) - - for elem in Path(EvaDB_SRC_DIR).rglob('*.*'): + # LOG.info("Codespell") + subprocess.check_output( + """ codespell "evadb/*.py" """, shell=True, universal_newlines=True + ) + subprocess.check_output( + """ codespell "evadb/*/*.py" """, shell=True, universal_newlines=True + ) + subprocess.check_output( + """ codespell "docs/source/*/*.rst" """, shell=True, universal_newlines=True + ) + subprocess.check_output( + """ codespell "docs/source/*.rst" """, shell=True, universal_newlines=True + ) + subprocess.check_output( + """ codespell "*.md" """, shell=True, universal_newlines=True + ) + subprocess.check_output( + """ codespell "evadb/*.md" """, shell=True, universal_newlines=True + ) + + for elem in Path(EvaDB_SRC_DIR).rglob("*.*"): if elem.suffix == ".py": - os.system(f"aspell --lang=en --personal='{ignored_words_file}' check {elem}") + os.system( + f"aspell --lang=en --personal='{ignored_words_file}' check {elem}" + ) - for elem in Path(EvaDB_TEST_DIR).rglob('*.*'): + for elem in Path(EvaDB_TEST_DIR).rglob("*.*"): if elem.suffix == ".py": - os.system(f"aspell --lang=en --personal='{ignored_words_file}' check {elem}") \ No newline at end of file + os.system( + f"aspell --lang=en --personal='{ignored_words_file}' check {elem}" + ) diff --git a/script/formatting/validator.py b/script/formatting/validator.py index 465129fef5..1ce798a0ca 100644 --- a/script/formatting/validator.py +++ b/script/formatting/validator.py @@ -32,14 +32,12 @@ EXIT_SUCCESS = 0 EXIT_FAILURE = -1 -VALIDATOR_PATTERNS = [re.compile(patterns) for patterns in [ - r"print" -] -] +VALIDATOR_PATTERNS = [re.compile(patterns) for patterns in [r"print"]] CODE_SOURCE_DIR = os.path.abspath(os.path.dirname(__file__)) -EvaDB_DIR = functools.reduce(os.path.join, - [CODE_SOURCE_DIR, os.path.pardir, os.path.pardir]) +EvaDB_DIR = functools.reduce( + os.path.join, [CODE_SOURCE_DIR, os.path.pardir, os.path.pardir] +) EvaDB_SRC_DIR = os.path.join(EvaDB_DIR, "eva") EvaDB_TEST_DIR = os.path.join(EvaDB_DIR, "test") @@ -55,13 +53,13 @@ def contains_commented_out_code(line): line = line.lstrip() - if 'utf-8' in line: + if "utf-8" in line: return False - if not line.startswith('#'): + if not line.startswith("#"): return False - line = line.lstrip(' \t\v\n#').strip() + line = line.lstrip(" \t\v\n#").strip() # Regex for checking function definition; for, with loops; # continue and break @@ -69,24 +67,23 @@ def contains_commented_out_code(line): r"def .+\)[\s]*[->]*[\s]*[a-zA-Z_]*[a-zA-Z0-9_]*:$", r"with .+ as [a-zA-Z_][a-zA-Z0-9_]*:$", r"for [a-zA-Z_][a-zA-Z0-9_]* in .+:$", - r'continue$', r'break$' + r"continue$", + r"break$", ] for regex in regex_list: if re.search(regex, line): return True - symbol_list = list('[]{}=%') +\ - ['print', 'break', - 'import ', 'elif '] + symbol_list = list("[]{}=%") + ["print", "break", "import ", "elif "] for symbol in symbol_list: if symbol in line: return True # Handle return statements in a specific way - if 'return' in line: - if len(line.split(' ')) >= 2: + if "return" in line: + if len(line.split(" ")) >= 2: return False else: return True @@ -101,33 +98,39 @@ def validate_file(file): LOG.info("ERROR: " + file + " isn't a file") sys.exit(EXIT_FAILURE) - if not file.endswith('.py'): + if not file.endswith(".py"): return True code_validation = True line_number = 1 commented_code = False - with open(file, 'r') as opened_file: + with open(file, "r") as opened_file: for line in opened_file: # Check if the line has commented code - if line.lstrip().startswith('#'): + if line.lstrip().startswith("#"): commented_code = contains_commented_out_code(line) if commented_code: - LOG.info("Commented code " - + "in file " + file - + " Line {}: {}".format(line_number, line.strip())) + LOG.info( + "Commented code " + + "in file " + + file + + " Line {}: {}".format(line_number, line.strip()) + ) # Search for a pattern, and report hits for validator_pattern in VALIDATOR_PATTERNS: if validator_pattern.search(line): code_validation = False - LOG.info("Unacceptable pattern:" - + validator_pattern.pattern.strip() - + " in file " + file - + " Line {}: {}".format(line_number, line.strip())) + LOG.info( + "Unacceptable pattern:" + + validator_pattern.pattern.strip() + + " in file " + + file + + " Line {}: {}".format(line_number, line.strip()) + ) line_number += 1 @@ -150,14 +153,15 @@ def validate_directory(directory_list): return code_validation -if __name__ == '__main__': +if __name__ == "__main__": PARSER = argparse.ArgumentParser( description="Perform source code validation on EvaDB." ) - PARSER.add_argument("--files", nargs="*", - help="Provide a list of specific files to validate") + PARSER.add_argument( + "--files", nargs="*", help="Provide a list of specific files to validate" + ) ARGS = PARSER.parse_args() diff --git a/script/releasing/releaser.py b/script/releasing/releaser.py index 9b5a2f0179..23e109dbb8 100755 --- a/script/releasing/releaser.py +++ b/script/releasing/releaser.py @@ -296,7 +296,9 @@ def bump_up_version(next_version): run_command("git add . -u") run_command("git commit -m '[BUMP]: " + NEXT_RELEASE + "'") run_command("git push --set-upstream origin bump-" + NEXT_RELEASE) - run_command(f"gh pr create -B staging -H bump-{NEXT_RELEASE} --title 'Bump Version to {NEXT_RELEASE}' --body 'Bump Version to {NEXT_RELEASE}'") + run_command( + f"gh pr create -B staging -H bump-{NEXT_RELEASE} --title 'Bump Version to {NEXT_RELEASE}' --body 'Bump Version to {NEXT_RELEASE}'" + ) # ============================================== @@ -453,4 +455,3 @@ def bump_up_version(next_version): # BUMP UP VERSION bump_up_version(next_version) - diff --git a/script/test/test.sh b/script/test/test.sh old mode 100644 new mode 100755 diff --git a/setup.py b/setup.py index e3d211ece8..3a635be0cc 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def read(path, encoding="utf-8"): minimal_requirements = [ "numpy>=1.19.5", - "pandas>=2.1.0", # DataFrame.map is available after v2.1.0 + "pandas>=2.1.0", # DataFrame.map is available after v2.1.0 "sqlalchemy>=2.0.0", "sqlalchemy-utils>=0.36.6", "lark>=1.0.0", @@ -134,9 +134,7 @@ def read(path, encoding="utf-8"): "neuralforecast", # MODEL TRAIN AND FINE TUNING ] -imagegen_libs = [ - "replicate" -] +imagegen_libs = ["replicate"] ### NEEDED FOR DEVELOPER TESTING ONLY @@ -184,7 +182,15 @@ def read(path, encoding="utf-8"): "forecasting": forecasting_libs, "hackernews": hackernews_libs, # everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11. - "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + "dev": dev_libs + + vision_libs + + document_libs + + function_libs + + notebook_libs + + forecasting_libs + + sklearn_libs + + imagegen_libs + + xgboost_libs, } setup( diff --git a/test/integration_tests/long/test_optimizer_rules.py b/test/integration_tests/long/test_optimizer_rules.py index eb515a19a2..2f78bad02e 100644 --- a/test/integration_tests/long/test_optimizer_rules.py +++ b/test/integration_tests/long/test_optimizer_rules.py @@ -160,9 +160,11 @@ def _check_reorder(cost_func): # reordering if first predicate has higher cost _check_reorder( - lambda name: MagicMock(cost=10) - if name == "DummyMultiObjectDetector" - else MagicMock(cost=5) + lambda name: ( + MagicMock(cost=10) + if name == "DummyMultiObjectDetector" + else MagicMock(cost=5) + ) ) # reordering if first predicate has no cost @@ -194,16 +196,20 @@ def _check_no_reorder(cost_func): # no reordering if first predicate has lower cost _check_no_reorder( - lambda name: MagicMock(cost=10) - if name == "DummyMultiObjectDetector" - else MagicMock(cost=5) + lambda name: ( + MagicMock(cost=10) + if name == "DummyMultiObjectDetector" + else MagicMock(cost=5) + ) ) # no reordering if both predicates have same cost _check_no_reorder( - lambda name: MagicMock(cost=5) - if name == "DummyMultiObjectDetector" - else MagicMock(cost=5) + lambda name: ( + MagicMock(cost=5) + if name == "DummyMultiObjectDetector" + else MagicMock(cost=5) + ) ) # no reordering if default cost is used for one Function diff --git a/test/integration_tests/long/test_similarity.py b/test/integration_tests/long/test_similarity.py index 2a8d52cf8d..d647e1df8f 100644 --- a/test/integration_tests/long/test_similarity.py +++ b/test/integration_tests/long/test_similarity.py @@ -147,9 +147,9 @@ def setUp(self): self.original_weaviate_env = os.environ.get("WEAVIATE_API_URL") os.environ["WEAVIATE_API_KEY"] = "NM4adxLmhtJDF1dPXDiNhEGTN7hhGDpymmO0" - os.environ[ - "WEAVIATE_API_URL" - ] = "https://cs6422-test2-zn83syib.weaviate.network" + os.environ["WEAVIATE_API_URL"] = ( + "https://cs6422-test2-zn83syib.weaviate.network" + ) def tearDown(self): shutdown_ray() diff --git a/test/unit_tests/optimizer/rules/test_rules.py b/test/unit_tests/optimizer/rules/test_rules.py index 18f8dc51d4..41467f841d 100644 --- a/test/unit_tests/optimizer/rules/test_rules.py +++ b/test/unit_tests/optimizer/rules/test_rules.py @@ -214,9 +214,11 @@ def test_supported_rules(self): LogicalDeleteToPhysical(), LogicalLoadToPhysical(), LogicalGetToSeqScan(), - LogicalProjectToRayPhysical() - if ray_enabled_and_installed - else LogicalProjectToPhysical(), + ( + LogicalProjectToRayPhysical() + if ray_enabled_and_installed + else LogicalProjectToPhysical() + ), LogicalProjectNoTableToPhysical(), LogicalDerivedGetToPhysical(), LogicalUnionToPhysical(), @@ -228,9 +230,11 @@ def test_supported_rules(self): LogicalFunctionScanToPhysical(), LogicalJoinToPhysicalHashJoin(), LogicalFilterToPhysical(), - LogicalApplyAndMergeToRayPhysical() - if ray_enabled_and_installed - else LogicalApplyAndMergeToPhysical(), + ( + LogicalApplyAndMergeToRayPhysical() + if ray_enabled_and_installed + else LogicalApplyAndMergeToPhysical() + ), LogicalShowToPhysical(), LogicalExplainToPhysical(), LogicalCreateIndexToVectorIndex(),