Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable cfg #34

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import mlx.core as mx
import numpy as np
import torch

from outlines.processors import OutlinesLogitsProcessor


class HalvingLogitsProcessor(OutlinesLogitsProcessor):
"""Simply halve the passed logits"""

def process_logits(self, input_ids, logits):
return logits / 2


class LogitsProcessorBenchmark:
params = ["torch", "numpy"]
if mx.metal.is_available():
params += ["mlx"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()

# logits: (4, 30,000 ) dtype=float
# input_ids shape: (4, 2048) dtype=int
if array_library == "torch":
self.logits = torch.rand((4, 30000), dtype=torch.float)
self.input_ids = torch.randint(
low=0, high=30000, size=(4, 2048), dtype=torch.int
)
elif array_library == "numpy":
self.logits = np.random.rand(4, 30000).astype(np.float32)
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048))
elif array_library == "mlx":
self.logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32
)
self.input_ids = mx.random.randint(
low=0, high=30000, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

def time_logits_processor(self, array_library):
self.logits_processor(self.input_ids, self.logits)
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import outlines.generate
import outlines.grammars
import outlines.models
import outlines.processors
import outlines.types
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
Expand Down
23 changes: 18 additions & 5 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
make_byte_level_fsm,
make_deterministic_fsm,
)
from outlines.fsm.parsing import PartialLark, terminals_to_fsms

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer
Expand Down Expand Up @@ -256,15 +257,23 @@ def __init__(self, cfg_string: str, tokenizer):
self.cfg_string = cfg_string
self.tokenizer = tokenizer

