- Notifications
You must be signed in to change notification settings - Fork 14.1k
llama : add gpt-oss#15091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add gpt-oss #15091
Uh oh!
There was an error while loading. Please reload this page.
Conversation
ggerganov commented Aug 5, 2025 • edited
Loading Uh oh!
There was an error while loading. Please reload this page.
edited
Uh oh!
There was an error while loading. Please reload this page.
* llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <[email protected]> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: slaren <[email protected]>
joseph777111 commented Aug 6, 2025
@ggerganov Thank you, ggerganov and everyone else for your expedient and awesome work! Will attention sinks be made available to all GGUF'd models? 🤔 |
CHNtentes commented Aug 6, 2025
Could anyone help explain why use |
fat-tire commented Aug 6, 2025
@nachoal it appears there's a lot of new stuff here. At least to me-- but I have not used openai's API with openai before, only local models. Like, there are two kinds of system prompts-- a "system message" and a "developer message". Also there are two types of tools-- "builtin_tools" (python or browser tools) referenced in the system message and function tools (described in the developer message). There is a special format for describing the function tools but I'm guessing MCP would work too. The function tools are called in a separate "commentary" channel from normal reply So different types of output appear in different places in the chat completion. As an example, instead of parsing reasoning_text=response.choices[0].message.reasoning_contentwhere response=client.chat.completions.create( ... )It looks like right now in llama.cpp by default when an assistant tries to use a tool, the <|start|>assistant<|channel|>commentary to=fetch json<|message|>{\"url\":\"https://www.github.com\",\"max_length\":5000}The expected format is is supposed to come after the reasoning like this: <|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>So the output looks very close but not exactly right from what I am seeing. It's missing the <|call|> I'm sure that in the near future a tool call will be fully extracted by llama.cpp and put in The Harmony dox also discusses "Output by the model which can either be a tool call or a message output. " -- so apparently you can get a message OR a tool call, but apparently not both, which is why the The hacky temporary workaround to this bug to maintain compatibility with other models would be to come up with a regex expression you could use to pull the json toolname and arguments/{output} from the There's a note in this PR that the tool template stuff is a WIP and tool use is still to come, so I guess it may make the most sense to just wait for this to get fixed unless you're really itching to get tools working. Anyone that knows more please correct me as I'm just figuring this out myself! |
nachoal commented Aug 6, 2025
@fat-tire Appreciate the complete explanation 🙏, I ended up just parsing the strings to try tool calling for now, a bit broken but it works. Thanks! |
nai-kon commented Aug 6, 2025
Has the reasoning_effort not been implemented yet? I'm hosting gpt-oss-20b on llama-server and calling it from the OpenAI API Here is quick sample.
fromopenaiimportOpenAImodel=OpenAI(api_key="dummy", base_url="http://127.0.0.1:8080") completion=model.chat.completions.create( model="dummy", messages=[{"role": "user", "content": "Write fizzbuzz in Python"}], reasoning_effort="high", ) print(completion) |
fat-tire commented Aug 6, 2025
@nachoal Yup a simple regex pattern on that pattern=r".*\<\|start\|\>assistant\<\|channel\|\>commentary to=(?P<toolname>\w+) json\<\|message\|\>(?P<output>.*)"gets you two match groups, |
uazure commented Aug 6, 2025
Using build: llama-b6098-bin-win-cuda-12.4-x64 |
slaren commented Aug 6, 2025
@uazure what CPU backend is it loading? |
createthis commented Aug 6, 2025 • edited
Loading Uh oh!
There was an error while loading. Please reload this page.
edited
Uh oh!
There was an error while loading. Please reload this page.
@nai-kon I noticed the same behavior. I think we should open a defect issue rather than clog this further. |
Added in GPT-OSS PR ggml-org/llama.cpp#15091 --------- Co-authored-by: Xuan-Son Nguyen <[email protected]>
| new_name_gate=self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight")) | ||
| new_name_up=self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too late, but why was this split? Only adds extra ops on the graph...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gate_up tensor is organized in a way that a row of gate is followed by a row of up, aka interleaving. While we can rearrange it to the expected layout for fused op, I think it's easier to just split it into gate and up independently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh, didn't catch that.
| structggml_tensor*ggml_swiglu_oai( | ||
| structggml_context*ctx, | ||
| structggml_tensor*a, | ||
| structggml_tensor*b, | ||
| floatalpha, | ||
| floatlimit){ | ||
| structggml_tensor*result=ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false); | ||
| ggml_set_op_params_f32(result, 2, alpha); | ||
| ggml_set_op_params_f32(result, 3, limit); | ||
| returnresult; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this is ggml_swiglu_oai_split.
ericcurtin commented Aug 10, 2025 • edited
Loading Uh oh!
There was an error while loading. Please reload this page.
edited
Uh oh!
There was an error while loading. Please reload this page.
Does anybody see value in adding a simple chat client to upstream llama.cpp in C++ or python3 that we can consolidate on like this: https://github.com/ericcurtin/lm-chat/blob/main/lm-chat.py ? For formats like this new harmony one it can be hard to find simple reference implementations, that are not "from openai_harmony" I guess there sort of is the html client implementation, but I'm not sure how many people are ready to crack that open as it's more that just a simple cli chat client. |
JohannesGaessler commented Aug 10, 2025
My opinion is that efforts should be focused on the existing web interface of the HTTP server. |
ericcurtin commented Aug 10, 2025
A couple off issues with the web interface. It's a UI, so added complexity for a simple reference implementation. It's compressed, which kills all the version control, can't easily see changes, it's like a mystery blob that's committed from time to time. I think it would be better if we committed both: ./tools/server/public/index.html.gz and ./tools/server/public/index.html on changes, at least we could track the changes then. |
ngxson commented Aug 10, 2025
index.html... hmm, good luck decoding the diff of the transpiled JS code |
ericcurtin commented Aug 10, 2025
My bad, the true sources are there. |
CISC commented Aug 14, 2025
@ggerganov This workflow has been stuck in the queue for over a week now, it's impossible to cancel: https://github.com/ggml-org/llama.cpp/actions/runs/16754489544 |
ngxson commented Aug 14, 2025
@CISC it was created when github was down last week, should be a bug on github side |
CISC commented Aug 14, 2025
Yeah, just wondering if it has anything to do with the abnormally long queue times we've been having since, or if it's something else. |
marvin-0042 commented Aug 22, 2025
Amazing! Thanks for the great job! Just curious, has anyone done any accuracy test for gpt-oss-20B on non-Nvidia platforms? Thanks! |

gpt-oss model support in native MXFP4 format:
ggml_add_idoperator in ggmlUsage:
Model collection: https://huggingface.co/collections/ggml-org/gpt-oss-68923b60bee37414546c70bf
Example command:
llama-server -hf ggml-org/gpt-oss-120b-GGUF -c 0 -fa --jinja --reasoning-format none # Then, access http://localhost:8080Model card
References:
Note to maintainers:
This an initial implementation with pretty much complete support for the CUDA, Vulkan, Metal and CPU backends. The idea is to merge this quicker than usual, in time for the official release today, and later we can work on polishing any potential problems and missing features.
Next PRs: