Skip to content

Commit

Permalink
few-shot random sampling, evaluated with a single batch, few-shot cla…
Browse files Browse the repository at this point in the history
…ss balanced sampling, achieve almost best performance with only 4 steps and each step is only batch size 5, cost 20 more times inference
  • Loading branch information
liyin2015 committed May 26, 2024
1 parent 0a8951e commit 3bc86f2
Show file tree
Hide file tree
Showing 19 changed files with 4,337 additions and 108 deletions.
31 changes: 31 additions & 0 deletions class_sampler_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
DatasetDict({
train: Dataset({
features: ['text', 'coarse_label', 'fine_label'],
num_rows: 5452
})
test: Dataset({
features: ['text', 'coarse_label', 'fine_label'],
num_rows: 500
})
})
Train example: {'text': 'How did serfdom develop in and then leave Russia ?', 'coarse_label': 2, 'fine_label': 26}
Test example: {'text': 'How far is it from Denver to Aspen ?', 'coarse_label': 5, 'fine_label': 40}
INFO:core.prompt_builder:Prompt has variables: ['classes']
INFO:core.prompt_builder:Prompt has variables: ['example', 'schema']
DEBUG:use_cases.classification.task:output_str: Your output should be formatted as a standard YAML instance with the following schema:
```
thought: Your reasoning to classify the question to class_name (str) (required)
class_name: class_name (str) (required)
class_index: class_index in range[0, 5] (int) (required)
```

