Skip to content

Commit

Permalink
improve tensor extractor prints
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Jun 21, 2023
1 parent 4520128 commit f1d4112
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
10 changes: 5 additions & 5 deletions src/data_gradients/batch_processors/adapters/dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from data_gradients.batch_processors.adapters.tensor_extractor import get_tensor_extractor_options
from data_gradients.config.data.data_config import DataConfig
from data_gradients.config.data.questions import Question
from data_gradients.config.data.questions import Question, text_to_yellow

SupportedData = Union[Tuple, List, Mapping, Tuple, List]

Expand Down Expand Up @@ -43,8 +43,8 @@ def _get_images_extractor(self, data: SupportedData) -> Callable[[SupportedData]

# Otherwise, we ask the user how to map data -> image
if isinstance(data, (Tuple, List, Mapping, Tuple, List)):
description, options = get_tensor_extractor_options(data, object_name="Image(s)")
question = Question(question="Which tensor represents your Image(s) ?", options=options)
description, options = get_tensor_extractor_options(data)
question = Question(question=f"Which tensor represents your {text_to_yellow('Image(s)')} ?", options=options)
return self.data_config.get_images_extractor(question=question, hint=description)

raise NotImplementedError(
Expand All @@ -65,8 +65,8 @@ def _get_labels_extractor(self, data: SupportedData) -> Callable[[SupportedData]

# Otherwise, we ask the user how to map data -> labels
if isinstance(data, (Tuple, List, Mapping, Tuple, List)):
description, options = get_tensor_extractor_options(data, object_name="Labels(s)")
question = Question(question="Which tensor represents your Label(s) ?", options=options)
description, options = get_tensor_extractor_options(data)
question = Question(question=f"Which tensor represents your {text_to_yellow('Label(s)')} ?", options=options)
return self.data_config.get_labels_extractor(question=question, hint=description)

raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@
from numpy import ndarray


def get_tensor_extractor_options(objs: Any, object_name: str) -> Tuple[str, Dict[str, str]]:
def get_tensor_extractor_options(objs: Any) -> Tuple[str, Dict[str, str]]:
"""Extract out of objs all the potential fields of type [torch.Tensor, np.ndarray, PIL.Image], and then
asks the user to input which of the above keys mapping is the right one in order to retrieve the correct data (either images or labels).
:param object_name: Name of the object you want to extract ("image", "label", ...)
:param objs: Dictionary following the pattern: {"path.to.object: object_type": "path.to.object"}
"""
objects_mapping: List[Tuple[str, str]] = [] # Placeholder for list of (path.to.object, object_type)
nested_object_mapping = extract_object_mapping(objs, current_path="", objects_mapping=objects_mapping)
description = "This is how your data is structured: \n"
description += f"data = {json.dumps(nested_object_mapping, indent=4)}"

options = {f"- {object_name} = data{path_to_object}: {object_type}": path_to_object for path_to_object, object_type in objects_mapping}
options = {f"data{path_to_object}: {object_type}": path_to_object for path_to_object, object_type in objects_mapping}
return description, options


Expand Down
16 changes: 12 additions & 4 deletions src/data_gradients/config/data/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
from typing import Dict, Any, Optional, List


def text_to_blue(text: str) -> str:
return f"\033[34;1m{text}\033[0m"


def text_to_yellow(text: str):
return f"\033[33;1m{text}\033[0m"


@dataclass
class Question:
"""Model a Question with its options
Expand Down Expand Up @@ -35,7 +43,7 @@ def ask_user(main_question: str, options: List[str], optional_description: str =
"""
numbers_to_chose_from = range(len(options))

options_formatted = "\n".join([f"[{number}] {option_description}" for number, option_description in zip(numbers_to_chose_from, options)])
options_formatted = "\n".join([f"[{text_to_blue(number)}] | {option_description}" for number, option_description in zip(numbers_to_chose_from, options)])

user_answer = None
while user_answer not in numbers_to_chose_from:
Expand All @@ -49,15 +57,15 @@ def ask_user(main_question: str, options: List[str], optional_description: str =
print("")

try:
user_answer = input("Your selection (Enter the corresponding number) >>> ")
user_answer = input(f"Your selection (Enter the {text_to_blue('corresponding number')}) >>> ")
user_answer = int(user_answer)
except Exception:
user_answer = None

if user_answer not in numbers_to_chose_from:
print(f'Oops! "{user_answer}" is not a valid choice. Let\'s try again.')
print(f'Oops! "{text_to_blue(str(user_answer))}" is not a valid choice. Let\'s try again.')

selected_option = options[user_answer]
print(f"Great! You chose: {selected_option}\n")
print(f"Great! You chose: {text_to_yellow(selected_option)}\n")

return selected_option

0 comments on commit f1d4112

Please sign in to comment.