self.parser = Lark(
self.parser = PartialLark(
cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
deterministic=True,
import_paths=[grammars.GRAMMAR_PATH],
# TODO: old options, not sure we need them, investigate
# propagate_positions=False,
# maybe_placeholders=False,

# TODO: old PartialLark options, investigate
# start="file_input",
)

self.regex_fsm = terminals_to_fsms(self.parser)
self.generation = ""

"""
self.terminal_regexps = dict()
for terminal in self.parser.terminals:
if terminal.pattern is not None:
Expand All @@ -279,6 +288,7 @@ def __init__(self, cfg_string: str, tokenizer):
self.check_last = False
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexGuide
"""

self.start_state = 0
self.final_state = -1
Expand Down Expand Up @@ -316,6 +326,9 @@ def get_next_instruction(self, state: int) -> Instruction:
A list that contains the tokens to mask.

"""

import pdb;pdb.set_trace()

if self.is_final_state(state):
return Write([self.tokenizer.eos_token_id])

Expand Down
24 changes: 6 additions & 18 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

from outlines.fsm.guide import CFGGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -36,25 +33,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera

@cfg.register(MLXLM)
@cfg.register(VLLM)
def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
f"The CFG Logits processor is not available for {type(model)}."
)


@cfg.register(LlamaCpp)
def cfg_llamacpp(
model: LlamaCpp,
@cfg.register(Transformers)
def cfg_unified(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import CFGLogitsProcessor
from outlines.processors import CFGLogitsProcessor

logits_processor = CFGLogitsProcessor(cfg_str, model.model)
logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


Expand Down
1 change: 1 addition & 0 deletions outlines/grammars.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ def read_grammar(grammar_file_name, base_grammar_path=GRAMMAR_PATH):

arithmetic = read_grammar("arithmetic.lark")
json = read_grammar("json.lark")
sql_select = read_grammar("sql_select.lark")
7 changes: 4 additions & 3 deletions outlines/grammars/common.lark
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ SIGNED_FLOAT: ["+"|"-"] FLOAT
NUMBER: FLOAT | INT
SIGNED_NUMBER: ["+"|"-"] NUMBER

//
// TODO: Working escaped_string
//
UNESCAPED_STRING: /\"[^"]*\"/

// based on `outlines/fsm/json_schema.py`
ESCAPED_STRING_INNER: /([^"\\\\\\x00-\\x1F\\x7F-\\x9F]|\\\\["\\\\])/
ESCAPED_STRING: "\"" ESCAPED_STRING_INNER* "\""



//
Expand Down
203 changes: 203 additions & 0 deletions outlines/grammars/sql_select.lark
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// Adapted from https://github.com/zbrookle/sql_to_ibis
// License for https://github.com/zbrookle/sql_to_ibis follows
//BSD 3-Clause License
//
//Copyright (c) 2011-2022, Open source contributors.
//
//Redistribution and use in source and binary forms, with or without
//modification, are permitted provided that the following conditions are met:
//
//* Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
//* Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
//* Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
//THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
//DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
//FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
//DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
//CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
//OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


start: set_expr -> final

set_expr: query_expr
| set_expr "UNION"i ["DISTINCT"i] set_expr -> union_distinct
| set_expr "UNION"i "ALL"i set_expr -> union_all
| set_expr "INTERSECT"i ["DISTINCT"i] set_expr -> intersect_distinct
| set_expr "EXCEPT"i ["DISTINCT"i] set_expr -> except_distinct
| set_expr "EXCEPT"i "ALL"i set_expr -> except_all

query_expr: select [ "ORDER"i "BY"i (order_by_expr ",")* order_by_expr] [ "LIMIT"i limit_count [ "OFFSET"i skip_rows ] ]

select: "SELECT"i [SELECT_CONSTRAINT] [(select_expr ",")*] select_expr "FROM"i [(from_expr ",")*] from_expr [ "WHERE"i where_expr ] [ "GROUP"i "BY"i [(groupby_expr ",")*] groupby_expr ] [ "HAVING"i having_expr] [ "WINDOW"i window_expr ]

where_expr: bool_expression

select_expr.0: expression_math [ [ "AS"i ] alias ] -> select_expression

?from_expr: from_item -> from_expression

order_by_expr: order -> order_by_expression

having_expr: bool_expression

groupby_expr: expression -> group_by

window_expr: [window_expr ","] _window_name "AS"i ( window_definition )

from_item: name [ [ "AS"i ] alias ] -> table
| join -> join
| cross_join -> cross_join_expression
| subquery

subquery: ( "(" (query_expr | join | cross_join) ")" ) [ [ "AS"i ] alias ]

cross_join: from_item "CROSS"i "JOIN"i from_item
join: from_item JOIN_EXPR from_item [ "ON"i bool_expression ] -> join_expression

JOIN_EXPR.5: (JOIN_TYPE WS)? "JOIN"i
JOIN_TYPE: "INNER"i | "OUTER"i? | JOIN_DIRECTION (WS "OUTER"i)? | JOIN_DIRECTION
JOIN_DIRECTION: "FULL"i | "LEFT"i | "RIGHT"i

?expression_math: expression_product
| expression_math "+" expression_product -> expression_add
| expression_math "-" expression_product -> expression_sub
| "CASE"i (when_then)+ "ELSE"i expression_math "END"i -> case_expression
| "CAST"i "(" expression_math "AS"i TYPENAME ")" -> as_type
| "CAST"i "(" literal "AS"i TYPENAME ")" -> literal_cast
| AGGREGATION expression_math ")" [window_form] -> sql_aggregation
| "RANK"i "(" ")" window_form -> rank_expression
| "DENSE_RANK"i "(" ")" window_form -> dense_rank_expression
| "COALESCE"i "(" [(expression_math ",")*] expression_math ")" -> coalesce_expression

window_form: "OVER"i "(" ["PARTITION"i "BY"i (partition_by ",")* partition_by] ["ORDER"i "BY"i (order ",")* order [ row_range_clause ] ] ")"

partition_by: expression_math

row_range_clause: ( ROWS | RANGE ) frame_extent
frame_extent: frame_between | frame_preceding
frame_between: "BETWEEN"i frame_bound "AND"i frame_bound
frame_bound: frame_preceding | frame_following | "CURRENT"i "ROW"i
frame_preceding: UNBOUNDED PRECEDING | integer_ PRECEDING
frame_following: UNBOUNDED FOLLOWING | integer_ FOLLOWING
RANGE: "RANGE"i
ROWS: "ROWS"i
UNBOUNDED: "UNBOUNDED"i
PRECEDING: "PRECEDING"i
FOLLOWING: "FOLLOWING"i

when_then: "WHEN"i bool_expression "THEN"i expression_math
order: expression_math ["ASC"i] -> order_asc
| expression_math "DESC"i -> order_desc

column_name: [name "."] name
?expression_product: expression_parens
| expression_product "*" expression_parens -> expression_mul
| expression_product "/" expression_parens -> expression_div

?expression_parens: expression
| "(" expression_parens "*" expression ")" -> expression_mul
| "(" expression_parens "/" expression ")" -> expression_div
| "(" expression_parens "+" expression ")" -> expression_add
| "(" expression_parens "-" expression ")" -> expression_sub

?expression: [name "."] (name | STAR) -> column_name
| literal


SELECT_CONSTRAINT.9: "ALL"i | "DISTINCT"i
TYPENAME: "object"i
| "varchar"i
| "integer"i
| "int16"i
| "smallint"i
| "int32"i
| "int64"i
| "int"i
| "bigint"i
| "float16"i
| "float32"i
| "float64"i
| "float"i
| "bool"i
| "datetime64"i
| "timestamp"i
| "time"i
| "date"i
| "category"i
| "string"i
AGGREGATION.8: ("sum("i | "avg("i | "min("i | "max("i | "count("i "distinct"i | "count("i)
alias: name -> alias_string
_window_name: name
limit_count: integer_ -> limit_count
skip_rows: integer_
bool_expression: bool_parentheses
| bool_expression "AND"i bool_parentheses -> bool_and
| bool_expression "OR"i bool_parentheses -> bool_or
bool_parentheses: comparison_type
| "(" bool_expression "AND"i comparison_type ")" -> bool_and
| "(" bool_expression "OR"i comparison_type ")" -> bool_or
comparison_type: equals | not_equals | greater_than | less_than | greater_than_or_equal
| less_than_or_equal | between | in_expr | not_in_expr | subquery_in | is_null | is_not_null
equals: expression_math "=" expression_math
is_null: expression_math "is"i "null"i
is_not_null: expression_math "is"i "not"i "null"i
not_equals: expression_math ("<>" | "!=") expression_math
greater_than: expression_math ">" expression_math
less_than: expression_math "<" expression_math
greater_than_or_equal: expression_math ">=" expression_math
less_than_or_equal: expression_math "<=" expression_math
between: expression_math "BETWEEN"i expression_math "AND"i expression_math
in_expr: expression_math "IN"i "(" [expression_math ","]* expression_math ")"
subquery_in: expression_math "IN"i subquery
not_in_expr: expression_math "NOT"i "IN"i "(" [expression_math ","]* expression_math ")"
?literal: boolean -> bool
| number_expr -> number
| /'([^']|\s)+'|''/ -> string
| timestamp_expression -> timestamp_expression
boolean: "true"i -> true
| "false"i -> false
?number_expr: product

?product: NUMBER

integer_: /[1-9][0-9]*/
STAR: "*"
window_definition:
timestamp_expression: "NOW"i "(" ")" -> datetime_now
| "TODAY"i "(" ")" -> date_today
| "TIMESTAMP"i "(" "'" date "'" "," "'" time "'" ")" -> custom_timestamp

date: YEAR "-" MONTH "-" DAY
YEAR: /[0-9]{4}/
MONTH: /[0-9]{2}/
DAY: /[0-9]{2}/
time: HOURS ":" MINUTES ":" SECONDS
HOURS: /[0-9]{2}/
MINUTES: /[0-9]{2}/
SECONDS: /[0-9]{2}/
name: CNAME | ESCAPED_STRING



%import common.ESCAPED_STRING
%import common.CNAME
%import common.NUMBER
%import common.WS
%import common.SQL_COMMENT
%import common.WS_INLINE

%ignore WS
%ignore SQL_COMMENT
Loading
Loading