diff --git a/src/app.cpp b/src/app.cpp index f4a3b4b..ebd3381 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -41,6 +41,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { args.steps = 0; args.seed = (unsigned long long)time(NULL); args.chatTemplateType = TEMPLATE_UNKNOWN; + args.maxSeqLen = 0; args.useDiscForKvCache = false; int i = 1; @@ -99,6 +100,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { args.seed = atoll(value); } else if (strcmp(name, "--chat-template") == 0) { args.chatTemplateType = parseChatTemplateType(value); + } else if (strcmp(name, "--max-seq-len") == 0) { + args.maxSeqLen = (unsigned int)atoi(value); } else if (strcmp(name, "--kv-cache-storage") == 0) { args.useDiscForKvCache = strcmp(value, "disc") == 0; } else { @@ -128,7 +131,7 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts); unsigned int nSlices = args->nWorkers + 1; - TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->weightsFloatType, args->bufferFloatType); + TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->maxSeqLen, args->weightsFloatType, args->bufferFloatType); TransformerArch arch = TransformerArchFactory::create(&spec); Tokenizer tokenizer(args->tokenizerPath, spec.vocabSize); diff --git a/src/app.hpp b/src/app.hpp index e7f239c..ef3b372 100644 --- a/src/app.hpp +++ b/src/app.hpp @@ -34,6 +34,7 @@ class AppArgs { bool benchmark; unsigned long long seed; ChatTemplateType chatTemplateType; + unsigned int maxSeqLen; // worker int port; diff --git a/src/apps/dllama/dllama.cpp b/src/apps/dllama/dllama.cpp index 7d83823..45cd9b9 100644 --- a/src/apps/dllama/dllama.cpp +++ b/src/apps/dllama/dllama.cpp @@ -160,7 +160,7 @@ class Chat { int nInputTokens; tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false); - pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, (int)pos + nInputTokens - 1); + pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1); for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) { inference->infer(inputTokens[i], pos); token = inputTokens[i + 1]; diff --git a/src/commands.cpp b/src/commands.cpp index 26cb1c6..ed1c5af 100644 --- a/src/commands.cpp +++ b/src/commands.cpp @@ -134,7 +134,7 @@ LlamaRopeCommand::LlamaRopeCommand(RopeSlice *slice) { size_t cacheBytes = slice->seqLen * slice->sliceDim * sizeof(float); cache = (float*)newBuffer(cacheBytes); - printf("🕒 ropeCache: %ld kB\n", cacheBytes / 1024); + printf("🕒 ropeCacheSize: %ld kB\n", cacheBytes / 1024); for (pos_t pos = 0; pos < slice->seqLen; pos++) { for (unsigned int i = slice->kvDimStart; i < slice->qDimEnd; i += 2) { diff --git a/src/transformer.cpp b/src/transformer.cpp index 92040a4..6583815 100644 --- a/src/transformer.cpp +++ b/src/transformer.cpp @@ -9,7 +9,7 @@ #define IS_ROOT_SLICE(sliceIndex) (sliceIndex == 0) -TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) { +TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, const unsigned int maxSeqLen, FloatType weightsFloatType, FloatType bufferFloatType) { TransformerSpec spec; memset(&spec, 0, sizeof(TransformerSpec)); spec.hiddenAct = SILU; @@ -95,6 +95,10 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i } } + spec.origSeqLen = spec.seqLen; + if (maxSeqLen > 0 && spec.seqLen > maxSeqLen) { + spec.seqLen = maxSeqLen; + } spec.headSize = spec.dim / spec.nHeads; spec.kvDim = (spec.dim * spec.nKvHeads) / spec.nHeads; spec.weightsFloatType = weightsFloatType; @@ -131,6 +135,9 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i printf("💡 nActiveExperts: %d\n", spec.nActiveExperts); } printf("💡 vocabSize: %d\n", spec.vocabSize); + if (spec.seqLen != spec.origSeqLen) { + printf("💡 origSeqLen: %d\n", spec.origSeqLen); + } printf("💡 seqLen: %d\n", spec.seqLen); printf("💡 nSlices: %d\n", spec.nSlices); printf("💡 ropeTheta: %.1f\n", spec.ropeTheta); diff --git a/src/transformer.hpp b/src/transformer.hpp index d7b296e..e1a2849 100644 --- a/src/transformer.hpp +++ b/src/transformer.hpp @@ -71,7 +71,8 @@ struct TransformerSpec { int nKvHeads; int nExperts; int nActiveExperts; - int seqLen; + unsigned int origSeqLen; // Original model context length + unsigned int seqLen; // Limited context length by the `--max-seq-len` argument int hiddenDim; TransformerHiddenAct hiddenAct; int kvDim; @@ -197,7 +198,7 @@ class Transformer { ~Transformer(); - static TransformerSpec loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType); + static TransformerSpec loadSpecFromFile(const char* path, const unsigned int nSlices, const unsigned int maxSeqLen, FloatType weightsFloatType, FloatType bufferFloatType); static Transformer loadRootFromFile(const char* path, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool); static Transformer loadRoot(char* data, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool); static Transformer loadSlice(TransformerSpec* spec, TransformerConfig* config, Socket* socket);