Skip to content

Commit

Permalink
Merge pull request #68 from lincc-frameworks/gmerz/refactor
Browse files Browse the repository at this point in the history
Add flattening code
  • Loading branch information
grantmerz authored Nov 27, 2023
2 parents 700c3bd + 7d393bb commit 8814c21
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
116 changes: 116 additions & 0 deletions src/deepdisc/data_format/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@

# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import deepdisc
from deepdisc.data_format.image_readers import DC2ImageReader

#PATH = deepdisc.__path__[0]

def flatten_dc2(ddicts):
"""Reads in large cutouts and creates postage stamp images centered on individual objects
Flattens these images+metadata into one tabular dataset. Ignores segmentation maps.
Parameters
----------
ddicts : list[dicts]
The metadata dictionaries for large cutouts with multiple objects.
Returns
-------
flattened_data : np array
The images + metadata that have now been flattened into a tabular array.
Each row has 98316 columns (6x128x128 + 12 metadata values)
"""

i=0
images=[]
metadatas = []
image_reader = DC2ImageReader(norm="raw")

for d in ddicts:
filename= d[f"filename"]
for a in d['annotations']:
new_dict = {}
new_dict["image_id"] = 1
new_dict["height"] = 128
new_dict["width"] = 128

x = a['bbox'][0]
y = a['bbox'][1]
w = a['bbox'][2]
h = a['bbox'][3]

xnew = x+w//2-64
ynew = y+h//2-64

if xnew<0 or ynew <0 or xnew+128>d['height'] or ynew+128>d['height'] or a['mag_i']>25.3:
continue

bxnew = x-(x+w//2 - 64)
bynew = y-(y+h//2 - 64)
#base=filename.split('.')[0].split('/')[-1]
#dirpath = '/home/g4merz/DC2/nersc_data/scarlet_data'
#fn=os.path.join(dirpath,base)+'.npy'

#print(filename.split('.fits')[0])
#base=os.path.join(os.path.dirname(os.path.dirname(PATH)),filename.split('.fits')[0])
#fn = base+'.npy'


#fn = get_test_image_path(d)

image = image_reader(filename)
image = np.transpose(image, axes=(2, 0, 1))


imagecut = image[:,ynew:ynew+128,xnew:xnew+128]

images.append(imagecut.flatten())

metadata =[128,128,i,bxnew,bynew,w,h,1,a['category_id'],a['redshift'],a['obj_id'],a['mag_i']]
metadatas.append(metadata)
i+=1

images = np.array(images)
metadatas = np.array(metadatas)

flattened_data = []
for image,metadata in zip(images,metadatas):
#flatdat = np.concatenate((image,metadat.iloc[i].values))
flatdat = np.concatenate((image,metadata))
flattened_data.append(flatdat)


return flattened_data


def get_test_image_path(d):
"""Function to get an image filepath based on the "filepath" key in a metadata dict
Parameters
----------
d : dict
The metadata dictionary
Returns
-------
fn : str
The filepath to the stored image. Ideally, this should just return the "filename" key,
but if the user moves the images around or saves in a different format,
it can save the time to rename those keys in the metadata dictionaries
"""
filename= d[f"filename"]
base=os.path.join(os.path.dirname(os.path.dirname(PATH)),filename.split('.fits')[0])
fn = base+'.npy'
return fn





7 changes: 6 additions & 1 deletion tests/deepdisc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,9 @@ def dc2_test_data_dir():

@pytest.fixture
def dc2_single_test_file(dc2_test_data_dir):
return path.join(dc2_test_data_dir, "3828_2,2_12_images.npy")
return path.join(dc2_test_data_dir, "3828_2,2_12_images.npy")


@pytest.fixture
def dc2_single_test_dict(dc2_test_data_dir):
return path.join(dc2_test_data_dir, "single_test.json")
16 changes: 16 additions & 0 deletions tests/deepdisc/data_format/test_flatten_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from deepdisc.data_format.file_io import get_data_from_json
from deepdisc.data_format.flatten import flatten_dc2
import os
import pytest


def test_flatten_shape(dc2_single_test_dict):
ddicts = get_data_from_json(dc2_single_test_dict)
flatdat = flatten_dc2(ddicts)

assert len(flatdat)>0
assert len(flatdat[0]) == 98316




0 comments on commit 8814c21

Please sign in to comment.