Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix: release 0.0.6 #29

Merged
merged 5 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.6.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
args: [--exclude=tests/*]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
- id: ruff
args: [ --fix ]
- id: ruff-format
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Change Log

## 0.0.6 (2024-02-15)

* Use OpenAI v1.12.0.
* Update OpenAI API calls.
* Fix default value in greedy filter.
* Update tests.

## 0.0.5 (2024-02-10)

* Make Rasa an optional package.
Expand Down
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pip install -e .
Once you have installed all dependencies you are ready to go with:
```python
from nl2ltl import translate
from nl2ltl.engines.rasa.core import RasaEngine
from nl2ltl.engines.gpt.core import GPTEngine, Models
from nl2ltl.filters.simple_filters import BasicFilter
from nl2ltl.engines.utils import pretty

engine = RasaEngine()
engine = GPTEngine()
filter = BasicFilter()
utterance = "Eventually send me a Slack after receiving a Gmail"

Expand All @@ -65,7 +65,8 @@ For instance, Rasa requires a `.tar.gz` format trained model in the
- [x] [Rasa](https://rasa.com/) intents/entities classifier (to use Rasa, please install it with `pip install -e ".[rasa]"`)
- [ ] [Watson Assistant](https://www.ibm.com/products/watson-assistant) intents/entities classifier -- Planned

To use GPT models you need to have the OPEN_API_KEY set as environment variable. To set it:
**NOTE**: To use OpenAI GPT models don't forget to add the `OPEN_API_KEY` environment
variable with:
```bash
export OPENAI_API_KEY=your_api_key
```
Expand Down Expand Up @@ -118,7 +119,11 @@ ltl_formulas = translate(utterance, engine=my_engine, filter=my_filter)
Contributions are welcome! Here's how to set up the development environment:
- set up your preferred virtualenv environment
- clone the repo: `git clone https://github.com/IBM/nl2ltl.git && cd nl2ltl`
- install dependencies: `pip install -e .`
- install dev dependencies: `pip install -e ".[dev]"`
- install pre-commit: `pre-commit install`
- sign-off your commits using the `-s` flag in the commit message to be compliant with
the [DCO](https://developercertificate.org/)

## Tests

Expand Down
15 changes: 9 additions & 6 deletions nl2ltl/engines/gpt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@

"""
import json
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Set

import openai
from openai import OpenAI
from pylogics.syntax.base import Formula

from nl2ltl.engines.base import Engine
from nl2ltl.engines.gpt import ENGINE_ROOT
from nl2ltl.engines.gpt.output import GPTOutput, parse_gpt_output, parse_gpt_result
from nl2ltl.filters.base import Filter

openai.api_key = os.getenv("OPENAI_API_KEY")
try:
client = OpenAI()
except Exception:
client = None

engine_root = ENGINE_ROOT
DATA_DIR = engine_root / "data"
PROMPT_PATH = engine_root / DATA_DIR / "prompt.json"
Expand Down Expand Up @@ -75,7 +78,7 @@ def _check_consistency(self) -> None:

def __check_openai_version(self):
"""Check that the GPT tool is at the right version."""
is_right_version = openai.__version__ == "1.12.0"
is_right_version = client._version == "1.12.0"
if not is_right_version:
raise Exception(
"OpenAI needs to be at version 1.12.0. "
Expand Down Expand Up @@ -149,7 +152,7 @@ def _process_utterance(
query = f"NL: {utterance}\n"
messages = [{"role": "user", "content": prompt + query}]
if operation_mode == OperationModes.CHAT.value:
prediction = openai.chat.completions.create(
prediction = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
Expand All @@ -160,7 +163,7 @@ def _process_utterance(
stop=["\n\n"],
)
else:
prediction = openai.completions.create(
prediction = client.completions.create(
model=model,
prompt=messages[0]["content"],
temperature=temperature,
Expand Down
10 changes: 4 additions & 6 deletions nl2ltl/engines/gpt/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def pattern(self) -> str:
Match,
re.search(
"PATTERN: (.*)\n",
self.output["choices"][0]["message"]["content"],
self.output.choices[0].message.content,
),
).group(1)
)
else:
return str(
cast(
Match,
re.search("PATTERN: (.*)\n", self.output["choices"][0]["text"]),
re.search("PATTERN: (.*)\n", self.output.choices[0].text),
).group(1)
)

Expand All @@ -61,15 +61,13 @@ def entities(self) -> Tuple[str]:
return tuple(
cast(
Match,
re.search("SYMBOLS: (.*)", self.output["choices"][0]["message"]["content"]),
re.search("SYMBOLS: (.*)", self.output.choices[0].message.content),
)
.group(1)
.split(", ")
)
else:
return tuple(
cast(Match, re.search("SYMBOLS: (.*)", self.output["choices"][0]["text"])).group(1).split(", ")
)
return tuple(cast(Match, re.search("SYMBOLS: (.*)", self.output.choices[0].text)).group(1).split(", "))


def parse_gpt_output(gpt_output: dict, operation_mode: str) -> GPTOutput:
Expand Down
3 changes: 2 additions & 1 deletion nl2ltl/filters/simple_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pylogics.syntax.base import Formula

from nl2ltl.declare.base import Template
from nl2ltl.filters.base import Filter
from nl2ltl.filters.utils.conflicts import conflicts
from nl2ltl.filters.utils.subsumptions import subsumptions
Expand Down Expand Up @@ -44,7 +45,7 @@ def enforce(output: Dict[Formula, float], entities: Dict[str, float], **kwargs)
"""
result_set = set()

highest_scoring_formula = max(output, key=output.get)
highest_scoring_formula = max(output, key=output.get, default=Template)
formula_conflicts = conflicts(highest_scoring_formula)
formula_subsumptions = subsumptions(highest_scoring_formula)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nl2ltl"
version = "0.0.5"
version = "0.0.6"
license = {file = "LICENSE"}
authors = [
{ name = "Francesco Fuggitti", email = "francesco.fuggitti@gmail.com" },
Expand Down Expand Up @@ -34,7 +34,7 @@ classifiers = [

dependencies = [
"pylogics",
"openai"
"openai==1.12.0"
]

[project.optional-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

class UtterancesFixtures:
utterances = [
"whenever I get a Slack, send a Gmail",
"Invite Sales employees to Thursday's meeting",
"If a new Eventbrite is created, alert me through Slack",
"send me a Slack whenever I get a Gmail",
"whenever I get a Slack, send a Gmail.",
"Invite Sales employees.",
"If a new Eventbrite is created, alert me through Slack.",
"send me a Slack whenever I get a Gmail.",
]
4 changes: 2 additions & 2 deletions tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from nl2ltl import translate
from nl2ltl.engines.gpt.core import GPTEngine, Models
from nl2ltl.engines.gpt.core import GPTEngine
from nl2ltl.filters.simple_filters import BasicFilter, GreedyFilter

from .conftest import UtterancesFixtures
Expand All @@ -18,7 +18,7 @@ def setup_class(cls):
"""Setup any state specific to the execution of the given class (which
usually contains tests).
"""
cls.gpt_engine = GPTEngine(model=Models.GPT35_INSTRUCT.value)
cls.gpt_engine = GPTEngine()
cls.basic_filter = BasicFilter()
cls.greedy_filter = GreedyFilter()

Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ commands = ruff check .

[testenv:ruff-check-apply]
skip_install = True
deps = ruff==0.1.9r
deps = ruff==0.1.9
commands = ruff check --fix --show-fixes .

[testenv:ruff-format]
Expand Down
Loading