Skip to content

Commit

Permalink
Shaokai/release 0.1.2 (#52)
Browse files Browse the repository at this point in the history
* corrected broken links and broken references

* Update setup.cfg

- bump to stable v1 and v0.1.1

* Update pyproject.toml

* Update version.py

* Corrected typo. Added config yamls in setup

* Removed config files that are no longer needed

* changed work from to pull the repo from git

* Added comments to remind people to pay attentino to data folder in the demo notebooks

* fixed pypi typo

* Fixed a bug in create_project. Changed default use_vlm to False. Updated demo notebooks

* removed WIP 3d keypoints

* Fixed one more

* WIP

* enforcing the use of create_project in demo notebooks and modified the test

* 3D supported. Better tests. More flexible identifier

* black and isort

* added dlc to test requirement

* Made test use stronger gpt. Added dummy video

* easier superanimal test

* Better 3D prompt and fixed self-debug

* preventing infinite loop

* better prompt for 3D

* better prompt for 3D

* better prompt

* updates

* fixed serialization

* extension to support animation. Made self-debugging work with bigger output. Allowing to skip code execution in parse result

* better interpolation and corrected x,y,z convention

* incorporated suggestions

---------

Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>
  • Loading branch information
yeshaokai and MMathisLab authored Aug 7, 2024
1 parent af4a0a5 commit 6860338
Show file tree
Hide file tree
Showing 34 changed files with 609 additions and 331 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest numpy==1.23.5 tables==3.8.0
pip install deeplabcut==3.0.0rc4
pip install pytest
pip install pytest-timeout
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
Expand Down
3 changes: 2 additions & 1 deletion amadeusgpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from amadeusgpt.integration_modules import *
from amadeusgpt.main import AMADEUS
from amadeusgpt.version import VERSION, __version__
from amadeusgpt.project import create_project
from amadeusgpt.version import VERSION, __version__

params = {
"axes.labelsize": 10,
"legend.fontsize": 10,
Expand Down
14 changes: 9 additions & 5 deletions amadeusgpt/analysis_objects/animal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
from numpy import ndarray
from scipy.spatial import ConvexHull

from amadeusgpt.analysis_objects.object import Object


Expand All @@ -27,8 +26,8 @@ class AnimalSeq(Animal):
body center, left, right, above, top are relative to the subset of keypoints.
Attributes
----------
self._coords: arr potentially subset of keypoints
self.wholebody: full set of keypoints. This is important for overlap relationship
self.wholebody: np.ndarray of keypoints of all bodyparts
self.keypoint
"""

def __init__(self, animal_name: str, keypoints: ndarray, keypoint_names: List[str]):
Expand Down Expand Up @@ -95,8 +94,6 @@ def get_path(self, ind):
return mpath.Path(verts, codes)

def get_keypoints(self, average_keypoints=False) -> ndarray:
# the shape should be (n_frames, n_keypoints, 2)
# extending to 3D?
assert (
len(self.keypoints.shape) == 3
), f"keypoints shape is {self.keypoints.shape}"
Expand All @@ -123,8 +120,15 @@ def get_ymin(self):
def get_ymax(self):
return np.nanmax(self.keypoints[..., 1], axis=1)

def get_zmin(self):
return np.nanmin(self.keypoints[..., 2], axis=1)

def get_zmax(self):
return np.nanmax(self.keypoints[..., 2], axis=1)

def get_keypoint_names(self):
return self.keypoint_names


def query_states(self, query: str) -> ndarray:
assert query in [
Expand Down
37 changes: 22 additions & 15 deletions amadeusgpt/analysis_objects/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ class LLM(AnalysisObject):

def __init__(self, config):
self.config = config
self.max_tokens = config.get("max_tokens", 4096)
self.gpt_model = config.get("gpt_model", "gpt-4o-mini")

self.max_tokens = config["llm_info"].get("max_tokens", 4096)
self.gpt_model = config["llm_info"].get("gpt_model", "gpt-4o-mini")
self.keep_last_n_messages = config.get("keep_last_n_messages", 2)

# the list that is actually sent to gpt
Expand Down Expand Up @@ -261,8 +262,8 @@ def speak(self, sandbox: Sandbox, image: np.ndarray):
response = self.connect_gpt(self.context_window, max_tokens=2000)
text = response.choices[0].message.content.strip()

print ('description of the image frame provided')
print (text)
print("description of the image frame provided")
print(text)

pattern = r"```json(.*?)```"
if len(re.findall(pattern, text, re.DOTALL)) == 0:
Expand Down Expand Up @@ -293,24 +294,26 @@ def speak(
task_program_docs = sandbox.get_task_program_docs()

if share_video_file:
video_file_path = sandbox.video_file_paths[0]
identifier = sandbox.identifiers[0]
else:
raise NotImplementedError("This is not implemented yet")

behavior_analysis = sandbox.analysis_dict[video_file_path]
behavior_analysis = sandbox.analysis_dict[identifier]
scene_image = behavior_analysis.visual_manager.get_scene_image()
keypoint_names = behavior_analysis.animal_manager.get_keypoint_names()
object_names = behavior_analysis.object_manager.get_object_names()
animal_names = behavior_analysis.animal_manager.get_animal_names()

animal_names = behavior_analysis.animal_manager.get_animal_names()
use_3d = sandbox.config['keypoint_info'].get('use_3d', False)

self.system_prompt = _get_system_prompt(
core_api_docs,
task_program_docs,
scene_image,
keypoint_names,
object_names,
animal_names,
)
use_3d=use_3d,
)

self.update_history("system", self.system_prompt)

Expand Down Expand Up @@ -338,6 +341,13 @@ def speak(
with open("temp_answer.json", "w") as f:
obj = {}
obj["chain_of_thought"] = text
obj["code"] = function_code
obj["video_file_paths"] = sandbox.video_file_paths
obj["keypoint_file_paths"] = sandbox.keypoint_file_paths
if not isinstance(sandbox.config, dict):
obj["config"] = sandbox.config.to_dict()
else:
obj["config"] = sandbox.config
json.dump(obj, f, indent=4)

return qa_message
Expand All @@ -361,21 +371,18 @@ def speak(self, qa_message):
query = f""" The code that caused error was {code}
And the error message was {error_message}.
All the modules were already imported so you don't need to import them again.
Can you correct the code?
Can you correct the code? Make sure you only write one function which is the updated function.
"""
self.update_history("user", query)
response = self.connect_gpt(self.context_window, max_tokens=700)
response = self.connect_gpt(self.context_window, max_tokens=4096)
text = response.choices[0].message.content.strip()

print(text)

pattern = r"```python(.*?)```"
function_code = re.findall(pattern, text, re.DOTALL)[0]

qa_message.code = function_code

qa_message.chain_of_thought = text

return qa_message

if __name__ == "__main__":
from amadeusgpt.config import Config
Expand Down
2 changes: 2 additions & 0 deletions amadeusgpt/analysis_objects/object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.path as mpath
import numpy as np

from .base import AnalysisObject


Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(self, name: str, masks: dict):
_seg: dict = self.masks.get("segmentation")
# this is rle format
from pycocotools import mask as mask_decoder

if "counts" in _seg:
self.segmentation = mask_decoder.decode(_seg)
else:
Expand Down
5 changes: 3 additions & 2 deletions amadeusgpt/analysis_objects/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def calc_angle_between_2d_coordinate_systems(cs1, cs2):
return np.rad2deg(np.arccos(dot))


def get_pairwise_distance(arr1, arr2):
def get_pairwise_distance(arr1: np.ndarray, arr2: np.ndarray):
# we want to make sure this uses a fast implementation
# (n_frame, n_kpts, 2)
# arr: (n_frame, n_kpts, 2)
assert len(arr1.shape) == 3 and len(arr2.shape) == 3
# pariwise distance (n_frames, n_kpts, n_kpts)
pairwise_distances = np.ones((arr1.shape[0], arr1.shape[1], arr2.shape[1])) * 100000
for frame_id in range(arr1.shape[0]):
# should we use the mean of all keypoints for the distance?
pairwise_distances[frame_id] = cdist(arr1[frame_id], arr2[frame_id])

return pairwise_distances
Expand Down
73 changes: 42 additions & 31 deletions amadeusgpt/analysis_objects/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from matplotlib.figure import Figure
from matplotlib.ticker import FuncFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
from scipy.signal import medfilt

Expand Down Expand Up @@ -125,7 +126,8 @@ def draw(self, **kwargs) -> None:

self._draw_seg_objects()
self._draw_roi_objects()
self.axs.imshow(self.scene_frame)
if self.scene_frame is not None:
self.axs.imshow(self.scene_frame)


class KeypointVisualization(MatplotlibVisualization):
Expand Down Expand Up @@ -284,37 +286,46 @@ def _event_plot_trajectory(self, **kwargs):
masked_data = medfilt(masked_data, kernel_size=(k, 1))
if masked_data.shape[0] == 0:
continue
x, y = masked_data[:, 0], masked_data[:, 1]
x = x[x.nonzero()]
y = y[y.nonzero()]
if len(x) < 1:
continue

scatter = self.axs.plot(
x,
y,
label=f"event{event_id}",
color=line_colors[event_id],
alpha=0.5,
)
scatter = self.axs.scatter(
x[0],
y[0],
marker="*",
s=100,
color=line_colors[event_id],
alpha=0.5,
**kwargs,
)
self.axs.scatter(
x[-1],
y[-1],
marker="x",
s=100,
color=line_colors[event_id],
alpha=0.5,
**kwargs,
)
if not kwargs.get("use_3d", False):
x, y = masked_data[:, 0], masked_data[:, 1]
_mask = (x != 0) & (y != 0)

x = x[_mask]
y = y[_mask]
if len(x) < 1:
continue

scatter = self.axs.plot(
x,
y,
label=f"event{event_id}",
color=line_colors[event_id],
alpha=0.5,
)
scatter = self.axs.scatter(
x[0],
y[0],
marker="*",
s=100,
color=line_colors[event_id],
alpha=0.5,
**kwargs,
)
self.axs.scatter(
x[-1],
y[-1],
marker="x",
s=100,
color=line_colors[event_id],
alpha=0.5,
**kwargs,
)
else:
# TODO
# implement 3d event plot
pass

return self.axs

def display(self):
Expand Down
6 changes: 3 additions & 3 deletions amadeusgpt/behavior_analysis/analysis_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@

def create_analysis(identifier: Identifier):

if str(identifier) not in analysis_fac:
analysis_fac[str(identifier)] = AnimalBehaviorAnalysis(identifier)
return analysis_fac[str(identifier)]
if identifier not in analysis_fac:
analysis_fac[identifier] = AnimalBehaviorAnalysis(identifier)
return analysis_fac[identifier]
25 changes: 21 additions & 4 deletions amadeusgpt/behavior_analysis/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,34 @@ class Identifier:
Can be more in the future
"""

def __init__(self, config: Config, video_file_path: str, keypoint_file_path: str):
def __init__(
self, config: Config | dict, video_file_path: str, keypoint_file_path: str
):

self.config = config
self.video_file_path = video_file_path
self.keypoint_file_path = keypoint_file_path

def __str__(self):
return os.path.abspath(self.video_file_path)
return f"""------
video_file_path: {self.video_file_path}
keypoint_file_path: {self.keypoint_file_path}
config: {self.config}
------
"""

def __eq__(self, other):
return self.video_file_path == other.video_file_path
if os.path.exists(self.video_file_path):
return os.path.abspath(self.video_file_path) == os.path.abspath(
other.video_file_path
)
else:
return os.path.abspath(self.keypoint_file_path) == os.path.abspath(
other.keypoint_file_path
)

def __hash__(self):
return hash(self.video_file_path)
if os.path.exists(self.video_file_path):
return hash(os.path.abspath(self.video_file_path))
else:
return hash(os.path.abspath(self.keypoint_file_path))
3 changes: 3 additions & 0 deletions amadeusgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __repr__(self):
def __setitem__(self, key, value):
self.data[key] = value

def to_dict(self):
return self.data

def load_config(self):
# Load the YAML config file
if os.path.exists(self.config_file_path):
Expand Down
10 changes: 5 additions & 5 deletions amadeusgpt/integration_modules/embedding/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_cebra_embedding(self, inputs: np.ndarray, n_dimension=3) -> np.ndarray:
features = inputs.reshape(inputs.shape[0], -1)
features = np.nan_to_num(features)

print ('features shape', features.shape)
print("features shape", features.shape)
cebra_params = dict(
model_architecture="offset10-model",
batch_size=512,
Expand All @@ -34,11 +34,11 @@ def get_cebra_embedding(self, inputs: np.ndarray, n_dimension=3) -> np.ndarray:
verbose=True,
time_offsets=10,
)
print ('got here1')
print("got here1")
model = CEBRA(**cebra_params)
print ('got here2')
print("got here2")
model.fit(features)
print ('got here3')
print("got here3")
embeddings = model.transform(features)
print ('got here4')
print("got here4")
return embeddings
Loading

0 comments on commit 6860338

Please sign in to comment.