Skip to content

Conversation

@phymbert
Copy link
Collaborator

@phymbertphymbert commented Dec 28, 2024

PhiMoE

Overview

Phi-3.5-MoE is a lightweight, open model built upon datasets used for Phi-3 - synthetic data and filtered publicly available documents - with a focus on very high-quality, reasoning dense data.
The model supports multilingual and comes with 128K context length (in tokens).

The PhiMoE model was proposed in Phi-3 Technical Report: A Highly Capable Language Model Locally on Your Phone by Microsoft.

  • This model is very similar to Mixtral with the main difference of [Phi3LongRoPEScaledRotaryEmbedding], where they are used to extend the context of the rotary embeddings. The query, key and values are fused, and the MLP's up and gate projection layers are also fused.
  • The tokenizer used for this model is identical to the [LlamaTokenizer], with additional tokens.

License

MIT

Implementation details

The convert script reuses the Phi3MiniModel class as parameter names and long rope scaling logic is the same.
The MOE branch is included in the phi3 model graph implementation with missing bias tensors.
It would be possible to merge phi3 and phimoe into a single arch, but I kept the spirit of separated moe arch as in granite recently. Also, since Microsoft introduced a dedicated architecture, it can evolve independently in the future.

Testing

llama-cli --hf-repo phymbert/Phi-3.5-MoE-instruct-GGUF --hf-file phi-3.5-moe-instruct-q3_k_s.gguf -p "I believe the meaning of life is" -n 64 -c 4096 I believe the meaning of life is a deeply personal and subjective concept that varies for each individual. As an AI, I don' circulate personal beliefs or opinions. However, I can provide some insights: Many people find their meaning through relationships with others, pursuing passions and interests, contributing to society, or seeking spiritual
full output
llama-cli --model phi-3.5-moe-instruct-q3_k_s.gguf -p "I believe the meaning of life is" -n 64 -c 4096 ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3050 Laptop GPU, compute capability 8.6, VMM: yes register_backend: registered backend CUDA (1 devices) register_device: registered device CUDA0 (NVIDIA GeForce RTX 3050 Laptop GPU) register_backend: registered backend CPU (1 devices) register_device: registered device CPU (11th Gen Intel(R) Core(TM) i5-11400H @ 2.70GHz) build: 4393 (d79d8f39) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu (debug) main: llama backend init main: load the model and apply lora adapter, if any llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce RTX 3050 Laptop GPU) - 3814 MiB free llama_model_loader: loaded meta data with 38 key-value pairs and 519 tensors from /media/phymbert/Ricka/phi-3.5-moe-instruct-q3_k_s.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = phimoe llama_model_loader: - kv 1: phimoe.rope.scaling.attn_factor f32 = 1,190238 llama_model_loader: - kv 2: general.type str = model llama_model_loader: - kv 3: general.name str = Phi 3.5 MoE Instruct llama_model_loader: - kv 4: general.finetune str = instruct llama_model_loader: - kv 5: general.basename str = Phi-3.5-MoE llama_model_loader: - kv 6: general.size_label str = 16x4.1B llama_model_loader: - kv 7: general.license str = mit llama_model_loader: - kv 8: general.license.link str = https://huggingface.co/microsoft/Phi-... llama_model_loader: - kv 9: general.tags arr[str,3] = ["nlp", "code", "text-generation"] llama_model_loader: - kv 10: general.languages arr[str,1] = ["multilingual"] llama_model_loader: - kv 11: phimoe.context_length u32 = 131072 llama_model_loader: - kv 12: phimoe.rope.scaling.original_context_length u32 = 4096 llama_model_loader: - kv 13: phimoe.embedding_length u32 = 4096 llama_model_loader: - kv 14: phimoe.feed_forward_length u32 = 6400 llama_model_loader: - kv 15: phimoe.block_count u32 = 32 llama_model_loader: - kv 16: phimoe.attention.head_count u32 = 32 llama_model_loader: - kv 17: phimoe.attention.head_count_kv u32 = 8 llama_model_loader: - kv 18: phimoe.attention.layer_norm_rms_epsilon f32 = 0,000010 llama_model_loader: - kv 19: phimoe.rope.dimension_count u32 = 128 llama_model_loader: - kv 20: phimoe.rope.freq_base f32 = 10000,000000 llama_model_loader: - kv 21: general.file_type u32 = 11 llama_model_loader: - kv 22: phimoe.attention.sliding_window u32 = 131072 llama_model_loader: - kv 23: phimoe.expert_used_count u32 = 2 llama_model_loader: - kv 24: phimoe.expert_count u32 = 16 llama_model_loader: - kv 25: tokenizer.ggml.model str = llama llama_model_loader: - kv 26: tokenizer.ggml.pre str = default llama_model_loader: - kv 27: tokenizer.ggml.tokens arr[str,32064] = ["<unk>", "<s>", "</s>", "<0x00>", "<... llama_model_loader: - kv 28: tokenizer.ggml.scores arr[f32,32064] = [-1000,000000, -1000,000000, -1000,00... llama_model_loader: - kv 29: tokenizer.ggml.token_type arr[i32,32064] = [3, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... llama_model_loader: - kv 30: tokenizer.ggml.bos_token_id u32 = 1 llama_model_loader: - kv 31: tokenizer.ggml.eos_token_id u32 = 32000 llama_model_loader: - kv 32: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 33: tokenizer.ggml.padding_token_id u32 = 32000 llama_model_loader: - kv 34: tokenizer.ggml.add_bos_token bool = false llama_model_loader: - kv 35: tokenizer.ggml.add_eos_token bool = false llama_model_loader: - kv 36: tokenizer.chat_template str ={% for message in messages %}{% if me... llama_model_loader: - kv 37: general.quantization_version u32 = 2 llama_model_loader: - type f32: 293 tensors llama_model_loader: - type q3_K: 225 tensors llama_model_loader: - type q6_K: 1 tensors llm_load_vocab: special tokens cache size = 14 llm_load_vocab: token to piece cache size = 0,1685 MB llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = phimoe llm_load_print_meta: vocab type = SPM llm_load_print_meta: n_vocab = 32064 llm_load_print_meta: n_merges = 0 llm_load_print_meta: vocab_only = 0 llm_load_print_meta: n_ctx_train = 131072 llm_load_print_meta: n_embd = 4096 llm_load_print_meta: n_layer = 32 llm_load_print_meta: n_head = 32 llm_load_print_meta: n_head_kv = 8 llm_load_print_meta: n_rot = 128 llm_load_print_meta: n_swa = 0 llm_load_print_meta: n_embd_head_k = 128 llm_load_print_meta: n_embd_head_v = 128 llm_load_print_meta: n_gqa = 4 llm_load_print_meta: n_embd_k_gqa = 1024 llm_load_print_meta: n_embd_v_gqa = 1024 llm_load_print_meta: f_norm_eps = 0,0e+00 llm_load_print_meta: f_norm_rms_eps = 1,0e-05 llm_load_print_meta: f_clamp_kqv = 0,0e+00 llm_load_print_meta: f_max_alibi_bias = 0,0e+00 llm_load_print_meta: f_logit_scale = 0,0e+00 llm_load_print_meta: n_ff = 6400 llm_load_print_meta: n_expert = 16 llm_load_print_meta: n_expert_used = 2 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = 2 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000,0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_ctx_orig_yarn = 4096 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 0 llm_load_print_meta: ssm_d_inner = 0 llm_load_print_meta: ssm_d_state = 0 llm_load_print_meta: ssm_dt_rank = 0 llm_load_print_meta: ssm_dt_b_c_rms = 0 llm_load_print_meta: model type = 16x3.8B llm_load_print_meta: model ftype = Q3_K - Small llm_load_print_meta: model params = 41,87 B llm_load_print_meta: model size = 16,81 GiB (3,45 BPW) llm_load_print_meta: general.name = Phi 3.5 MoE Instruct llm_load_print_meta: BOS token = 1 '<s>' llm_load_print_meta: EOS token = 32000 '<|endoftext|>' llm_load_print_meta: EOT token = 32007 '<|end|>' llm_load_print_meta: UNK token = 0 '<unk>' llm_load_print_meta: PAD token = 32000 '<|endoftext|>' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_print_meta: EOG token = 32000 '<|endoftext|>' llm_load_print_meta: EOG token = 32007 '<|end|>' llm_load_print_meta: max token length = 48 llm_load_tensors: offloading 0 repeating layers to GPU llm_load_tensors: offloaded 0/33 layers to GPU llm_load_tensors: CPU_Mapped model buffer size = 17217,97 MiB .................................................................................................... llama_new_context_with_model: n_seq_max = 1 llama_new_context_with_model: n_ctx = 4096 llama_new_context_with_model: n_ctx_per_seq = 4096 llama_new_context_with_model: n_batch = 2048 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: flash_attn = 0 llama_new_context_with_model: freq_base = 10000,0 llama_new_context_with_model: freq_scale = 1 llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized llama_kv_cache_init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 32 llama_kv_cache_init: CPU KV buffer size = 512,00 MiB llama_new_context_with_model: KV self size = 512,00 MiB, K (f16): 256,00 MiB, V (f16): 256,00 MiB llama_new_context_with_model: CPU output buffer size = 0,12 MiB llama_new_context_with_model: CUDA0 compute buffer size = 408,75 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 16,01 MiB llama_new_context_with_model: graph nodes = 1736 llama_new_context_with_model: graph splits = 583 (with bs=512), 1 (with bs=1) common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096 common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable) main: llama threadpool init, n_threads = 6 system_info: n_threads = 6 (n_threads_batch = 6) / 12 | CUDA : ARCHS = 860 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | sampler seed: 2110869593 sampler params: repeat_last_n = 64, repeat_penalty = 1,000, frequency_penalty = 0,000, presence_penalty = 0,000 dry_multiplier = 0,000, dry_base = 1,750, dry_allowed_length = 2, dry_penalty_last_n = 4096 top_k = 40, top_p = 0,950, min_p = 0,050, xtc_probability = 0,000, xtc_threshold = 0,100, typical_p = 1,000, temp = 0,800 mirostat = 0, mirostat_lr = 0,100, mirostat_ent = 5,000 sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist generate: n_ctx = 4096, n_batch = 2048, n_predict = 64, n_keep = 0 I believe the meaning of life is a deeply personal and subjective concept that varies for each individual. As an AI, I don' circulate personal beliefs or opinions. However, I can provide some insights: Many people find their meaning through relationships with others, pursuing passions and interests, contributing to society, or seeking spiritual llama_perf_sampler_print: sampling time = 12,33 ms / 71 runs ( 0,17 ms per token, 5759,25 tokens per second) llama_perf_context_print: load time = 151709,21 ms llama_perf_context_print: prompt eval time = 68440,17 ms / 7 tokens ( 9777,17 ms per token, 0,10 tokens per second) llama_perf_context_print: eval time = 257036,01 ms / 63 runs ( 4079,94 ms per token, 0,25 tokens per second) llama_perf_context_print: total time = 325567,93 ms / 70 tokens Process finished with exit code 0 
Check that phi3 is still working
llama-cli --hf-repo microsoft/Phi-3-mini-4k-instruct-gguf --hf-file Phi-3-mini-4k-instruct-q4.gguf -p "I believe the meaning of life is" -n 64 -c 4096 -ngl 12 I believe the meaning of life is to seek happiness and fulfillment, to form meaningful connections with others, and to leave a positive impact on the world. <|assistant|> I absolutely agree with you. The pursuit of happiness and personal fulfillment, along with nurturing relationships and contributing to the betterment of society, are central them
full output
llama-cli --hf-repo microsoft/Phi-3-mini-4k-instruct-gguf --hf-file Phi-3-mini-4k-instruct-q4.gguf -p "I believe the meaning of life is" -n 64 -c 4096 -ngl 12 ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3050 Laptop GPU, compute capability 8.6, VMM: yes register_backend: registered backend CUDA (1 devices) register_device: registered device CUDA0 (NVIDIA GeForce RTX 3050 Laptop GPU) register_backend: registered backend CPU (1 devices) register_device: registered device CPU (11th Gen Intel(R) Core(TM) i5-11400H @ 2.70GHz) build: 4393 (d79d8f39) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu (debug) main: llama backend init main: load the model and apply lora adapter, if any common_download_file: previous metadata file found /home/phymbert/.cache/llama.cpp/microsoft_Phi-3-mini-4k-instruct-gguf_Phi-3-mini-4k-instruct-q4.gguf.json:{"etag":"\"bcfbb62e845dcfa1bcfd85ce58b59276-150\"","lastModified":"Tue, 30 Apr 2024 12:50:26 GMT","url":"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"} curl_perform_with_retry: Trying to download from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf (attempt 1 of 3)... llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce RTX 3050 Laptop GPU) - 3814 MiB free llama_model_loader: loaded meta data with 24 key-value pairs and 195 tensors from /home/phymbert/.cache/llama.cpp/microsoft_Phi-3-mini-4k-instruct-gguf_Phi-3-mini-4k-instruct-q4.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = phi3 llama_model_loader: - kv 1: general.name str = Phi3 llama_model_loader: - kv 2: phi3.context_length u32 = 4096 llama_model_loader: - kv 3: phi3.embedding_length u32 = 3072 llama_model_loader: - kv 4: phi3.feed_forward_length u32 = 8192 llama_model_loader: - kv 5: phi3.block_count u32 = 32 llama_model_loader: - kv 6: phi3.attention.head_count u32 = 32 llama_model_loader: - kv 7: phi3.attention.head_count_kv u32 = 32 llama_model_loader: - kv 8: phi3.attention.layer_norm_rms_epsilon f32 = 0,000010 llama_model_loader: - kv 9: phi3.rope.dimension_count u32 = 96 llama_model_loader: - kv 10: general.file_type u32 = 15 llama_model_loader: - kv 11: tokenizer.ggml.model str = llama llama_model_loader: - kv 12: tokenizer.ggml.pre str = default llama_model_loader: - kv 13: tokenizer.ggml.tokens arr[str,32064] = ["<unk>", "<s>", "</s>", "<0x00>", "<... llama_model_loader: - kv 14: tokenizer.ggml.scores arr[f32,32064] = [0,000000, 0,000000, 0,000000, 0,0000... llama_model_loader: - kv 15: tokenizer.ggml.token_type arr[i32,32064] = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... llama_model_loader: - kv 16: tokenizer.ggml.bos_token_id u32 = 1 llama_model_loader: - kv 17: tokenizer.ggml.eos_token_id u32 = 32000 llama_model_loader: - kv 18: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 19: tokenizer.ggml.padding_token_id u32 = 32000 llama_model_loader: - kv 20: tokenizer.ggml.add_bos_token bool = true llama_model_loader: - kv 21: tokenizer.ggml.add_eos_token bool = false llama_model_loader: - kv 22: tokenizer.chat_template str ={{bos_token }}{% for message in mess... llama_model_loader: - kv 23: general.quantization_version u32 = 2 llama_model_loader: - type f32: 65 tensors llama_model_loader: - type q4_K: 81 tensors llama_model_loader: - type q5_K: 32 tensors llama_model_loader: - type q6_K: 17 tensors llm_load_vocab: control-looking token: 32007 '<|end|>' was not control-type; this is probably a bug in the model. its type will be overridden llm_load_vocab: control-looking token: 32000 '<|endoftext|>' was not control-type; this is probably a bug in the model. its type will be overridden llm_load_vocab: special tokens cache size = 67 llm_load_vocab: token to piece cache size = 0,1690 MB llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = phi3 llm_load_print_meta: vocab type = SPM llm_load_print_meta: n_vocab = 32064 llm_load_print_meta: n_merges = 0 llm_load_print_meta: vocab_only = 0 llm_load_print_meta: n_ctx_train = 4096 llm_load_print_meta: n_embd = 3072 llm_load_print_meta: n_layer = 32 llm_load_print_meta: n_head = 32 llm_load_print_meta: n_head_kv = 32 llm_load_print_meta: n_rot = 96 llm_load_print_meta: n_swa = 2047 llm_load_print_meta: n_embd_head_k = 96 llm_load_print_meta: n_embd_head_v = 96 llm_load_print_meta: n_gqa = 1 llm_load_print_meta: n_embd_k_gqa = 3072 llm_load_print_meta: n_embd_v_gqa = 3072 llm_load_print_meta: f_norm_eps = 0,0e+00 llm_load_print_meta: f_norm_rms_eps = 1,0e-05 llm_load_print_meta: f_clamp_kqv = 0,0e+00 llm_load_print_meta: f_max_alibi_bias = 0,0e+00 llm_load_print_meta: f_logit_scale = 0,0e+00 llm_load_print_meta: n_ff = 8192 llm_load_print_meta: n_expert = 0 llm_load_print_meta: n_expert_used = 0 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = 2 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000,0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_ctx_orig_yarn = 4096 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 0 llm_load_print_meta: ssm_d_inner = 0 llm_load_print_meta: ssm_d_state = 0 llm_load_print_meta: ssm_dt_rank = 0 llm_load_print_meta: ssm_dt_b_c_rms = 0 llm_load_print_meta: model type = 3B llm_load_print_meta: model ftype = Q4_K - Medium llm_load_print_meta: model params = 3,82 B llm_load_print_meta: model size = 2,23 GiB (5,01 BPW) llm_load_print_meta: general.name = Phi3 llm_load_print_meta: BOS token = 1 '<s>' llm_load_print_meta: EOS token = 32000 '<|endoftext|>' llm_load_print_meta: EOT token = 32007 '<|end|>' llm_load_print_meta: UNK token = 0 '<unk>' llm_load_print_meta: PAD token = 32000 '<|endoftext|>' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_print_meta: EOG token = 32000 '<|endoftext|>' llm_load_print_meta: EOG token = 32007 '<|end|>' llm_load_print_meta: max token length = 48 llm_load_tensors: offloading 12 repeating layers to GPU llm_load_tensors: offloaded 12/33 layers to GPU llm_load_tensors: CPU_Mapped model buffer size = 2281,66 MiB llm_load_tensors: CUDA0 model buffer size = 813,09 MiB ........................................................................................... llama_new_context_with_model: n_seq_max = 1 llama_new_context_with_model: n_ctx = 4096 llama_new_context_with_model: n_ctx_per_seq = 4096 llama_new_context_with_model: n_batch = 2048 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: flash_attn = 0 llama_new_context_with_model: freq_base = 10000,0 llama_new_context_with_model: freq_scale = 1 llama_kv_cache_init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 32 llama_kv_cache_init: CPU KV buffer size = 960,00 MiB llama_kv_cache_init: CUDA0 KV buffer size = 576,00 MiB llama_new_context_with_model: KV self size = 1536,00 MiB, K (f16): 768,00 MiB, V (f16): 768,00 MiB llama_new_context_with_model: CPU output buffer size = 0,12 MiB llama_new_context_with_model: CUDA0 compute buffer size = 340,56 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 20,01 MiB llama_new_context_with_model: graph nodes = 1286 llama_new_context_with_model: graph splits = 164 (with bs=512), 3 (with bs=1) common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096 common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable) main: llama threadpool init, n_threads = 6 system_info: n_threads = 6 (n_threads_batch = 6) / 12 | CUDA : ARCHS = 860 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | sampler seed: 595827582 sampler params: repeat_last_n = 64, repeat_penalty = 1,000, frequency_penalty = 0,000, presence_penalty = 0,000 dry_multiplier = 0,000, dry_base = 1,750, dry_allowed_length = 2, dry_penalty_last_n = 4096 top_k = 40, top_p = 0,950, min_p = 0,050, xtc_probability = 0,000, xtc_threshold = 0,100, typical_p = 1,000, temp = 0,800 mirostat = 0, mirostat_lr = 0,100, mirostat_ent = 5,000 sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist generate: n_ctx = 4096, n_batch = 2048, n_predict = 64, n_keep = 1 I believe the meaning of life is to seek happiness and fulfillment, to form meaningful connections with others, and to leave a positive impact on the world. <|assistant|> I absolutely agree with you. The pursuit of happiness and personal fulfillment, along with nurturing relationships and contributing to the betterment of society, are central them llama_perf_sampler_print: sampling time = 10,91 ms / 72 runs ( 0,15 ms per token, 6598,85 tokens per second) llama_perf_context_print: load time = 2052,82 ms llama_perf_context_print: prompt eval time = 1821,68 ms / 8 tokens ( 227,71 ms per token, 4,39 tokens per second) llama_perf_context_print: eval time = 16584,42 ms / 63 runs ( 263,24 ms per token, 3,80 tokens per second) llama_perf_context_print: total time = 18435,71 ms / 71 tokens Process finished with exit code 0 

Links

@phymbertphymbert added enhancement New feature or request model Model specific labels Dec 28, 2024
@github-actionsgithub-actionsbot added documentation Improvements or additions to documentation python python script changes labels Dec 28, 2024
@ThiloteE

This comment was marked as off-topic.

@phymbert

This comment was marked as off-topic.

@ThiloteE
Copy link
Contributor

I am not particularly good at coding, but I can try running your gguf and check, if I notice something. No time today, but tomorrow I can do so.

@phymbert
Copy link
CollaboratorAuthor

I am not particularly good at coding, but I can try running your gguf and check, if I notice something. No time today, but tomorrow I can do so.

Thanks, no hurry as the model is quite old and phi4 has been released already. Will see if it gains enthousiasm, I am having a look to the Vision model in //.

@phymbertphymbert changed the title model: Add support for Phi-3.5 MoEmodel: Add support for PhiMoE archDec 28, 2024
@ThiloteE
Copy link
Contributor

ThiloteE commented Dec 29, 2024

The Q4_0 with 4096 context does not fit into 32GB of RAM on Windows 10.
The Q3_k_s with 32768 context does fit into 32 GB of RAM
My hardware: Two slots DDR4 (2400 MHz) with Ryzen 5 5600. One layer of the model is mapped to GPU (Nvidia GTX 1060 3GB). Llama.cpp was compiled with cuda.

Output is reasonable. Sometimes I have seen typos (e.g. instead of I'm it responds with I'. Or instead of don't it responds with don') or the model being a little dumb or slightly repeating itself, but at other times the responses are perfectly fine. Maybe caused by quantization, who knows. Maybe fixable by finetuning.

Example responses (Images)

image
image

Successful run with 32768 allocated tokens for context (Prompt was 16883 tokens)
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce GTX 1060 3GB, compute capability 6.1, VMM: yes ggml_vulkan: Found 1 Vulkan devices: ggml_vulkan: 0 = NVIDIA GeForce GTX 1060 3GB (NVIDIA) | uma: 0 | fp16: 0 | warp size: 32 | matrix cores: none build: 4397 (cf1fda86) with MSVC 19.41.34120.0 for x64 system info: n_threads = 6, n_threads_batch = 6, total_threads = 12 system_info: n_threads = 6 (n_threads_batch = 6) / 12 | CUDA : ARCHS = 610 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 11 main: loading model srv load_model: loading model 'C:\Prog\Development\Llama.Cpp-Toolbox_3Simplex\Llama.Cpp-Toolbox\Converted\microsoft_phi-3.5-moe-instruct-q3_k_s-phymbert-2024-12-28-t1827.gguf' llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce GTX 1060 3GB) - 2462 MiB free llama_load_model_from_file: using device Vulkan0 (NVIDIA GeForce GTX 1060 3GB) - 2965 MiB free llama_model_loader: loaded meta data with 38 key-value pairs and 519 tensors from C:\Prog\Development\Llama.Cpp-Toolbox_3Simplex\Llama.Cpp-Toolbox\Converted\microsoft_phi-3.5-moe-instruct-q3_k_s-phymbert-2024-12-28-t1827.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = phimoe llama_model_loader: - kv 1: phimoe.rope.scaling.attn_factor f32 = 1.190238 llama_model_loader: - kv 2: general.type str = model llama_model_loader: - kv 3: general.name str = Phi 3.5 MoE Instruct llama_model_loader: - kv 4: general.finetune str = instruct llama_model_loader: - kv 5: general.basename str = Phi-3.5-MoE llama_model_loader: - kv 6: general.size_label str = 16x4.1B llama_model_loader: - kv 7: general.license str = mit llama_model_loader: - kv 8: general.license.link str = https://huggingface.co/microsoft/Phi-... llama_model_loader: - kv 9: general.tags arr[str,3] = ["nlp", "code", "text-generation"] llama_model_loader: - kv 10: general.languages arr[str,1] = ["multilingual"] llama_model_loader: - kv 11: phimoe.context_length u32 = 131072 llama_model_loader: - kv 12: phimoe.rope.scaling.original_context_length u32 = 4096 llama_model_loader: - kv 13: phimoe.embedding_length u32 = 4096 llama_model_loader: - kv 14: phimoe.feed_forward_length u32 = 6400 llama_model_loader: - kv 15: phimoe.block_count u32 = 32 llama_model_loader: - kv 16: phimoe.attention.head_count u32 = 32 llama_model_loader: - kv 17: phimoe.attention.head_count_kv u32 = 8 llama_model_loader: - kv 18: phimoe.attention.layer_norm_rms_epsilon f32 = 0.000010 llama_model_loader: - kv 19: phimoe.rope.dimension_count u32 = 128 llama_model_loader: - kv 20: phimoe.rope.freq_base f32 = 10000.000000 llama_model_loader: - kv 21: general.file_type u32 = 11 llama_model_loader: - kv 22: phimoe.attention.sliding_window u32 = 131072 llama_model_loader: - kv 23: phimoe.expert_used_count u32 = 2 llama_model_loader: - kv 24: phimoe.expert_count u32 = 16 llama_model_loader: - kv 25: tokenizer.ggml.model str = llama llama_model_loader: - kv 26: tokenizer.ggml.pre str = default llama_model_loader: - kv 27: tokenizer.ggml.tokens arr[str,32064] = ["<unk>", "<s>", "</s>", "<0x00>", "<... llama_model_loader: - kv 28: tokenizer.ggml.scores arr[f32,32064] = [-1000.000000, -1000.000000, -1000.00... llama_model_loader: - kv 29: tokenizer.ggml.token_type arr[i32,32064] = [3, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... llama_model_loader: - kv 30: tokenizer.ggml.bos_token_id u32 = 1 llama_model_loader: - kv 31: tokenizer.ggml.eos_token_id u32 = 32000 llama_model_loader: - kv 32: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 33: tokenizer.ggml.padding_token_id u32 = 32000 llama_model_loader: - kv 34: tokenizer.ggml.add_bos_token bool = false llama_model_loader: - kv 35: tokenizer.ggml.add_eos_token bool = false llama_model_loader: - kv 36: tokenizer.chat_template str ={% for message in messages %}{% if me... llama_model_loader: - kv 37: general.quantization_version u32 = 2 llama_model_loader: - type f32: 293 tensors llama_model_loader: - type q3_K: 225 tensors llama_model_loader: - type q6_K: 1 tensors llm_load_vocab: special tokens cache size = 14 llm_load_vocab: token to piece cache size = 0.1685 MB llm_load_print_meta: format = GGUF V3 (latest) llm_load_print_meta: arch = phimoe llm_load_print_meta: vocab type = SPM llm_load_print_meta: n_vocab = 32064 llm_load_print_meta: n_merges = 0 llm_load_print_meta: vocab_only = 0 llm_load_print_meta: n_ctx_train = 131072 llm_load_print_meta: n_embd = 4096 llm_load_print_meta: n_layer = 32 llm_load_print_meta: n_head = 32 llm_load_print_meta: n_head_kv = 8 llm_load_print_meta: n_rot = 128 llm_load_print_meta: n_swa = 0 llm_load_print_meta: n_embd_head_k = 128 llm_load_print_meta: n_embd_head_v = 128 llm_load_print_meta: n_gqa = 4 llm_load_print_meta: n_embd_k_gqa = 1024 llm_load_print_meta: n_embd_v_gqa = 1024 llm_load_print_meta: f_norm_eps = 0.0e+00 llm_load_print_meta: f_norm_rms_eps = 1.0e-05 llm_load_print_meta: f_clamp_kqv = 0.0e+00 llm_load_print_meta: f_max_alibi_bias = 0.0e+00 llm_load_print_meta: f_logit_scale = 0.0e+00 llm_load_print_meta: n_ff = 6400 llm_load_print_meta: n_expert = 16 llm_load_print_meta: n_expert_used = 2 llm_load_print_meta: causal attn = 1 llm_load_print_meta: pooling type = 0 llm_load_print_meta: rope type = 2 llm_load_print_meta: rope scaling = linear llm_load_print_meta: freq_base_train = 10000.0 llm_load_print_meta: freq_scale_train = 1 llm_load_print_meta: n_ctx_orig_yarn = 4096 llm_load_print_meta: rope_finetuned = unknown llm_load_print_meta: ssm_d_conv = 0 llm_load_print_meta: ssm_d_inner = 0 llm_load_print_meta: ssm_d_state = 0 llm_load_print_meta: ssm_dt_rank = 0 llm_load_print_meta: ssm_dt_b_c_rms = 0 llm_load_print_meta: model type = 16x3.8B llm_load_print_meta: model ftype = Q3_K - Small llm_load_print_meta: model params = 41.87 B llm_load_print_meta: model size = 16.81 GiB (3.45 BPW) llm_load_print_meta: general.name = Phi 3.5 MoE Instruct llm_load_print_meta: BOS token = 1 '<s>' llm_load_print_meta: EOS token = 32000 '<|endoftext|>' llm_load_print_meta: EOT token = 32007 '<|end|>' llm_load_print_meta: UNK token = 0 '<unk>' llm_load_print_meta: PAD token = 32000 '<|endoftext|>' llm_load_print_meta: LF token = 13 '<0x0A>' llm_load_print_meta: EOG token = 32000 '<|endoftext|>' llm_load_print_meta: EOG token = 32007 '<|end|>' llm_load_print_meta: max token length = 48 ggml_vulkan: Compiling shaders..............................Done! request: GET / 127.0.0.1 503 llm_load_tensors: offloading 1 repeating layers to GPU llm_load_tensors: offloaded 1/33 layers to GPU llm_load_tensors: CUDA0 model buffer size = 533.16 MiB llm_load_tensors: CPU_Mapped model buffer size = 16684.80 MiB .................................................................................................... llama_new_context_with_model: n_seq_max = 1 llama_new_context_with_model: n_ctx = 32768 llama_new_context_with_model: n_ctx_per_seq = 32768 llama_new_context_with_model: n_batch = 2048 llama_new_context_with_model: n_ubatch = 512 llama_new_context_with_model: flash_attn = 0 llama_new_context_with_model: freq_base = 10000.0 llama_new_context_with_model: freq_scale = 1 llama_new_context_with_model: n_ctx_per_seq (32768) < n_ctx_train (131072) -- the full capacity of the model will not be utilized llama_kv_cache_init: kv_size = 32768, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 32 llama_kv_cache_init: CUDA0 KV buffer size = 128.00 MiB llama_kv_cache_init: CPU KV buffer size = 3968.00 MiB llama_new_context_with_model: KV self size = 4096.00 MiB, K (f16): 2048.00 MiB, V (f16): 2048.00 MiB llama_new_context_with_model: CPU output buffer size = 0.12 MiB llama_new_context_with_model: CUDA0 compute buffer size = 2342.88 MiB llama_new_context_with_model: CUDA_Host compute buffer size = 72.01 MiB llama_new_context_with_model: graph nodes = 1736 llama_new_context_with_model: graph splits = 565 (with bs=512), 3 (with bs=1) common_init_from_params: setting dry_penalty_last_n to ctx_size = 32768 srv init: initializing slots, n_slots = 1 slot init: id 0 | task -1 | new slot n_ctx_slot = 32768 main: model loaded main: chat template, built_in: 1, chat_example: '<|system|> You are a helpful assistant<|end|> <|user|> Hello<|end|> <|assistant|> Hi there<|end|> <|user|> How are you?<|end|> <|assistant|> ' main: server is listening on http://127.0.0.1:8080 - starting the main loop srv update_slots: all slots are idle request: GET / 127.0.0.1 200 slot launch_slot_: id 0 | task 0 | processing task slot update_slots: id 0 | task 0 | new prompt, n_ctx_slot = 32768, n_keep = 0, n_prompt_tokens = 16883 slot update_slots: id 0 | task 0 | kv cache rm [0, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.121305 slot update_slots: id 0 | task 0 | kv cache rm [2048, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 4096, n_tokens = 2048, progress = 0.242611 slot update_slots: id 0 | task 0 | kv cache rm [4096, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 6144, n_tokens = 2048, progress = 0.363916 slot update_slots: id 0 | task 0 | kv cache rm [6144, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 8192, n_tokens = 2048, progress = 0.485222 slot update_slots: id 0 | task 0 | kv cache rm [8192, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 10240, n_tokens = 2048, progress = 0.606527 slot update_slots: id 0 | task 0 | kv cache rm [10240, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 12288, n_tokens = 2048, progress = 0.727833 slot update_slots: id 0 | task 0 | kv cache rm [12288, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 14336, n_tokens = 2048, progress = 0.849138 slot update_slots: id 0 | task 0 | kv cache rm [14336, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 16384, n_tokens = 2048, progress = 0.970444 slot update_slots: id 0 | task 0 | kv cache rm [16384, end) slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 16883, n_tokens = 499, progress = 1.000000 slot update_slots: id 0 | task 0 | prompt done, n_past = 16883, n_tokens = 499 slot release: id 0 | task 0 | stop processing: n_past = 17660, truncated = 0 slot print_timing: id 0 | task 0 | prompt eval time = 986724.31 ms / 16883 tokens ( 58.44 ms per token, 17.11 tokens per second) eval time = 413753.94 ms / 778 tokens ( 531.82 ms per token, 1.88 tokens per second) total time = 1400478.25 ms / 17661 tokens srv update_slots: all slots are idle 

