Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi 3 medium/small support #7439

Closed
bartowski1182 opened this issue May 21, 2024 · 38 comments
Closed

Phi 3 medium/small support #7439

bartowski1182 opened this issue May 21, 2024 · 38 comments
Labels
enhancement New feature or request stale

Comments

@bartowski1182
Copy link
Contributor

bartowski1182 commented May 21, 2024

2 new models released from Microsoft:

https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/

https://huggingface.co/microsoft/Phi-3-small-8k-instruct/

Medium uses Phi3ForCausalLM and converts without issue, but when trying to generate has an invalid tensor shape:

llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_qkv.weight' has wrong shape; expected 5120, 15360, got 5120, 7680, 1, 1

And then Small uses a new Architecture tag 'Phi3SmallForCausalLM'

@bartowski1182 bartowski1182 added the enhancement New feature or request label May 21, 2024
@dillfrescott
Copy link

Happens with the 128k variants as well. I tried both!

@bartowski1182
Copy link
Contributor Author

woops, thanks, forgot and got lazy with pasting links lol

@ggerganov
Copy link
Owner

Try all models using #7225 and report any issues

@bartowski1182
Copy link
Contributor Author

Building now, will report any updates

Also will try running created quants with those changes just to see if it works (need these changes for imatrix though of course)

@qnixsynapse
Copy link
Contributor

Will llamacpp work with blocksparse attention? These models seems to implement it.

@CyberTimon
Copy link

Do we also get vision support for Phi-3-Vision? I don't know how much this diverses from other archs like LLaVA.

@ggerganov
Copy link
Owner

ggerganov commented May 21, 2024

Will llamacpp work with blocksparse attention? These models seems to implement it.

Could you remind me what was blocksparse attention?

Edit: No, there is no API for that atm.

@qnixsynapse
Copy link
Contributor

@ggerganov https://facebookresearch.github.io/xformers/tutorials/blocksparse.html

@ggerganov
Copy link
Owner

@qnixsynapse Is this technique actually used in practice? I don't see how one would choose the attention mask in a reasonable way without the LLM "forgetting" important bits from the context

@qnixsynapse
Copy link
Contributor

Normally, I will not prefer it, however, I am seeing this which caught my attention.

@tristandruyen
Copy link
Contributor

tristandruyen commented May 21, 2024

I tried it with #7225 using the 128k variants:

microsoft/Phi-3-medium-128k-instruct:

  • bf16 gguf conversion works
  • basic inference works (with bf16)
  • imatrix gen seems to work (still running, but already saves checkpoints)
  • quanting works
./llama-server --chat-template phi3 -m ../../models/Phi-3-medium-128k-instruct-iMat-GGUF/phi-3-medium-128k-instruct-bf16.gguf &

curl http://localhost:8080/v1/chat/completions \
                          -H "Content-Type: application/json" \
                          -d '{
            "messages": [
              {
                "role": "system",
                "content": "You are a helpful assistant."
              },
              {
                "role": "user",
                "content": "What is 2+2 ?"
              }
            ]
         }'
         
{"choices":[{"finish_reason":"stop","index":0,"message":{"content":" The result of 2+2 is 4.","role":"assistant"}}],"created":1716312775,"model":"model_name","object":"chat.completion","usage":{"completion_tokens":12,"prompt_tokens":18,"total_tokens":30},"id":"chatcmpl-KvOpfd64IzSt8DR7h3smbyP7PyVc6xPG"}

microsoft/Phi-3-small-128k-instruct:

bf16 gguf creation still fails with:

INFO:hf-to-gguf:Loading model: Phi-3-small-128k-instruct
Traceback (most recent call last):
File "/home/tristand/ai/tools/llama.cpp/convert-hf-to-gguf.py", line 2585, in <module>
  main()
File "/home/tristand/ai/tools/llama.cpp/convert-hf-to-gguf.py", line 2563, in main
  model_class = Model.from_model_architecture(hparams["architectures"][0])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tristand/ai/tools/llama.cpp/convert-hf-to-gguf.py", line 370, in from_model_architecture
  raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
NotImplementedError: Architecture 'Phi3SmallForCausalLM' not supported!

