Skip to content

Commit

Permalink
fix(llama): Overhaul use of sampling module for llama.cpp changes
Browse files Browse the repository at this point in the history
The changes here reflect the changes made in the big llama.cpp sampling PR
ggml-org/llama.cpp#9294

The sampling functionality is now broken into the base interface
(llama_sampler) and the generation implementation (gpt_sampler). The
changes here reflect that. Since the sampling.h/sampling.cpp code uses c++
STL headers, the sampling_ext.[h|cpp] wrapper is maintained to allow go to
access a pure-C interface.

Branch: IBMGraniteArchitectureSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
  • Loading branch information
gabe-l-hart committed Oct 14, 2024
1 parent bcdae7c commit 1b70cde
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 37 deletions.
20 changes: 8 additions & 12 deletions llama/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
c *C.struct_llama_sampling_context
c *C.struct_llama_sampler
}

type SamplingParams struct {
Expand All @@ -467,7 +467,8 @@ type SamplingParams struct {
Grammar string
}

func NewSamplingContext(params SamplingParams) *SamplingContext {
func NewSamplingContext(model *Model, params SamplingParams) *SamplingContext {

var cparams C.struct_llama_sampling_cparams
cparams.top_k = C.int32_t(params.TopK)
cparams.top_p = C.float(params.TopP)
Expand All @@ -489,7 +490,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
defer C.free(unsafe.Pointer(grammar))

cparams.grammar = grammar
context := &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
context := &SamplingContext{c: C.llama_sampling_cinit(model.c, &cparams)}
runtime.SetFinalizer(context, func(s *SamplingContext) { C.llama_sampling_cfree(s.c) })

return context
Expand All @@ -499,15 +500,10 @@ func (s *SamplingContext) Reset() {
C.llama_sampling_creset(s.c)
}

func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
// TODO (jmorganca): handle nil for all args
if ctxConfig == nil {
return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
}

return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
func (s *SamplingContext) Sample(ctxMain *Context, idx int) int {
return int(C.llama_sampling_csample(s.c, ctxMain.c, C.int(idx)))
}

func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
C.llama_sampling_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
}
8 changes: 4 additions & 4 deletions llama/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen

var sc *llama.SamplingContext
if params.samplingParams != nil {
sc = llama.NewSamplingContext(*params.samplingParams)
sc = llama.NewSamplingContext(s.model, *params.samplingParams)
for _, input := range inputs {
if input.embed == nil {
sc.Accept(s.lc, input.token, false)
sc.Accept(input.token, false)
}
}
}
Expand Down Expand Up @@ -429,8 +429,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}

// sample a token
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true)
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
seq.samplingCtx.Accept(token, true)
piece := s.model.TokenToPiece(token)

seq.numPredicted++
Expand Down
28 changes: 14 additions & 14 deletions llama/sampling_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
#include "sampling.h"
#include "sampling_ext.h"

struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params)
struct llama_sampler *llama_sampling_cinit(
const struct llama_model *model, struct llama_sampling_cparams *params)
{
llama_sampling_params sparams;
gpt_sampler_params sparams;
sparams.top_k = params->top_k;
sparams.top_p = params->top_p;
sparams.min_p = params->min_p;
sparams.tfs_z = params->tfs_z;
sparams.typical_p = params->typical_p;
sparams.typ_p = params->typical_p;
sparams.temp = params->temp;
sparams.penalty_last_n = params->penalty_last_n;
sparams.penalty_repeat = params->penalty_repeat;
Expand All @@ -21,33 +22,32 @@ struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparam
sparams.penalize_nl = params->penalize_nl;
sparams.seed = params->seed;
sparams.grammar = params->grammar;
return llama_sampling_init(sparams);
return (llama_sampler*)gpt_sampler_init(model, sparams);
}

void llama_sampling_cfree(struct llama_sampling_context *ctx)
void llama_sampling_cfree(struct llama_sampler *sampler)
{
llama_sampling_free(ctx);
gpt_sampler_free((gpt_sampler*)sampler);
}

void llama_sampling_creset(struct llama_sampling_context *ctx)
void llama_sampling_creset(struct llama_sampler *sampler)
{
llama_sampling_reset(ctx);
gpt_sampler_reset((gpt_sampler*)sampler);
}

llama_token llama_sampling_csample(
struct llama_sampling_context *ctx_sampling,
struct llama_sampler *sampler,
struct llama_context *ctx_main,
struct llama_context *ctx_cfg,
int idx)
{
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx);
// TODO (ggoodhart): Do we need to support grammar_first?
return gpt_sampler_sample((gpt_sampler*)sampler, ctx_main, idx);
}

void llama_sampling_caccept(
struct llama_sampling_context *ctx_sampling,
struct llama_context *ctx_main,
struct llama_sampler *sampler,
llama_token id,
bool apply_grammar)
{
llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar);
gpt_sampler_accept((gpt_sampler*)sampler, id, apply_grammar);
}
14 changes: 7 additions & 7 deletions llama/sampling_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ extern "C"
char *grammar;
};

struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params);
void llama_sampling_cfree(struct llama_sampling_context *ctx);
void llama_sampling_creset(struct llama_sampling_context *ctx);
struct llama_sampler *llama_sampling_cinit(
const struct llama_model *model,
struct llama_sampling_cparams *params);
void llama_sampling_cfree(struct llama_sampler *sampler);
void llama_sampling_creset(struct llama_sampler *sampler);

llama_token llama_sampling_csample(
struct llama_sampling_context *ctx_sampling,
struct llama_sampler *sampler,
struct llama_context *ctx_main,
struct llama_context *ctx_cfg,
int idx);

void llama_sampling_caccept(
struct llama_sampling_context *ctx_sampling,
struct llama_context *ctx_main,
struct llama_sampler *sampler,
llama_token id,
bool apply_grammar);

Expand Down

0 comments on commit 1b70cde

Please sign in to comment.