Skip to content

Commit

Permalink
updated mending
Browse files Browse the repository at this point in the history
  • Loading branch information
pudumagico committed Jul 30, 2024
1 parent 1afd2c1 commit 942dbb2
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 245 deletions.
3 changes: 1 addition & 2 deletions direct_VQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import io
from litellm import completion
import argparse
from utils import *

# Set environment variables for API keys
# os.environ["OPENAI_API_KEY"] = "sk-None-jFbxGP93PDk2Yk1vo63ST3BlbkFJQtK4HzVLPQ17LsRxkjaI"
os.environ["OPENAI_API_KEY"] = ""

def encode_image(image_path):
with open(image_path, "rb") as image_file:
Expand Down
13 changes: 8 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def main(args):

extended_theory = incumbent_theory + '\n' + response
syntax_check, syntax_error = check_asp_syntax(extended_theory)
print(syntax_error)
if syntax_check:
semantic_check, semantic_error = run_asp_code(
extended_theory, incumbent_examples)
Expand All @@ -155,7 +156,9 @@ def main(args):
while mend_retries < max_mend_retries:

if state_mending:
print('state mending')
state_atoms = run_asp_code_with_states(extended_theory, incumbent_examples)
print(state_atoms)
mended_rule = mend_semantics_with_states(
response, semantic_error, examples[current_example]['answer'], incumbent_theory, preprompt, model, state_atoms)
else:
Expand Down Expand Up @@ -281,7 +284,7 @@ def main(args):
parser.add_argument("--max_retries", type=int, default=1,
help="Maximum number of retries")

parser.add_argument("--mend_retries", type=int, default=0,
parser.add_argument("--mend_retries", type=int, default=1,
help="Maximum number of mending retries")

parser.add_argument("--learning_examples", type=int,
Expand All @@ -293,10 +296,10 @@ def main(args):
parser.add_argument("--model", type=str,
default="gpt-4-1106-preview", help="LLM model to be used")

parser.add_argument("--strategy", type=str, default="len",
parser.add_argument("--strategy", type=str, default="pred",
help="Strategy used to sample examples")

parser.add_argument("--sample_sz", type=int, default=2,
parser.add_argument("--sample_sz", type=int, default=10,
help="Sample size for the strategy selected")

parser.add_argument("--regressive_test", default=True, type=bool,
Expand All @@ -305,13 +308,13 @@ def main(args):
parser.add_argument("--representation", type=str, default="flat",
help="Representation to be used")

parser.add_argument("--remove_random", type=int, default=10,
parser.add_argument("--remove_random", type=int, default=0,
help="Remove a percentage of random lines from the perfect theory to use as initial theory.")

parser.add_argument("--remove_predicate", type=str, default=None,
help="Remove any rule where the predicated selected appears in the perfect theory and use this as initial theory.")

parser.add_argument("--state_mending", type=bool, default=False,
parser.add_argument("--state_mending", type=bool, default=True,
help="Use semantic mending with states of the program.")

parser.add_argument("--batch_theory", type=str, default="",
Expand Down
Loading

0 comments on commit 942dbb2

Please sign in to comment.