Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shaokai/fix create project bug #51

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions amadeusgpt/analysis_objects/animal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class AnimalSeq(Animal):
body center, left, right, above, top are relative to the subset of keypoints.
Attributes
----------
self._coords: arr potentially subset of keypoints
self.wholebody: full set of keypoints. This is important for overlap relationship
self.wholebody: np.ndarray of keypoints of all bodyparts
self.keypoint
"""

def __init__(self, animal_name: str, keypoints: ndarray, keypoint_names: List[str]):
Expand Down
6 changes: 4 additions & 2 deletions amadeusgpt/analysis_objects/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def calc_angle_between_2d_coordinate_systems(cs1, cs2):
return np.rad2deg(np.arccos(dot))


def get_pairwise_distance(arr1, arr2):
def get_pairwise_distance(arr1:np.ndarray,
arr2:np.ndarray):
# we want to make sure this uses a fast implementation
# (n_frame, n_kpts, 2)
# arr: (n_frame, n_kpts, 2)
assert len(arr1.shape) == 3 and len(arr2.shape) == 3
# pariwise distance (n_frames, n_kpts, n_kpts)
pairwise_distances = np.ones((arr1.shape[0], arr1.shape[1], arr2.shape[1])) * 100000
for frame_id in range(arr1.shape[0]):
# should we use the mean of all keypooints for the distance?
pairwise_distances[frame_id] = cdist(arr1[frame_id], arr2[frame_id])

return pairwise_distances
Expand Down
3 changes: 1 addition & 2 deletions amadeusgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from amadeusgpt.config import Config
from amadeusgpt.programs.sandbox import Sandbox
import yaml
##########
# all these are providing the customized classes for the code execution
##########
Expand All @@ -21,7 +20,7 @@


class AMADEUS:
def __init__(self, config: Config | dict, use_vlm=True):
def __init__(self, config: Config | dict, use_vlm=False):
self.config = config
### fields that decide the behavior of the application
self.use_self_debug = True
Expand Down
4 changes: 2 additions & 2 deletions amadeusgpt/managers/animal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def configure_animal_from_meta(self, meta_info):
self.superanimal_name = None

def init_pose(self):

if not os.path.exists(self.keypoint_file_path):
# no need to initialize here
return
Expand All @@ -111,6 +110,7 @@ def init_pose(self):
elif self.keypoint_file_path.endswith(".json"):
# could be coco format
all_keypoints = self._process_keypoint_file_from_json()

for individual_id in range(self.n_individuals):
animal_name = f"animal_{individual_id}"
# by default, we initialize all animals with the same keypoints and all the keypoint names
Expand All @@ -128,7 +128,7 @@ def init_pose(self):
self.config["keypoint_info"]["head_orientation_keypoints"]
)

self.animals.append(animalseq)
self.animals.append(animalseq)

def _process_keypoint_file_from_h5(self) -> ndarray:
df = pd.read_hdf(self.keypoint_file_path)
Expand Down
4 changes: 3 additions & 1 deletion amadeusgpt/managers/object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(
self.animal_manager = animal_manager
self.roi_objects = []
self.seg_objects = []
self.load_from_disk = self.config["object_info"]["load_objects_from_disk"]

self.load_from_disk = self.config["object_info"].get("load_objects_from_disk", False)

if self.load_from_disk:
self.load_objects_from_disk()
else:
Expand Down
2 changes: 1 addition & 1 deletion amadeusgpt/managers/visual_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.object_manager = object_manager

def get_scene_image(self):
scene_frame_index = self.config["video_info"]["scene_frame_number"]
scene_frame_index = self.config["video_info"].get("scene_frame_number", 1)
cap = cv2.VideoCapture(self.video_file_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, scene_frame_index)
ret, frame = cap.read()
Expand Down
9 changes: 8 additions & 1 deletion amadeusgpt/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def create_project(data_folder,
"temperature": 0.0,
"keep_last_n_messages": 2
},
"object_info": {
"load_objects_from_disk": False,
"use_grid_objects": False
},
"keypoint_info": {},
"result_info" :{},
"video_info": {}
}
# save the dictionary config to yaml

Expand All @@ -31,7 +38,7 @@ def create_project(data_folder,
yaml.dump(config, f)

print (f"Project created at {result_folder}. Results will be saved to {result_folder}")
print (f"The project will load video files (*.{video_suffix}) and optionally keypoint files from {data_folder}")
print (f"The project will load video files (*{video_suffix}) and optionally keypoint files from {data_folder}")
print (f"A copy of the project config file is saved at {file_path}")
pprint.pprint(config)

Expand Down
2 changes: 1 addition & 1 deletion notebooks/EPM_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"# repo_root/examples/EPM\n",
"config['data_info']['data_folder'] = amadeus_root / config['data_info']['data_folder']\n",
"\n",
"amadeus = AMADEUS(config)\n",
"amadeus = AMADEUS(config, use_vlm=True)\n",
"video_file_paths = amadeus.get_video_file_paths()\n",
"print (video_file_paths) "
]
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Horse_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"# repo_root/examples/Horse\n",
"config['data_info']['data_folder'] = amadeus_root / config['data_info']['data_folder'] \n",
"\n",
"amadeus = AMADEUS(config)\n",
"amadeus = AMADEUS(config, use_vlm = True)\n",
"video_file_paths = amadeus.get_video_file_paths()\n",
"print (video_file_paths) "
]
Expand Down
2 changes: 1 addition & 1 deletion notebooks/MABe_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"# repo_root/examples/MABe\n",
"config['data_info']['data_folder'] = amadeus_root / config['data_info']['data_folder']\n",
"\n",
"amadeus = AMADEUS(config)\n",
"amadeus = AMADEUS(config, use_vlm=True)\n",
"video_file_paths = amadeus.get_video_file_paths()\n",
"print (video_file_paths)"
]
Expand Down
2 changes: 1 addition & 1 deletion notebooks/MausHaus_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"# repo_root/examples/MausHaus\n",
"config['data_info']['data_folder'] = amadeus_root / config['data_info']['data_folder']\n",
"\n",
"amadeus = AMADEUS(config)\n",
"amadeus = AMADEUS(config, use_vlm = True)\n",
"video_file_paths = amadeus.get_video_file_paths()\n",
"print (video_file_paths) "
]
Expand Down
2 changes: 1 addition & 1 deletion notebooks/custom_mouse_video.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"\n",
"config[\"scene_frame_number\"] = scene_frame_number\n",
"\n",
"amadeus = AMADEUS(config)\n",
"amadeus = AMADEUS(config, use_vlm = True)\n",
"video_file_paths = amadeus.get_video_file_paths()\n",
"print (video_file_paths)"
]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_project_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# Create a project

data_folder = "temp_data_folder"
data_folder = "examples/EPM/"
result_folder = "temp_result_folder"

config = create_project(data_folder, result_folder)
Expand All @@ -18,4 +18,4 @@

# query = "Plot the trajectory of the animal using the animal center and color it by time"
# qa_message = amadeus.step(query)
# parse_result(amadeus, qa_message)
# parse_result(amadeus, qa_message)
Loading