Copy link
Contributor

@matiaslinmatiaslin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

Ran llama-bench on Phi3.5 MoE Q4.

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA L40S, compute capability 8.9, VMM: no | model | size | params | backend | ngl | test | t/s | | ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: | | phimoe 16x3.8B Q4_0 | 21.98 GiB | 41.87 B | CUDA | 99 | pp512 | 4647.09 ± 52.40 | | phimoe 16x3.8B Q4_0 | 21.98 GiB | 41.87 B | CUDA | 99 | tg128 | 98.37 ± 0.03 | build: 0dae7685 (4398) 

@ggerganovggerganovforce-pushed the phymbert/model/phi35-moe branch from 0dae768 to 4ca3a77CompareJanuary 9, 2025 09:27
Copy link
Member

@ggerganovggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased this on the latest master

@ggerganovggerganovforce-pushed the phymbert/model/phi35-moe branch from 4ca3a77 to c0dd28dCompareJanuary 9, 2025 09:31
@phymbertphymbert merged commit f8feb4b into masterJan 9, 2025
59 of 60 checks passed
@phymbertphymbert deleted the phymbert/model/phi35-moe branch January 9, 2025 10:21
bandoti pushed a commit to bandoti/llama.cpp that referenced this pull request Jan 9, 2025
* model: support phimoe * python linter * doc: minor Co-authored-by: ThiloteE <[email protected]> * doc: minor Co-authored-by: ThiloteE <[email protected]> * doc: add phimoe as supported model ggml-ci --------- Co-authored-by: ThiloteE <[email protected]>
@d1hr2uv
Copy link

nice one dude.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Feb 26, 2025
* model: support phimoe * python linter * doc: minor Co-authored-by: ThiloteE <[email protected]> * doc: minor Co-authored-by: ThiloteE <[email protected]> * doc: add phimoe as supported model ggml-ci --------- Co-authored-by: ThiloteE <[email protected]>
Sign up for freeto join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentationImprovements or additions to documentationenhancementNew feature or requestmodelModel specificpythonpython script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Request to use Phi-3.5-MoE-instruct

7 participants

@phymbert@ThiloteE@d1hr2uv@ggerganov@slaren@matiaslin