This repository has been archived by the owner on Jan 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
add visual search #25
Open
shashikg
wants to merge
14
commits into
brain-score:master
Choose a base branch
from
shashikg:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
fce9862
add visual search model for object array
shashikg a9b3caf
Chnages to ModelCommitment to add visual search brain model
shashikg b15f42f
show visual search status
shashikg 13f1bb0
Merge branch 'master' into master
shashikg b73ac40
removed candidate_model dependencies and some minor changes
shashikg c0c14c9
Merge branch 'master' of https://github.com/shashikg/model-tools
shashikg 3a81068
Merge branch 'master' into master
mschrimpf 4e583fa
simplify behavior arbitration
mschrimpf 1f9b095
auto-format
mschrimpf 0a40eb3
add vs
shashikg 9958deb
Merge branch 'master' of https://github.com/brain-score/model-tools i…
shashikg 9d1cf90
Merge branch 'brain-score-master'
shashikg d5f29ac
remove redundant import
shashikg f4cf9b3
visual search - waldo and natural design
shashikg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
__pycache__/ | ||
.ipynb_checkpoints | ||
build/* | ||
dist/* | ||
model_tools.egg-info/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
import cv2 | ||
import logging | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from brainscore.model_interface import BrainModel | ||
from brainscore.utils import fullname | ||
|
||
class VisualSearchObjArray(BrainModel): | ||
def __init__(self, identifier, target_model_param, stimuli_model_param): | ||
self.current_task = None | ||
self.identifier = identifier | ||
self.target_model = target_model_param['target_model'] | ||
self.stimuli_model = stimuli_model_param['stimuli_model'] | ||
self.target_layer = target_model_param['target_layer'] | ||
self.stimuli_layer = stimuli_model_param['stimuli_layer'] | ||
self.search_image_size = stimuli_model_param['search_image_size'] | ||
self._logger = logging.getLogger(fullname(self)) | ||
|
||
def start_task(self, task: BrainModel.Task, **kwargs): | ||
self.fix = kwargs['fix'] # fixation map | ||
self.max_fix = kwargs['max_fix'] # maximum allowed fixation excluding the very first fixation | ||
self.data_len = kwargs['data_len'] # Number of stimuli | ||
self.current_task = task | ||
|
||
def look_at(self, stimuli_set): | ||
self.gt_array = [] | ||
gt = stimuli_set[stimuli_set['image_label'] == 'mask'] | ||
gt_paths = list(gt.image_paths.values())[int(gt.index.values[0]):int(gt.index.values[-1] + 1)] | ||
|
||
for i in range(6): | ||
imagename_gt = gt_paths[i] | ||
|
||
gt = cv2.imread(imagename_gt, 0) | ||
gt = cv2.resize(gt, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) | ||
retval, gt = cv2.threshold(gt, 125, 255, cv2.THRESH_BINARY) | ||
temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) | ||
temp_stim[self.search_image_size:2 * self.search_image_size, | ||
self.search_image_size:2 * self.search_image_size] = np.copy(gt) | ||
gt = np.copy(temp_stim) | ||
gt = gt / 255 | ||
|
||
self.gt_array.append(gt) | ||
|
||
self.gt_total = np.copy(self.gt_array[0]) | ||
for i in range(1, 6): | ||
self.gt_total += self.gt_array[i] | ||
|
||
self.score = np.zeros((self.data_len, self.max_fix + 1)) | ||
self.data = np.zeros((self.data_len, self.max_fix + 2, 2), dtype=int) | ||
S_data = np.zeros((300, 7, 2), dtype=int) | ||
I_data = np.zeros((300, 1), dtype=int) | ||
|
||
data_cnt = 0 | ||
|
||
target = stimuli_set[stimuli_set['image_label'] == 'target'] | ||
target_features = self.target_model(target, layers=[self.target_layer], stimuli_identifier=False) | ||
if target_features.shape[0] == target_features['neuroid_num'].shape[0]: | ||
target_features = target_features.T | ||
|
||
stimuli = stimuli_set[stimuli_set['image_label'] == 'stimuli'] | ||
stimuli_features = self.stimuli_model(stimuli, layers=[self.stimuli_layer], stimuli_identifier=False) | ||
if stimuli_features.shape[0] == stimuli_features['neuroid_num'].shape[0]: | ||
stimuli_features = stimuli_features.T | ||
|
||
import torch | ||
|
||
for i in tqdm(range(self.data_len), desc="visual search stimuli: "): | ||
op_target = self.unflat(target_features[i:i + 1]) | ||
MMconv = torch.nn.Conv2d(op_target.shape[1], 1, kernel_size=(op_target.shape[2], op_target.shape[3]), | ||
stride=1, bias=False) | ||
MMconv.weight = torch.nn.Parameter(torch.Tensor(op_target)) | ||
|
||
gt_idx = target_features.tar_obj_pos.values[i] | ||
gt = self.gt_array[gt_idx] | ||
|
||
op_stimuli = self.unflat(stimuli_features[i:i + 1]) | ||
out = MMconv(torch.Tensor(op_stimuli)).detach().numpy() | ||
out = out.reshape(out.shape[2:]) | ||
|
||
out = out - np.min(out) | ||
out = out / np.max(out) | ||
out *= 255 | ||
out = np.uint8(out) | ||
out = cv2.resize(out, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) | ||
out = cv2.GaussianBlur(out, (7, 7), 3) | ||
|
||
temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) | ||
temp_stim[self.search_image_size:2 * self.search_image_size, | ||
self.search_image_size:2 * self.search_image_size] = np.copy(out) | ||
attn = np.copy(temp_stim * self.gt_total) | ||
|
||
saccade = [] | ||
(x, y) = int(attn.shape[0] / 2), int(attn.shape[1] / 2) | ||
saccade.append((x, y)) | ||
|
||
for k in range(self.max_fix): | ||
(x, y) = np.unravel_index(np.argmax(attn), attn.shape) | ||
|
||
fxn_x, fxn_y = x, y | ||
|
||
fxn_x, fxn_y = max(fxn_x, self.search_image_size), max(fxn_y, self.search_image_size) | ||
fxn_x, fxn_y = min(fxn_x, (attn.shape[0] - self.search_image_size)), min(fxn_y, ( | ||
attn.shape[1] - self.search_image_size)) | ||
|
||
saccade.append((fxn_x, fxn_y)) | ||
|
||
attn, t = self.remove_attn(attn, saccade[-1][0], saccade[-1][1]) | ||
|
||
if (t == gt_idx): | ||
self.score[data_cnt, k + 1] = 1 | ||
data_cnt += 1 | ||
break | ||
|
||
saccade = np.asarray(saccade) | ||
j = saccade.shape[0] | ||
|
||
for k in range(j): | ||
tar_id = self.get_pos(saccade[k, 0], saccade[k, 1], 0) | ||
saccade[k, 0] = self.fix[tar_id][0] | ||
saccade[k, 1] = self.fix[tar_id][1] | ||
|
||
I_data[i, 0] = min(7, j) | ||
S_data[i, :j, 0] = saccade[:, 0].reshape((-1,))[:7] | ||
S_data[i, :j, 1] = saccade[:, 1].reshape((-1,))[:7] | ||
|
||
self.data[:, :7, :] = S_data | ||
self.data[:, 7, :] = I_data | ||
|
||
return (self.score, self.data) | ||
|
||
def remove_attn(self, img, x, y): | ||
t = -1 | ||
for i in range(5, -1, -1): | ||
fxt_place = self.gt_array[i][x, y] | ||
if (fxt_place > 0): | ||
t = i | ||
break | ||
|
||
if (t > -1): | ||
img[self.gt_array[t] == 1] = 0 | ||
|
||
return img, t | ||
|
||
def get_pos(self, x, y, t): | ||
for i in range(5, -1, -1): | ||
fxt_place = self.gt_array[i][int(x), int(y)] | ||
if (fxt_place > 0): | ||
t = i + 1 | ||
break | ||
return t | ||
|
||
def unflat(self, X): | ||
channel_names = ['channel', 'channel_x', 'channel_y'] | ||
assert all(hasattr(X, coord) for coord in channel_names) | ||
shapes = [len(set(X[channel].values)) for channel in channel_names] | ||
X = np.reshape(X.values, [X.shape[0]] + shapes) | ||
X = np.transpose(X, axes=[0, 3, 1, 2]) | ||
return X | ||
|
||
|
||
class VisualSearch(BrainModel): | ||
def __init__(self, identifier, target_model_param, stimuli_model_param): | ||
self.current_task = None | ||
self.identifier = identifier | ||
self.target_model = target_model_param['target_model'] | ||
self.stimuli_model = stimuli_model_param['stimuli_model'] | ||
self.target_layer = target_model_param['target_layer'] | ||
self.stimuli_layer = stimuli_model_param['stimuli_layer'] | ||
self.search_image_size = stimuli_model_param['search_image_size'] | ||
self._logger = logging.getLogger(fullname(self)) | ||
|
||
def start_task(self, task: BrainModel.Task, **kwargs): | ||
self.max_fix = kwargs['max_fix'] # maximum allowed fixation excluding the very first fixation | ||
self.data_len = kwargs['data_len'] # Number of stimuli | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this just be the length of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this will be the same. I will change this. Thanks for pointing out. |
||
self.current_task = task | ||
self.ior_size = kwargs['ior_size'] | ||
|
||
def look_at(self, stimuli_set): | ||
self.score = np.zeros((self.data_len, self.max_fix + 1)) | ||
self.data = np.zeros((self.data_len, self.max_fix + 2, 2), dtype=int) | ||
S_data = np.zeros((self.data_len, self.max_fix + 1, 2), dtype=int) | ||
I_data = np.zeros((self.data_len, 1), dtype=int) | ||
|
||
data_cnt = 0 | ||
|
||
target = stimuli_set[stimuli_set['image_label'] == 'target'] | ||
target_features = self.target_model(target, layers=[self.target_layer], stimuli_identifier=False) | ||
if target_features.shape[0] == target_features['neuroid_num'].shape[0]: | ||
target_features = target_features.T | ||
|
||
stimuli = stimuli_set[stimuli_set['image_label'] == 'stimuli'] | ||
stimuli_features = self.stimuli_model(stimuli, layers=[self.stimuli_layer], stimuli_identifier=False) | ||
if stimuli_features.shape[0] == stimuli_features['neuroid_num'].shape[0]: | ||
stimuli_features = stimuli_features.T | ||
|
||
gt = stimuli_set[stimuli_set['image_label'] == 'gt'] | ||
gt_paths = list(gt.image_paths.values())[int(gt.index.values[0]):int(gt.index.values[-1] + 1)] | ||
|
||
import torch | ||
|
||
for i in tqdm(range(self.data_len), desc="visual search stimuli: "): | ||
op_target = self.unflat(target_features[i:i + 1]) | ||
MMconv = torch.nn.Conv2d(op_target.shape[1], 1, kernel_size=(op_target.shape[2], op_target.shape[3]), | ||
stride=1, bias=False) | ||
MMconv.weight = torch.nn.Parameter(torch.Tensor(op_target)) | ||
|
||
imagename_gt = gt_paths[i] | ||
gt = cv2.imread(imagename_gt, 0) | ||
gt = cv2.resize(gt, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) | ||
retval, gt = cv2.threshold(gt, 125, 255, cv2.THRESH_BINARY) | ||
temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) | ||
temp_stim[self.search_image_size:2 * self.search_image_size, | ||
self.search_image_size:2 * self.search_image_size] = np.copy(gt) | ||
gt = np.copy(temp_stim) | ||
gt = gt / 255 | ||
|
||
op_stimuli = self.unflat(stimuli_features[i:i + 1]) | ||
out = MMconv(torch.Tensor(op_stimuli)).detach().numpy() | ||
out = out.reshape(out.shape[2:]) | ||
|
||
out = out - np.min(out) | ||
out = out / np.max(out) | ||
out *= 255 | ||
out = np.uint8(out) | ||
out = cv2.resize(out, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) | ||
out = cv2.GaussianBlur(out, (7, 7), 3) | ||
|
||
temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) | ||
temp_stim[self.search_image_size:2 * self.search_image_size, | ||
self.search_image_size:2 * self.search_image_size] = np.copy(out) | ||
attn = np.copy(temp_stim) | ||
|
||
saccade = [] | ||
(x, y) = int(attn.shape[0] / 2), int(attn.shape[1] / 2) | ||
saccade.append((x, y)) | ||
|
||
for k in range(self.max_fix): | ||
(x, y) = np.unravel_index(np.argmax(attn), attn.shape) | ||
|
||
fxn_x, fxn_y = x, y | ||
|
||
fxn_x, fxn_y = max(fxn_x, self.search_image_size), max(fxn_y, self.search_image_size) | ||
fxn_x, fxn_y = min(fxn_x, (attn.shape[0] - self.search_image_size)), min(fxn_y, ( | ||
attn.shape[1] - self.search_image_size)) | ||
|
||
saccade.append((fxn_x, fxn_y)) | ||
|
||
attn, t = self.remove_attn(attn, saccade[-1][0], saccade[-1][1], gt) | ||
|
||
if t: | ||
self.score[data_cnt, k + 1] = 1 | ||
data_cnt += 1 | ||
break | ||
|
||
saccade = np.asarray(saccade) | ||
j = saccade.shape[0] | ||
|
||
I_data[i, 0] = min(self.max_fix+1, j) | ||
S_data[i, :j, 0] = saccade[:, 0].reshape((-1,))[:self.max_fix+1] | ||
S_data[i, :j, 1] = saccade[:, 1].reshape((-1,))[:self.max_fix+1] | ||
|
||
self.data[:, :self.max_fix+1, :] = S_data | ||
self.data[:, self.max_fix+1, :] = I_data | ||
|
||
return (self.score, self.data) | ||
|
||
def remove_attn(self, img, x, y, gt): | ||
img[(x - int(self.ior_size/2)):(x + int(self.ior_size/2)), (y - int(self.ior_size/2)):(y + int(self.ior_size/2))] = 0 | ||
|
||
fxt_xtop = x-int(self.ior_size/2) | ||
fxt_ytop = y-int(self.ior_size/2) | ||
fxt_place = gt[fxt_xtop:(fxt_xtop+self.ior_size), fxt_ytop:(fxt_ytop+self.ior_size)] | ||
|
||
if (np.sum(fxt_place)>0): | ||
return img, True | ||
else: | ||
return img, False | ||
|
||
def unflat(self, X): | ||
channel_names = ['channel', 'channel_x', 'channel_y'] | ||
assert all(hasattr(X, coord) for coord in channel_names) | ||
shapes = [len(set(X[channel].values)) for channel in channel_names] | ||
X = np.reshape(X.values, [X.shape[0]] + shapes) | ||
X = np.transpose(X, axes=[0, 3, 1, 2]) | ||
return X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should these not be the same? Sorry I'm getting confused with these parameters hidden in dictionaries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No these will not be the same. As I previously mentioned that the input image size will be different for the target_image and stimuli_image. You will need two different ML models for both.