-
Notifications
You must be signed in to change notification settings - Fork 217
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
few-shot random sampling, evaluated with a single batch, few-shot cla…
…ss balanced sampling, achieve almost best performance with only 4 steps and each step is only batch size 5, cost 20 more times inference
- Loading branch information
Showing
19 changed files
with
4,337 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
DatasetDict({ | ||
train: Dataset({ | ||
features: ['text', 'coarse_label', 'fine_label'], | ||
num_rows: 5452 | ||
}) | ||
test: Dataset({ | ||
features: ['text', 'coarse_label', 'fine_label'], | ||
num_rows: 500 | ||
}) | ||
}) | ||
Train example: {'text': 'How did serfdom develop in and then leave Russia ?', 'coarse_label': 2, 'fine_label': 26} | ||
Test example: {'text': 'How far is it from Denver to Aspen ?', 'coarse_label': 5, 'fine_label': 40} | ||
INFO:core.prompt_builder:Prompt has variables: ['classes'] | ||
INFO:core.prompt_builder:Prompt has variables: ['example', 'schema'] | ||
DEBUG:use_cases.classification.task:output_str: Your output should be formatted as a standard YAML instance with the following schema: | ||
``` | ||
thought: Your reasoning to classify the question to class_name (str) (required) | ||
class_name: class_name (str) (required) | ||
class_index: class_index in range[0, 5] (int) (required) | ||
``` | ||
|
||
-Make sure to always enclose the YAML output in triple backticks (```). Please do not add anything other than valid YAML output! | ||
-Follow the YAML formatting conventions with an indent of 2 spaces. | ||
-Quote the string values properly. | ||
|
||
DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False | ||
DEBUG:httpx:load_verify_locations cafile='/Users/liyin/Documents/test/LightRAG/.venv/lib/python3.11/site-packages/certifi/cacert.pem' | ||
INFO:core.prompt_builder:Prompt has variables: ['input', 'output_format_str', 'examples_str', 'input_label', 'task_desc_str'] | ||
data: None, requires_opt: True | ||
data: {'examples_str': Parameter: None with key: examples_str}, requires_opt: True | ||
Registered parameter trainable_prompt_kwargs with value Parameter: {'examples_str': Parameter: None with key: examples_str} with key: None |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from copy import deepcopy | ||
from core.parameter import Parameter | ||
from core.component import Component | ||
|
||
from optimizer.sampler import RandomSampler | ||
|
||
|
||
class Optimizer: | ||
def state_dict(self): | ||
pass | ||
|
||
|
||
r""" | ||
We focus on error fixing, run a batch, get batch based accuracy. | ||
# pass the batch to the LLMOptimizer | ||
# sample a class from the batch.Let an llm to boostra | ||
""" | ||
|
||
|
||
class BootstrapFewShot(Optimizer): | ||
def __init__( | ||
self, | ||
example_parameter: Parameter, | ||
train_dataset, | ||
sampler: Component, | ||
output_processors: Component, | ||
): | ||
super().__init__() | ||
self.example_parameter = deepcopy(example_parameter) | ||
self.train_dataset = train_dataset | ||
self.sampler = sampler | ||
self.output_processors = output_processors | ||
|
||
def step(self, num_shots: int): | ||
examples = self.sampler(num_shots) | ||
if self.output_processors: | ||
examples = self.output_processors(examples) | ||
self.example_parameter.update_value(examples) | ||
return self.example_parameter | ||
|
||
def state_dict(self): | ||
# TODO: need to figure out how really parameters and states are saved and loaded. | ||
return {"example_parameter": self.example_parameter} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import random | ||
from typing import List, Sequence, Optional, Callable, Any | ||
import math | ||
|
||
from core.component import Component | ||
|
||
|
||
class ClassSampler(Component): | ||
def __init__( | ||
self, | ||
dataset, | ||
num_classes: int, | ||
get_data_key_fun: Callable, | ||
): | ||
super().__init__() | ||
self.dataset = dataset | ||
self.num_classes = num_classes | ||
if not get_data_key_fun: | ||
raise ValueError("get_data_key_fun must be provided") | ||
self.get_key = get_data_key_fun | ||
self.class_indices: List[List] = [[] for _ in range(num_classes)] | ||
for i, data in enumerate(dataset): | ||
self.class_indices[self.get_key(data)].append(i) | ||
|
||
def _sample_one_class(self, num_samples: int, class_index: int) -> List[Any]: | ||
incides = random.sample(self.class_indices[class_index], num_samples) | ||
samples = [self.dataset[i] for i in incides] | ||
return samples | ||
|
||
def call(self, shots: int) -> Sequence[str]: | ||
samples = [] | ||
samples_per_class = math.ceil(shots / self.num_classes) | ||
print(f"samples_per_class: {samples_per_class}") | ||
for class_index in range(self.num_classes): | ||
samples.extend(self._sample_one_class(samples_per_class, class_index)) | ||
if len(samples) > shots: | ||
# randomly sample the remaining samples | ||
samples = random.sample(samples, shots) | ||
return samples | ||
|
||
|
||
class RandomSampler(Component): | ||
def __init__(self, dataset, num_shots: Optional[int] = None): | ||
super().__init__() | ||
self.dataset = dataset | ||
self.num_shots = num_shots | ||
|
||
def __call__(self, num_shots: Optional[int] = None) -> Sequence[str]: | ||
if num_shots is None: | ||
num_shots = self.num_shots | ||
if num_shots is None: | ||
raise ValueError("num_shots is not set") | ||
indices = random.sample(range(len(self.dataset)), num_shots) | ||
return [self.dataset[i] for i in indices] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
digraph { | ||
graph [size="12,12"] | ||
node [align=left fontname=monospace fontsize=10 height=0.2 ranksep=0.1 shape=box style=filled] | ||
4936293968 [label=" | ||
(1, 20, 20, 20)" fillcolor=darkolivegreen1] | ||
4934982272 [label=ReluBackward0] | ||
4934982416 -> 4934982272 | ||
4934982416 [label=ConvolutionBackward0] | ||
4932567584 -> 4934982416 | ||
4932567584 [label=ReluBackward0] | ||
4932577376 -> 4932567584 | ||
4932577376 [label=ConvolutionBackward0] | ||
4932567632 -> 4932577376 | ||
4936291088 [label=" | ||
(20, 1, 5, 5)" fillcolor=lightblue] | ||
4936291088 -> 4932567632 | ||
4932567632 [label=AccumulateGrad] | ||
4932579152 -> 4932577376 | ||
4936291184 [label=" | ||
(20)" fillcolor=lightblue] | ||
4936291184 -> 4932579152 | ||
4932579152 [label=AccumulateGrad] | ||
4932567536 -> 4934982416 | ||
4936291376 [label=" | ||
(20, 20, 5, 5)" fillcolor=lightblue] | ||
4936291376 -> 4932567536 | ||
4932567536 [label=AccumulateGrad] | ||
4932568400 -> 4934982416 | ||
4936291472 [label=" | ||
(20)" fillcolor=lightblue] | ||
4936291472 -> 4932568400 | ||
4932568400 [label=AccumulateGrad] | ||
4934982272 -> 4936293968 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.