Skip to content

Commit

Permalink
check
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jun 11, 2024
1 parent e349418 commit 05048ff
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def get_greedy_with_repetition_penalty() -> GenerationConfig:
generation_config.repetition_penalty = 2.0
return generation_config

def get_greedy_with_penalties() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_return_sequences = 1
generation_config.presence_penalty = 2.0
generation_config.frequence_penalty = 0.2
return generation_config


def get_beam_search() -> GenerationConfig:
generation_config = GenerationConfig()
Expand Down Expand Up @@ -96,13 +103,15 @@ def get_multinomial_temperature_and_frequence_penalty() -> GenerationConfig:
generation_config.do_sample = True
generation_config.temperature = 0.8
generation_config.frequence_penalty = 0.5
generation_config.num_return_sequences = 1
return generation_config

def get_multinomial_temperature_and_presence_penalty() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.do_sample = True
generation_config.temperature = 0.8
generation_config.presence_penalty = 0.1
generation_config.num_return_sequences = 1
return generation_config

def get_test_dataset() -> Tuple[List[str], List[GenerationConfig]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

from common import run_test_pipeline, get_models_list, get_model_and_tokenizer, save_ov_model_from_optimum, \
generate_and_compare_with_reference_text, get_greedy, get_beam_search, get_multinomial_temperature, \
get_greedy_with_penalties, \
get_multinomial_temperature_and_top_k, get_multinomial_temperature_and_top_p, \
get_multinomial_temperature_top_p_and_top_k, DEFAULT_SCHEDULER_CONFIG, get_greedy_with_repetition_penalty, \
get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \
get_multinomial_temperature_and_frequence_penalty, get_multinomial_temperature_and_presence_penalty, \
generate_and_compare_with_hf, get_multinomial_temperature_and_repetition_penalty, get_scheduler_config


Expand Down Expand Up @@ -111,12 +113,21 @@ class RandomSamplingTestStruct:
prompts=["Tell me something about UAE"],
ref_texts=[
[
" and how it's not like we're all in the same boat right now lol (or even close) 😂😁! Just curious :) If",
"? You are my country... so what does our military do here?? What am i missing out on?? And why don't u tell us?",
'?\nThe U.S government has been doing quite well with foreign-made aircraft for many years under US administration....and they have very good reasons',
'? I think that is a bit of an anomaly, but you might want to ask yourself this question: Where can some young people from Dubai or Bahrain'
"?\nUAE is the country with a population of 3 million people living between 4pm - 8am daily. So yeah i'm pretty sure we",
'?\nThe name, emirate or not... there are many interesting places to stay while visiting here that you can do very well as an individual',
" and how it's different from Saudi Arabia but in my eyes I don't care if they have freedom fighters! /s Just kidding :) You're welcome",
'?\nThe U.A.. Is a small island nation full-fledged on top of Dubai lol Edit: forgot Arabic means "small" where',
]
]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_presence_penalty(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\n\nOpenVINO is a software development platform developed by OpenVINO, Inc., which uses a RESTful API for server-side web applications"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_frequence_penalty(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\n\nOpenVINO is a software development platform developed by OpenVINO, Inc., which offers the Linux-based platform. OpenVINO's"] ]),
RandomSamplingTestStruct(generation_config=get_greedy_with_penalties(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\nOpenVINO is a software that allows users to create and manage their own virtual machines. It's designed for use with Windows, Mac OS X"] ]),
]


Expand All @@ -129,8 +140,9 @@ class RandomSamplingTestStruct:
"multinomial_temperature_and_repetition_penalty",
"multinomial_temperature_and_num_return_sequence",
"multinomial_all_parameters",
"multinomial_temperature_and_presence_penalty",
"multinomial_temperature_and_frequence_penalty"])
"multinomial_temperature_and_presence_penalty",
"multinomial_temperature_and_frequence_penalty",
"greedy_with_penalties"])
def test_individual_generation_configs_random(tmp_path, test_struct: RandomSamplingTestStruct):
generation_config = test_struct.generation_config

Expand Down

0 comments on commit 05048ff

Please sign in to comment.