I tried the dumb fix for the Phi3SmallForCausalLm not supported:

diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 06c89e23..9d6f861a 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1685,7 +1685,7 @@ class Phi2Model(Model):
         self.gguf_writer.add_add_bos_token(False)


-@Model.register("Phi3ForCausalLM")
+@Model.register("Phi3ForCausalLM", "Phi3SmallForCausalLM")
 class Phi3MiniModel(Model):
     model_arch = gguf.MODEL_ARCH.PHI3

And now it fails with:

INFO:hf-to-gguf:Set model parameters
Traceback (most recent call last):
  File "/home/tristand/ai/tools/llama.cpp-fix-phi/convert-hf-to-gguf.py", line 2585, in <module>
    main()
  File "/home/tristand/ai/tools/llama.cpp-fix-phi/convert-hf-to-gguf.py", line 2567, in main
    model_instance.set_gguf_parameters()
  File "/home/tristand/ai/tools/llama.cpp-fix-phi/convert-hf-to-gguf.py", line 1791, in set_gguf_parameters
    rms_eps = self.find_hparam(["rms_norm_eps"])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tristand/ai/tools/llama.cpp-fix-phi/convert-hf-to-gguf.py", line 113, in find_hparam
    raise KeyError(f"could not find any of: {keys}")
KeyError: "could not find any of: ['rms_norm_eps']"

@bartowski1182
Copy link
Contributor Author

using that PR imatrix works which should likely imply that generation will work, old created ones don't so any that are floating out there without the PR will be broken

@Galunid
Copy link
Collaborator

Galunid commented May 21, 2024

It seems to kinda work for Phi3-medium-128k with #7225, but it breaks along the way with partial offloading (see last line)