-Make sure to always enclose the YAML output in triple backticks (```). Please do not add anything other than valid YAML output!
-Follow the YAML formatting conventions with an indent of 2 spaces.
-Quote the string values properly.

DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False
DEBUG:httpx:load_verify_locations cafile='/Users/liyin/Documents/test/LightRAG/.venv/lib/python3.11/site-packages/certifi/cacert.pem'
INFO:core.prompt_builder:Prompt has variables: ['input', 'output_format_str', 'examples_str', 'input_label', 'task_desc_str']
data: None, requires_opt: True
data: {'examples_str': Parameter: None with key: examples_str}, requires_opt: True
Registered parameter trainable_prompt_kwargs with value Parameter: {'examples_str': Parameter: None with key: examples_str} with key: None
3,669 changes: 3,669 additions & 0 deletions class_sampler_optimizer.txt

Large diffs are not rendered by default.

68 changes: 66 additions & 2 deletions core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ def _addindent(s_, numSpaces):
return s


# def fun_to_component(fun):
# r"""Decorator to convert a function to a Component class."""
# class_name = fun.__name__.capitalize()

# class ComponentClass(Component):
# def __init__(self, *args, **kwargs):
# super().__init__()
# self.fun = fun

# def __call__(self, *args, **kwargs):
# return self.fun(*args, **kwargs)

# def _extra_repr(self) -> str:
# return super()._extra_repr() + f"fun={self.fun}"

# ComponentClass.__name__ = class_name
# return ComponentClass()


def _call_unimplemented(self, *input: Any) -> None:
r"""
Define the call method for the component.
Expand All @@ -68,7 +87,7 @@ class Component:
(2) All components can be running local or APIs. 'Component' can deal with API calls, so we need support retries and rate limits.
"""

_version: int = 0
_version: int = 0.1 # Version of the component
# TODO: the type of module, is it OrderedDict or just Dict?
_components: Dict[str, Optional["Component"]]
_execution_graph: List[str] = [] # This will store the graph of execution.
Expand Down Expand Up @@ -131,7 +150,7 @@ def parameters(self, recursive: bool = True) -> Iterable[Parameter]:
>>> for param in model.parameters():
>>> print(param)
"""
for name, param in self.named_parameters():
for name, param in self.named_parameters(recurse=recursive):
yield param

def _named_members(
Expand Down Expand Up @@ -265,6 +284,19 @@ async def acall(self, *args, **kwargs):
pass

def add_component(self, name: str, component: Optional["Component"]) -> None:
r"Add a child component to the current component."
if not isinstance(component, Component) and component is not None:
raise TypeError(
f"component should be an instance of Component, but got {type(component)}"
)
if not isinstance(name, str):
raise TypeError(f"name should be a string, but got {type(name)}")
elif hasattr(self, name) and name not in self._components:
raise KeyError(f"attribute '{name}' already exists")
elif "." in name:
raise ValueError('component name can\'t contain "."')
elif name == "":
raise ValueError('component name can\'t be empty string ""')
self._components[name] = component

def register_subcomponent(
Expand Down Expand Up @@ -328,10 +360,22 @@ def named_components(
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
print(f"module: {module} ")
yield from module.named_components(
memo, submodule_prefix, remove_duplicate
)

def _save_to_state_dict(self, destination, prefix):
r"""Saves the state of the component to a dictionary.
Args:
destination (Dict[str, Any]): the dictionary to which the state is saved.
prefix (str): a prefix to add to the keys in the state_dict.
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param

def state_dict(
self, destination: Optional[Dict[str, Any]] = None, prefix: Optional[str] = ""
) -> Dict[str, Any]:
Expand All @@ -355,8 +399,13 @@ def state_dict(
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
# to do when local data where be needed
if hasattr(self, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata

# save its own state
self._save_to_state_dict(destination, prefix=prefix)
# save the state of all subcomponents
for name, component in self._components.items():
if component is not None:
component.state_dict(
Expand Down Expand Up @@ -694,3 +743,18 @@ def update(self, components: Mapping[str, Component]) -> None:
self[m[0]] = m[1] # type: ignore[assignment]

# remove forward alltogether to fallback on Module's _forward_unimplemented


class FunComponent(Component):
def __init__(self, fun: Callable):
super().__init__()
self.fun = fun

def call(self, *args, **kwargs):
return self.fun(*args, **kwargs)


def fun_to_component(fun):
r"""Decorator to convert a function to a Component class."""
class_name = fun.__name__.capitalize() + "Component"
return type(class_name, (FunComponent,), {})(fun)
7 changes: 6 additions & 1 deletion core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __init__(
# args for the prompt
template: str = DEFAULT_LIGHTRAG_SYSTEM_PROMPT,
preset_prompt_kwargs: Optional[Dict] = None, # manage the prompt kwargs
trainable_params: Optional[
List[str]
] = None, # the trainable parameters in the prompt
output_processors: Optional[Component] = None,
) -> None:
r"""The default prompt is set to the DEFAULT_LIGHTRAG_SYSTEM_PROMPT. It has the following variables:
Expand All @@ -52,7 +55,9 @@ def __init__(
# init the model client
self.model_client = model_client()
self.system_prompt = Prompt(
template=template, preset_prompt_kwargs=preset_prompt_kwargs
template=template,
preset_prompt_kwargs=preset_prompt_kwargs,
trainable_params=trainable_params,
)

self.output_processors = output_processors
Expand Down
9 changes: 7 additions & 2 deletions core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ class Parameter:
"""

def __init__(self, data: Any, requires_opt: bool = True):
def __init__(self, key: str, data: Any, requires_opt: bool = True):
self.data = data
self.key = key
self.requires_opt = requires_opt
print(f"data: {data}, requires_opt: {requires_opt}")

def update_value(self, data: Any):
self.data = data

def __repr__(self):
return f"Parameter containing:\n{self.data}"
return f"Parameter: {self.data} with key: {self.key}"
19 changes: 14 additions & 5 deletions core/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
preset_prompt_kwargs: Optional[Dict] = {}, # preload the parameters
trainable_params: Optional[
str
] = LIGHTRAG_DEFAULT_PROMPT_TRAINABLE_PARAMS, # the variables in the prompt that is trainable, in default, all will be passed to an optimizer
] = [], # the variables in the prompt that is trainable, in default, all will be passed to an optimizer
):

super().__init__()
Expand All @@ -94,6 +94,7 @@ def __init__(
# ensure all trainable_paramers are in the prompt_variables
# Start of the trainable parameters#
# self.trainable_prompt_kwargs: Parameter = {}
# parameter should be always a key and a value. key should be global state
self._trainable_prompt_kwargs: Dict[str, str] = {}
for param in trainable_params:
if param not in self.prompt_variables:
Expand All @@ -105,8 +106,14 @@ def __init__(
data = preset_prompt_kwargs[param]
else:
data = None
self._trainable_prompt_kwargs[param] = data
self.trainable_prompt_kwargs = Parameter(self._trainable_prompt_kwargs)
self._trainable_prompt_kwargs[param] = Parameter(key=param, data=data)

# self._trainable_prompt_kwargs[param] = Parameter(data)
self.trainable_prompt_kwargs = (
Parameter(key=None, data=self._trainable_prompt_kwargs)
if self._trainable_prompt_kwargs
else None
)
# End of the trainable parameters#

# an optimizer will optimize the trainable parameters, and
Expand Down Expand Up @@ -180,8 +187,8 @@ def _extra_repr(self) -> str:
s += f", preset_prompt_kwargs: {self.preset_prompt_kwargs}"
if self.prompt_variables:
s += f", prompt_variables: {self.prompt_variables}"
if self.trainable_prompt_kwargs:
s += f", trainable_prompt_kwargs: {self.trainable_prompt_kwargs}"
if self._trainable_prompt_kwargs:
s += f", trainable_prompt_kwargs: {self._trainable_prompt_kwargs}"
return s


Expand All @@ -203,6 +210,8 @@ def _extra_repr(self) -> str:
for name, param in named_params:
print(f"{name}: {param}")

print(f"prompt_variables: {prompt._trainable_prompt_kwargs}")

# EXAMPLES_TEMPLATE = r"""
# {% if examples %}
# {% for example in examples %}
Expand Down
Empty file added optimizer/__init__.py
Empty file.
Empty file added optimizer/llm_optimizer.py
Empty file.
45 changes: 45 additions & 0 deletions optimizer/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from copy import deepcopy
from core.parameter import Parameter
from core.component import Component

from optimizer.sampler import RandomSampler


class Optimizer:
def state_dict(self):
pass


r"""
We focus on error fixing, run a batch, get batch based accuracy.
# pass the batch to the LLMOptimizer
# sample a class from the batch.Let an llm to boostra
"""


class BootstrapFewShot(Optimizer):
def __init__(
self,
example_parameter: Parameter,
train_dataset,
sampler: Component,
output_processors: Component,
):
super().__init__()
self.example_parameter = deepcopy(example_parameter)
self.train_dataset = train_dataset
self.sampler = sampler
self.output_processors = output_processors

def step(self, num_shots: int):
examples = self.sampler(num_shots)
if self.output_processors:
examples = self.output_processors(examples)
self.example_parameter.update_value(examples)
return self.example_parameter

def state_dict(self):
# TODO: need to figure out how really parameters and states are saved and loaded.
return {"example_parameter": self.example_parameter}
54 changes: 54 additions & 0 deletions optimizer/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import random
from typing import List, Sequence, Optional, Callable, Any
import math

from core.component import Component


class ClassSampler(Component):
def __init__(
self,
dataset,
num_classes: int,
get_data_key_fun: Callable,
):
super().__init__()
self.dataset = dataset
self.num_classes = num_classes
if not get_data_key_fun:
raise ValueError("get_data_key_fun must be provided")
self.get_key = get_data_key_fun
self.class_indices: List[List] = [[] for _ in range(num_classes)]
for i, data in enumerate(dataset):
self.class_indices[self.get_key(data)].append(i)

def _sample_one_class(self, num_samples: int, class_index: int) -> List[Any]:
incides = random.sample(self.class_indices[class_index], num_samples)
samples = [self.dataset[i] for i in incides]
return samples

def call(self, shots: int) -> Sequence[str]:
samples = []
samples_per_class = math.ceil(shots / self.num_classes)
print(f"samples_per_class: {samples_per_class}")
for class_index in range(self.num_classes):
samples.extend(self._sample_one_class(samples_per_class, class_index))
if len(samples) > shots:
# randomly sample the remaining samples
samples = random.sample(samples, shots)
return samples


class RandomSampler(Component):
def __init__(self, dataset, num_shots: Optional[int] = None):
super().__init__()
self.dataset = dataset
self.num_shots = num_shots

def __call__(self, num_shots: Optional[int] = None) -> Sequence[str]:
if num_shots is None:
num_shots = self.num_shots
if num_shots is None:
raise ValueError("num_shots is not set")
indices = random.sample(range(len(self.dataset)), num_shots)
return [self.dataset[i] for i in indices]
34 changes: 34 additions & 0 deletions torchviz_output
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
digraph {
graph [size="12,12"]
node [align=left fontname=monospace fontsize=10 height=0.2 ranksep=0.1 shape=box style=filled]
4936293968 [label="
(1, 20, 20, 20)" fillcolor=darkolivegreen1]
4934982272 [label=ReluBackward0]
4934982416 -> 4934982272
4934982416 [label=ConvolutionBackward0]
4932567584 -> 4934982416
4932567584 [label=ReluBackward0]
4932577376 -> 4932567584
4932577376 [label=ConvolutionBackward0]
4932567632 -> 4932577376
4936291088 [label="
(20, 1, 5, 5)" fillcolor=lightblue]
4936291088 -> 4932567632
4932567632 [label=AccumulateGrad]
4932579152 -> 4932577376
4936291184 [label="
(20)" fillcolor=lightblue]
4936291184 -> 4932579152
4932579152 [label=AccumulateGrad]
4932567536 -> 4934982416
4936291376 [label="
(20, 20, 5, 5)" fillcolor=lightblue]
4936291376 -> 4932567536
4932567536 [label=AccumulateGrad]
4932568400 -> 4934982416
4936291472 [label="
(20)" fillcolor=lightblue]
4936291472 -> 4932568400
4932568400 [label=AccumulateGrad]
4934982272 -> 4936293968
}
Binary file added torchviz_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 3bc86f2

Please sign in to comment.