commit 4e4ad90f926c4acfc131449d3e4a1561506a1cb9
parent b4db016217a94ccc1060c1876aa6531948ed811a
Author: Aristide Tossou <atossou@mozilla.com>
Date: Wed, 8 Oct 2025 14:27:11 +0000
Bug 1990534 - Support Logit Bias Sampling for llama.cpp. r=padenot
Differential Revision: https://phabricator.services.mozilla.com/D267616
Diffstat:
4 files changed, 41 insertions(+), 11 deletions(-)
diff --git a/config/external/mozinference/mozinference.symbols b/config/external/mozinference/mozinference.symbols
@@ -21,11 +21,13 @@ llama_sampler_init_temp
llama_sampler_init_dist
llama_sampler_init_top_k
llama_sampler_init_top_p
+llama_sampler_init_logit_bias
llama_sampler_free
llama_sampler_sample
llama_memory_clear
llama_get_memory
llama_model_get_vocab
+llama_vocab_n_tokens
llama_tokenize
llama_n_ctx
llama_batch_get_one
diff --git a/toolkit/components/ml/backends/llama/LlamaBackend.cpp b/toolkit/components/ml/backends/llama/LlamaBackend.cpp
@@ -339,7 +339,8 @@ ChatMessageResult LlamaBackend::FormatChat(
}
LlamaBackend::SamplerResult LlamaBackend::InitializeSampler(
- const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers) {
+ const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers,
+ const llama_vocab* vocab) {
LOGV("Entered {}", __PRETTY_FUNCTION__);
MOZ_ASSERT(mLib, "No shared library pointer in InitializeSampler, fix this");
@@ -360,6 +361,8 @@ LlamaBackend::SamplerResult LlamaBackend::InitializeSampler(
mLib->llama_sampler_init_greedy());
}
+ auto n_vocab = mLib->llama_vocab_n_tokens(vocab);
+
for (const auto& samplerConfig : aSamplers) {
llama_sampler* samplerElement = nullptr;
@@ -385,6 +388,18 @@ LlamaBackend::SamplerResult LlamaBackend::InitializeSampler(
samplerConfig.mMinKeep);
break;
+ case LlamaSamplerType::Logit_bias: {
+ nsTArray<llama_logit_bias> logitBias;
+ logitBias.SetCapacity(samplerConfig.mLogitBias.Length());
+ for (const auto& val : samplerConfig.mLogitBias) {
+ logitBias.AppendElement(llama_logit_bias{val.mToken, val.mBias});
+ }
+
+ samplerElement = mLib->llama_sampler_init_logit_bias(
+ n_vocab, samplerConfig.mLogitBias.Length(), logitBias.Elements());
+ break;
+ }
+
default:
auto msg = nsFmtCString(FMT_STRING("{}: Unimplemented sampler type"),
@@ -420,15 +435,6 @@ ResultStatus LlamaBackend::Generate(
}
});
- auto samplerResult = InitializeSampler(aOptions.mSamplers);
-
- if (samplerResult.isErr()) {
- LOGE("{}", samplerResult.inspectErr().mMessage);
- return mozilla::Err(samplerResult.inspectErr());
- }
-
- auto sampler = samplerResult.unwrap();
-
if (!mModel) {
auto msg = nsFmtCString(FMT_STRING("{}: error: Model not loaded"),
__PRETTY_FUNCTION__);
@@ -439,6 +445,23 @@ ResultStatus LlamaBackend::Generate(
// Just a non-owned pointer to existing data, so fast to get each time
const llama_vocab* vocab = mLib->llama_model_get_vocab(mModel.get());
+ if (!vocab) {
+ auto msg =
+ nsFmtCString(FMT_STRING("{}: error: Unable to get model vocabulary."),
+ __PRETTY_FUNCTION__);
+ LOGE("{}", msg);
+ return mozilla::Err(Error{msg});
+ }
+
+ auto samplerResult = InitializeSampler(aOptions.mSamplers, vocab);
+
+ if (samplerResult.isErr()) {
+ LOGE("{}", samplerResult.inspectErr().mMessage);
+ return mozilla::Err(samplerResult.inspectErr());
+ }
+
+ auto sampler = samplerResult.unwrap();
+
const size_t estimatedNumPromptTokens = aOptions.mPrompt.Length() + 1;
LOGD("{} Estimated tokenization size is {} {}", __PRETTY_FUNCTION__,
estimatedNumPromptTokens, mModelGeneralName);
diff --git a/toolkit/components/ml/backends/llama/LlamaBackend.h b/toolkit/components/ml/backends/llama/LlamaBackend.h
@@ -123,7 +123,8 @@ class LlamaBackend {
private:
SamplerResult InitializeSampler(
- const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers);
+ const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers,
+ const llama_vocab* vocab);
// Pointer to the dynamically loaded llama library
LlamaLibWrapper* mLib = nullptr;
diff --git a/toolkit/components/ml/backends/llama/LlamaRuntimeLinker.h b/toolkit/components/ml/backends/llama/LlamaRuntimeLinker.h
@@ -50,10 +50,14 @@ namespace mozilla::llama {
X(struct llama_sampler*, llama_sampler_init_top_k, (int32_t k)) \
X(struct llama_sampler*, llama_sampler_init_top_p, \
(float p, size_t min_keep)) \
+ X(struct llama_sampler*, llama_sampler_init_logit_bias, \
+ (int32_t n_vocab, int32_t n_logit_bias, \
+ const llama_logit_bias* logit_bias)) \
X(void, llama_memory_clear, (llama_memory_t mem, bool data)) \
X(llama_memory_t, llama_get_memory, (const struct llama_context* ctx)) \
X(const struct llama_vocab*, llama_model_get_vocab, \
(const struct llama_model* model)) \
+ X(int32_t, llama_vocab_n_tokens, (const struct llama_vocab* vocab)) \
X(int32_t, llama_tokenize, \
(const struct llama_vocab* vocab, const char* text, int32_t text_len, \
llama_token* tokens, int32_t n_tokens_max, bool add_special, \