Phi3-medium-128k run
$ ./main -m models-local/Phi-3-medium-128k-instruct/ggml-model-Q4_K_M.gguf -i --in-prefix "<|user|>\n" --in-suffix "<|end|>\n<|assistant|>" --interactive-first -ngl 15
Log start
main: build = 2977 (92711138)
main: built with cc (GCC) 14.1.1 20240507 for x86_64-pc-linux-gnu
main: seed  = 1716314002
llama_model_loader: loaded meta data with 27 key-value pairs and 245 tensors from models-local/Phi-3-medium-128k-instruct/ggml-model-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = phi3
llama_model_loader: - kv   1:                               general.name str              = Phi3
llama_model_loader: - kv   2:                        phi3.context_length u32              = 131072
llama_model_loader: - kv   3:  phi3.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv   4:                      phi3.embedding_length u32              = 5120
llama_model_loader: - kv   5:                   phi3.feed_forward_length u32              = 17920
llama_model_loader: - kv   6:                           phi3.block_count u32              = 40
llama_model_loader: - kv   7:                  phi3.attention.head_count u32              = 40
llama_model_loader: - kv   8:               phi3.attention.head_count_kv u32              = 10
llama_model_loader: - kv   9:      phi3.attention.layer_norm_rms_epsilon f32              = 0,000010
llama_model_loader: - kv  10:                  phi3.rope.dimension_count u32              = 128
llama_model_loader: - kv  11:                        phi3.rope.freq_base f32              = 10000,000000
llama_model_loader: - kv  12:                          general.file_type u32              = 15
llama_model_loader: - kv  13:              phi3.rope.scaling.attn_factor f32              = 1,190238
llama_model_loader: - kv  14:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  15:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  16:                      tokenizer.ggml.tokens arr[str,32064]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  17:                      tokenizer.ggml.scores arr[f32,32064]   = [-1000,000000, -1000,000000, -1000,00...
llama_model_loader: - kv  18:                  tokenizer.ggml.token_type arr[i32,32064]   = [3, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  19:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  20:                tokenizer.ggml.eos_token_id u32              = 32000
llama_model_loader: - kv  21:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  22:            tokenizer.ggml.padding_token_id u32              = 32000
llama_model_loader: - kv  23:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  24:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  25:                    tokenizer.chat_template str              = {% for message in messages %}{% if (m...
llama_model_loader: - kv  26:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   83 tensors
llama_model_loader: - type q4_K:  101 tensors
llama_model_loader: - type q5_K:   40 tensors
llama_model_loader: - type q6_K:   21 tensors
llm_load_vocab: special tokens definition check successful ( 323/32064 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = phi3
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32064
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 5120
llm_load_print_meta: n_head           = 40
llm_load_print_meta: n_head_kv        = 10
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1280
llm_load_print_meta: n_embd_v_gqa     = 1280
llm_load_print_meta: f_norm_eps       = 0,0e+00
llm_load_print_meta: f_norm_rms_eps   = 1,0e-05
llm_load_print_meta: f_clamp_kqv      = 0,0e+00
llm_load_print_meta: f_max_alibi_bias = 0,0e+00
llm_load_print_meta: f_logit_scale    = 0,0e+00
llm_load_print_meta: n_ff             = 17920
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 2
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000,0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = Q4_K - Medium
llm_load_print_meta: model params     = 13,96 B
llm_load_print_meta: model size       = 7,98 GiB (4,91 BPW) 
llm_load_print_meta: general.name     = Phi3
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 32000 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 32000 '<|endoftext|>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_print_meta: EOT token        = 32007 '<|end|>'
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size =    0,28 MiB
llm_load_tensors: offloading 15 repeating layers to GPU
llm_load_tensors: offloaded 15/41 layers to GPU
llm_load_tensors:        CPU buffer size =  8169,25 MiB
llm_load_tensors:      CUDA0 buffer size =  3016,11 MiB
...............................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000,0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    62,50 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =    37,50 MiB
llama_new_context_with_model: KV self size  =  100,00 MiB, K (f16):   50,00 MiB, V (f16):   50,00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0,12 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   231,06 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    20,98 MiB
llama_new_context_with_model: graph nodes  = 1606
llama_new_context_with_model: graph splits = 274

system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
main: interactive mode on.
Input prefix: '<|user|>\n'
Input suffix: '<|end|>\n<|assistant|>'
sampling: 
	repeat_last_n = 64, repeat_penalty = 1,000, frequency_penalty = 0,000, presence_penalty = 0,000
	top_k = 40, tfs_z = 1,000, top_p = 0,950, min_p = 0,050, typical_p = 1,000, temp = 0,800
	mirostat = 0, mirostat_lr = 0,100, mirostat_ent = 5,000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = -1, n_keep = 0


== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to LLaMa.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.

<s><|user|>\nWho is the president of Poland?
<|end|>\n<|assistant|> As of my knowledge cutoff in 2e, the president of Poland is Andrzej Duda. He has been in office since August 6, 2015.<|end|>
<|user|>\nWrite a short articles on why dolphins are better pets than llamas
<|end|>\n<|assistant|> Title: Dolphins: The Superior Companions


Dolphins have long captivated the human imagination with their intelligence, playful behavior, and social nature. However, while some people may fantasize about dolphins as pets, it's crucial to acknowledge the significant differences between dolphins and more traditional pets like llamas. This article aims to explore these differences and argue why dolphins, despite their challenges, could be considered better companions.


First and forem least, dolphins are undeniably intelligent creatures. Scientists have discovered that dolphins possess self-awareness, problem-solving skills, and the ability to communicate using complex sounds and body language. These attributes make dolphins exceptionally engaging and interactive companions, unlike llamas, which, while intelligent in their ways, do not exhibit the same level of cognitive abilities.


Secondly, dolphins are known for their playfulness and affectionate nature. They often engage in activities that display joy, such as leaping from the water and riding waves. While llamas are known to be friendly and can develop bonds with humans, they do not exhibit the same level of playfulness as dolphins.


Thirdly, dolphins' social nature sets them apart from llamas. Dolphins typically live in groups called pods, which means they thrive in a social environment. They engage in cooperative behaviors such as hunting, protecting each other, and raising their young together. Llamas, on the other hand, are more independent animals that usually prefer to keep to themselves or travel in smaller groups.


Despite these reasons, it is essential to note the challenges of owning dolphins as pets. Dolphins have a long lifespan, often reaching over 20 years, and require a large, aquatic habitat to thrive. Furthermore, dolphins are protected under the aqua,ers's,�000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000<|user|>\n

@slaren
Copy link
Collaborator

slaren commented May 21, 2024

Try increasing the context, you are using only 512 so there are probably shifts happening.

@Galunid
Copy link
Collaborator

Galunid commented May 21, 2024

That did the trick

@duynt575
Copy link

tried with https://github.com/ggerganov/llama.cpp/pull/7225 and it worked but only with this version (PR 7225):

  • creating f16 worked
  • creating imatrix worked
  • creating quants worked (tested with q4_0 , iq3_s and iq1_s)
  • inferencing worked

But if you use main from latest version to run gguf files created with this version, it will show this error:

llama_model_load: error loading model: done_getting_tensors: wrong number of tensors; expected 245, got 243

@CrispStrobe
Copy link
Contributor

CrispStrobe commented May 21, 2024

with what model did this work for you? with convert-hf-to-gguf.py? probably medium? or if small, how did you get around tokenizer.model?

@sruPL
Copy link

sruPL commented May 21, 2024

Can anyone post working f16?

@bartowski1182
Copy link
Contributor Author

So far so good on the 4k GGUF, it's able to respond to queries which is good enough for me lol

uploaded here:

https://huggingface.co/bartowski/Phi-3-medium-4k-instruct-GGUF

@jamyl
Copy link

jamyl commented May 21, 2024

It's loading now and working great with bartowski/Phi-3-medium-4k-instruct-GGUF/Phi-3-medium-4k-instruct-Q4_K_S.gguf on LM-Studio.
Thanks

@CrispStrobe
Copy link
Contributor

looks great with medium. but small seems to need more work. in a first test, i can override some gguf_parameters and just use self._set_vocab_qwen() and it will convert and quantize, but then it won't run but throw "llama_model_load: error loading model: check_tensor_dims: tensor 'output.weight' not found".

@dillfrescott
Copy link

https://0.0g.gg/?8b6aa2a822f73b75#6dSFckfnxCttPKUX7rX4b35WEdt6woLdK65DTpSWSZ4w

here is an issue ive been running into. Link is a paste of the model just completely imploding in on itself from a basic word problem.

@bartowski1182
Copy link
Contributor Author

bartowski1182 commented May 21, 2024

Btw I think if you're using something like LM studio you aren't getting the right performance

It fails the tokenizer test of 3333+7777, but using the PR ./main gets it right

Likely need to wait for merge and version bump

@dillfrescott
Copy link

I am using a bleeding edge llama.cpp commit and its doing that which is odd...

@MoonRide303
Copy link

Something is wrong with Phi-3-medium-4k-instruct output, I am often getting weird "B:" out of the blue:

image

launched via:

server -v -ngl 99 -m Phi-3-medium-4k-instruct-Q6_K.gguf -c 4096 --chat-template chatml

configuration:
image

using current master (9b3d833).

@tristandruyen
Copy link
Contributor

tristandruyen commented May 22, 2024

@MoonRide303

Something is wrong with Phi-3-medium-4k-instruct output, I am often getting weird "B:" out of the blue:

AFAIK the --chat-template parameter is not used for the server web GUI as it uses the /completions endpoint internally and the --chat-template only applies to the /v1/chat/completions endpoint, you need to set the right template in the Prompt template section of the form manually.

@ThatcherC
Copy link
Contributor

@MoonRide303 @tristandruyen Just FYI - it looks like that --chat-template issue is being worked in this issue - #7432 - which is great because I ran into that same problem with Phi 3 mini! The --chat-completion fix worked in my case but I'm looking forward to the real fix in #7449 being merged.

@tristandruyen
Copy link
Contributor

tristandruyen commented May 22, 2024

@MoonRide303 @tristandruyen Just FYI - it looks like that --chat-template issue is being worked in this issue - #7432 - which is great because I ran into that same problem with Phi 3 mini! The --chat-completion fix worked in my case but I'm looking forward to the real fix in #7449 being merged.

I'm actually the one working on the --chat-template issue in #7449. However, it seems like @MoonRide303's issue is more related to the web ui not using any model specific templates, not that it's using the wrong one.

The fix I'm working on in #7449 aims to improve the auto-detection of the phi3 template, so users won't have to explicitly specify it using the --chat-template flag. This fix will ensure that llama.cpp automatically detects and uses the appropriate template for the model.

It's important to note that the behavior of the endpoints and the web UI will remain unchanged after my fix is merged. The web UI will still not use any model-specific template, just the auto-detection process will be more reliable.

@linxihui
Copy link

Will llamacpp work with blocksparse attention? These models seems to implement it.

Could you remind me what was blocksparse attention?

Edit: No, there is no API for that atm.

@ggerganov thanks for your interest on supporting phi-3-small.

I am the author of the blocksparse attention in phi-3-small. I am not very familiar with ollama, but I could help explain the detail.

The kernel is implemented in Triton, but you can find the code that generates the dense version of attention mask here
https://github.com/linxihui/vllm/blob/eb16d9a382f273c3ed62e4264a42a24f6ba53568/vllm/attention/ops/blocksparse_attention/utils.py#L187C1-L210C57

There is a also a vllm paged attention version
https://github.com/linxihui/vllm/blob/main/csrc/attention/attention_kernels.cu#L224-L236

I tested other models with ollama on my mac, it is super responsive and cool, hope I could have our phi-3-small model on my mac as well!

@ggerganov
Copy link
Owner

@linxihui Thanks for the information. Is my understanding correct that the vllm implementation skips non-attended blocks (i.e. full of -inf) which makes the computation faster? Do you have an estimate of the performance gain if that is the case? This does not lead to lower memory usage, correct?

If my understanding is correct, then I think we can easily support this in the Metal backend as the block-skipping logic is already there

@linxihui
Copy link

linxihui commented May 23, 2024

@linxihui Thanks for the information. Is my understanding correct that the vllm implementation skips non-attended blocks (i.e. full of -inf) which makes the computation faster? Do you have an estimate of the performance gain if that is the case? This does not lead to lower memory usage, correct?

If my understanding is correct, then I think we can easily support this in the Metal backend as the block-skipping logic is already there

@ggerganov

  • The current implementation in vllm doesn't improve memory, that's correct. This is because the block table in vllm isn't per head (it is per toke). Our attention may attend to all tokens (smaller vertical stride) but on different heads. To save memory, we'll need to change a very big portion of the core vllm, which we don't have the time for it. But theoretically, you could save memory with proper implementation.

  • Latency/throughput wise, it is faster, as it skip blocks both in prefilling and decoding. The theoretical flops is 1/vert_stride of dense with very large length. In vllm paged attn, we skip the compute of qk, attn*v. The filling of -"inf" is only for the normalization of the logits. So it should be faster as well. The end2end benefit can only be observed when large length when attention occupies the big proportion, compared to other ops and loading of model weights. E.g., if the model is full of blocksparse attention, it could be more than 4x faster end2end in prefilling with 100k context length.

  • Yes, the logic should be easy to implement in existing profiling and decoding code. But make sure you pay close attention to the head sliding part.

@0wwafa
Copy link

0wwafa commented Jun 30, 2024

still can't convert phi-3-small :( Phi3SmallForCausalLM unsupported :(

@0wwafa
Copy link

0wwafa commented Jun 30, 2024

Copy link
Contributor

This issue was closed because it has been inactive for 14 days since being marked as stale.

@JamesClarke7283
Copy link

Please, can you reopen this, we need phi-3 small.

@vladfaust
Copy link

I agree on that, 4k context is simply not enough.

@Galunid Galunid removed the stale label Aug 20, 2024
@Galunid Galunid reopened this Aug 20, 2024
@vladfaust
Copy link

vladfaust commented Aug 21, 2024

FYI: Microsoft has just released Phi3.5 models, with mini version having 128k context. See https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3. It doesn't have GGUF quants yet, because... because of this issue. Let's get to it! 💪

Edit: Just tested https://huggingface.co/bartowski/Phi-3.5-mini-instruct-GGUF with context size of 8192, works good, which fits my usecase.

@github-actions github-actions bot added the stale label Sep 21, 2024
Copy link
Contributor

github-actions bot commented Oct 6, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request stale
Projects
None yet
Development

No branches or pull requests