tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

commit 4e4ad90f926c4acfc131449d3e4a1561506a1cb9
parent b4db016217a94ccc1060c1876aa6531948ed811a
Author: Aristide Tossou <atossou@mozilla.com>
Date:   Wed,  8 Oct 2025 14:27:11 +0000

Bug 1990534 - Support Logit Bias Sampling for llama.cpp. r=padenot

Differential Revision: https://phabricator.services.mozilla.com/D267616

Diffstat:
Mconfig/external/mozinference/mozinference.symbols | 2++
Mtoolkit/components/ml/backends/llama/LlamaBackend.cpp | 43+++++++++++++++++++++++++++++++++----------
Mtoolkit/components/ml/backends/llama/LlamaBackend.h | 3++-
Mtoolkit/components/ml/backends/llama/LlamaRuntimeLinker.h | 4++++
4 files changed, 41 insertions(+), 11 deletions(-)

diff --git a/config/external/mozinference/mozinference.symbols b/config/external/mozinference/mozinference.symbols @@ -21,11 +21,13 @@ llama_sampler_init_temp llama_sampler_init_dist llama_sampler_init_top_k llama_sampler_init_top_p +llama_sampler_init_logit_bias llama_sampler_free llama_sampler_sample llama_memory_clear llama_get_memory llama_model_get_vocab +llama_vocab_n_tokens llama_tokenize llama_n_ctx llama_batch_get_one diff --git a/toolkit/components/ml/backends/llama/LlamaBackend.cpp b/toolkit/components/ml/backends/llama/LlamaBackend.cpp @@ -339,7 +339,8 @@ ChatMessageResult LlamaBackend::FormatChat( } LlamaBackend::SamplerResult LlamaBackend::InitializeSampler( - const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers) { + const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers, + const llama_vocab* vocab) { LOGV("Entered {}", __PRETTY_FUNCTION__); MOZ_ASSERT(mLib, "No shared library pointer in InitializeSampler, fix this"); @@ -360,6 +361,8 @@ LlamaBackend::SamplerResult LlamaBackend::InitializeSampler( mLib->llama_sampler_init_greedy()); } + auto n_vocab = mLib->llama_vocab_n_tokens(vocab); + for (const auto& samplerConfig : aSamplers) { llama_sampler* samplerElement = nullptr; @@ -385,6 +388,18 @@ LlamaBackend::SamplerResult LlamaBackend::InitializeSampler( samplerConfig.mMinKeep); break; + case LlamaSamplerType::Logit_bias: { + nsTArray<llama_logit_bias> logitBias; + logitBias.SetCapacity(samplerConfig.mLogitBias.Length()); + for (const auto& val : samplerConfig.mLogitBias) { + logitBias.AppendElement(llama_logit_bias{val.mToken, val.mBias}); + } + + samplerElement = mLib->llama_sampler_init_logit_bias( + n_vocab, samplerConfig.mLogitBias.Length(), logitBias.Elements()); + break; + } + default: auto msg = nsFmtCString(FMT_STRING("{}: Unimplemented sampler type"), @@ -420,15 +435,6 @@ ResultStatus LlamaBackend::Generate( } }); - auto samplerResult = InitializeSampler(aOptions.mSamplers); - - if (samplerResult.isErr()) { - LOGE("{}", samplerResult.inspectErr().mMessage); - return mozilla::Err(samplerResult.inspectErr()); - } - - auto sampler = samplerResult.unwrap(); - if (!mModel) { auto msg = nsFmtCString(FMT_STRING("{}: error: Model not loaded"), __PRETTY_FUNCTION__); @@ -439,6 +445,23 @@ ResultStatus LlamaBackend::Generate( // Just a non-owned pointer to existing data, so fast to get each time const llama_vocab* vocab = mLib->llama_model_get_vocab(mModel.get()); + if (!vocab) { + auto msg = + nsFmtCString(FMT_STRING("{}: error: Unable to get model vocabulary."), + __PRETTY_FUNCTION__); + LOGE("{}", msg); + return mozilla::Err(Error{msg}); + } + + auto samplerResult = InitializeSampler(aOptions.mSamplers, vocab); + + if (samplerResult.isErr()) { + LOGE("{}", samplerResult.inspectErr().mMessage); + return mozilla::Err(samplerResult.inspectErr()); + } + + auto sampler = samplerResult.unwrap(); + const size_t estimatedNumPromptTokens = aOptions.mPrompt.Length() + 1; LOGD("{} Estimated tokenization size is {} {}", __PRETTY_FUNCTION__, estimatedNumPromptTokens, mModelGeneralName); diff --git a/toolkit/components/ml/backends/llama/LlamaBackend.h b/toolkit/components/ml/backends/llama/LlamaBackend.h @@ -123,7 +123,8 @@ class LlamaBackend { private: SamplerResult InitializeSampler( - const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers); + const mozilla::dom::Sequence<LlamaSamplerConfig>& aSamplers, + const llama_vocab* vocab); // Pointer to the dynamically loaded llama library LlamaLibWrapper* mLib = nullptr; diff --git a/toolkit/components/ml/backends/llama/LlamaRuntimeLinker.h b/toolkit/components/ml/backends/llama/LlamaRuntimeLinker.h @@ -50,10 +50,14 @@ namespace mozilla::llama { X(struct llama_sampler*, llama_sampler_init_top_k, (int32_t k)) \ X(struct llama_sampler*, llama_sampler_init_top_p, \ (float p, size_t min_keep)) \ + X(struct llama_sampler*, llama_sampler_init_logit_bias, \ + (int32_t n_vocab, int32_t n_logit_bias, \ + const llama_logit_bias* logit_bias)) \ X(void, llama_memory_clear, (llama_memory_t mem, bool data)) \ X(llama_memory_t, llama_get_memory, (const struct llama_context* ctx)) \ X(const struct llama_vocab*, llama_model_get_vocab, \ (const struct llama_model* model)) \ + X(int32_t, llama_vocab_n_tokens, (const struct llama_vocab* vocab)) \ X(int32_t, llama_tokenize, \ (const struct llama_vocab* vocab, const char* text, int32_t text_len, \ llama_token* tokens, int32_t n_tokens_max, bool add_special, \