Private
Public Access
1
0

Properly support per-model maxTokens/temperature

This commit is contained in:
2025-06-25 07:49:29 +00:00
parent 259648f699
commit 3cd897d494
3 changed files with 76 additions and 32 deletions

View File

@@ -19,15 +19,16 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) { func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
p := modelConfig.Client
params := provider.RequestParameters{ params := provider.RequestParameters{
Model: m, Model: modelConfig.Model,
MaxTokens: *ctx.Config.Defaults.MaxTokens, MaxTokens: modelConfig.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature, Temperature: modelConfig.Temperature,
} }
system := ctx.DefaultSystemPrompt() system := ctx.DefaultSystemPrompt()
@@ -206,15 +207,16 @@ Example response:
}, },
} }
m, _, p, err := ctx.GetModelProvider( modelConfig, err := ctx.GetModelProvider(
*ctx.Config.Conversations.TitleGenerationModel, "", *ctx.Config.Conversations.TitleGenerationModel, "",
) )
if err != nil { if err != nil {
return "", err return "", err
} }
p := modelConfig.Client
requestParams := provider.RequestParameters{ requestParams := provider.RequestParameters{
Model: m, Model: modelConfig.Model,
MaxTokens: 25, MaxTokens: 25,
} }

View File

@@ -139,7 +139,35 @@ func (c *Context) DefaultSystemPrompt() string {
return c.Config.Defaults.SystemPrompt return c.Config.Defaults.SystemPrompt
} }
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) { type ModelConfig struct {
Provider string
Client provider.ChatCompletionProvider
Model string
MaxTokens int
Temperature float32
}
func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
// Set model name
cfg.Model = m.Name
// Set max tokens
if m.MaxTokens == nil {
cfg.MaxTokens = *c.Config.Defaults.MaxTokens
} else {
cfg.MaxTokens = *m.MaxTokens
}
// Set temperature
if m.Temperature == nil {
cfg.Temperature = *c.Config.Defaults.Temperature
} else {
cfg.Temperature = *m.Temperature
}
return cfg
}
func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")
if provider == "" && len(parts) > 1 { if provider == "" && len(parts) > 1 {
@@ -158,6 +186,7 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
} }
for _, m := range p.Models() { for _, m := range p.Models() {
var cfg *ModelConfig
if m.Name == model { if m.Name == model {
switch p.Kind { switch p.Kind {
case "anthropic": case "anthropic":
@@ -165,44 +194,56 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, name, &anthropic.AnthropicClient{ cfg = &ModelConfig{
BaseURL: url, Client: &anthropic.AnthropicClient{
APIKey: p.APIKey, BaseURL: url,
}, nil APIKey: p.APIKey,
},
}
return c.fillModelConfig(cfg, m), nil
case "google": case "google":
url := "https://generativelanguage.googleapis.com" url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, name, &google.Client{ cfg := &ModelConfig{
BaseURL: url, Client: &google.Client{
APIKey: p.APIKey, BaseURL: url,
}, nil APIKey: p.APIKey,
},
}
return c.fillModelConfig(cfg, m), nil
case "ollama": case "ollama":
url := "http://localhost:11434/api" url := "http://localhost:11434/api"
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, name, &ollama.OllamaClient{ cfg := &ModelConfig{
BaseURL: url, Client: &ollama.OllamaClient{
}, nil BaseURL: url,
},
}
return c.fillModelConfig(cfg, m), nil
case "openai": case "openai":
url := "https://api.openai.com" url := "https://api.openai.com"
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, name, &openai.OpenAIClient{ cfg := &ModelConfig{
BaseURL: url, Client: &openai.OpenAIClient{
APIKey: p.APIKey, BaseURL: url,
Headers: p.Headers, APIKey: p.APIKey,
}, nil Headers: p.Headers,
},
}
return c.fillModelConfig(cfg, m), nil
default: default:
return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) return nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
} }
} }
} }
} }
return "", "", nil, fmt.Errorf("unknown model: %s", model) return nil, fmt.Errorf("unknown model: %s", model)
} }
func Fatal(format string, args ...any) { func Fatal(format string, args ...any) {

View File

@@ -39,9 +39,9 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat
} }
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model app.Model = modelConfig.Model
app.ProviderName = provider app.ProviderName = modelConfig.Provider
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent) app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app return app
} }
@@ -257,15 +257,16 @@ func (a *AppModel) Prompt(
chatReplyChunks chan provider.Chunk, chatReplyChunks chan provider.Chunk,
stopSignal chan struct{}, stopSignal chan struct{},
) (*conversation.Message, error) { ) (*conversation.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p := modelConfig.Client
params := provider.RequestParameters{ params := provider.RequestParameters{
Model: model, Model: modelConfig.Model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens, MaxTokens: modelConfig.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature, Temperature: modelConfig.Temperature,
} }
if a.Agent != nil { if a.Agent != nil {