Properly support per-model maxTokens/temperature
This commit is contained in:
@@ -19,15 +19,16 @@ import (
|
||||
// Prompt prompts the configured the configured model and streams the response
|
||||
// to stdout. Returns all model reply messages.
|
||||
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
||||
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||
modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := modelConfig.Client
|
||||
|
||||
params := provider.RequestParameters{
|
||||
Model: m,
|
||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *ctx.Config.Defaults.Temperature,
|
||||
Model: modelConfig.Model,
|
||||
MaxTokens: modelConfig.MaxTokens,
|
||||
Temperature: modelConfig.Temperature,
|
||||
}
|
||||
|
||||
system := ctx.DefaultSystemPrompt()
|
||||
@@ -206,15 +207,16 @@ Example response:
|
||||
},
|
||||
}
|
||||
|
||||
m, _, p, err := ctx.GetModelProvider(
|
||||
modelConfig, err := ctx.GetModelProvider(
|
||||
*ctx.Config.Conversations.TitleGenerationModel, "",
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
p := modelConfig.Client
|
||||
|
||||
requestParams := provider.RequestParameters{
|
||||
Model: m,
|
||||
Model: modelConfig.Model,
|
||||
MaxTokens: 25,
|
||||
}
|
||||
|
||||
|
||||
@@ -139,7 +139,35 @@ func (c *Context) DefaultSystemPrompt() string {
|
||||
return c.Config.Defaults.SystemPrompt
|
||||
}
|
||||
|
||||
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) {
|
||||
type ModelConfig struct {
|
||||
Provider string
|
||||
Client provider.ChatCompletionProvider
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
}
|
||||
|
||||
func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
|
||||
// Set model name
|
||||
cfg.Model = m.Name
|
||||
|
||||
// Set max tokens
|
||||
if m.MaxTokens == nil {
|
||||
cfg.MaxTokens = *c.Config.Defaults.MaxTokens
|
||||
} else {
|
||||
cfg.MaxTokens = *m.MaxTokens
|
||||
}
|
||||
|
||||
// Set temperature
|
||||
if m.Temperature == nil {
|
||||
cfg.Temperature = *c.Config.Defaults.Temperature
|
||||
} else {
|
||||
cfg.Temperature = *m.Temperature
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) {
|
||||
parts := strings.Split(model, "@")
|
||||
|
||||
if provider == "" && len(parts) > 1 {
|
||||
@@ -158,6 +186,7 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
|
||||
}
|
||||
|
||||
for _, m := range p.Models() {
|
||||
var cfg *ModelConfig
|
||||
if m.Name == model {
|
||||
switch p.Kind {
|
||||
case "anthropic":
|
||||
@@ -165,44 +194,56 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, name, &anthropic.AnthropicClient{
|
||||
cfg = &ModelConfig{
|
||||
Client: &anthropic.AnthropicClient{
|
||||
BaseURL: url,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return c.fillModelConfig(cfg, m), nil
|
||||
case "google":
|
||||
url := "https://generativelanguage.googleapis.com"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, name, &google.Client{
|
||||
cfg := &ModelConfig{
|
||||
Client: &google.Client{
|
||||
BaseURL: url,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return c.fillModelConfig(cfg, m), nil
|
||||
case "ollama":
|
||||
url := "http://localhost:11434/api"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, name, &ollama.OllamaClient{
|
||||
cfg := &ModelConfig{
|
||||
Client: &ollama.OllamaClient{
|
||||
BaseURL: url,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return c.fillModelConfig(cfg, m), nil
|
||||
case "openai":
|
||||
url := "https://api.openai.com"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, name, &openai.OpenAIClient{
|
||||
cfg := &ModelConfig{
|
||||
Client: &openai.OpenAIClient{
|
||||
BaseURL: url,
|
||||
APIKey: p.APIKey,
|
||||
Headers: p.Headers,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return c.fillModelConfig(cfg, m), nil
|
||||
default:
|
||||
return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
||||
return nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", nil, fmt.Errorf("unknown model: %s", model)
|
||||
return nil, fmt.Errorf("unknown model: %s", model)
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
|
||||
@@ -39,9 +39,9 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat
|
||||
|
||||
}
|
||||
|
||||
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||
app.Model = model
|
||||
app.ProviderName = provider
|
||||
modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||
app.Model = modelConfig.Model
|
||||
app.ProviderName = modelConfig.Provider
|
||||
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
||||
return app
|
||||
}
|
||||
@@ -257,15 +257,16 @@ func (a *AppModel) Prompt(
|
||||
chatReplyChunks chan provider.Chunk,
|
||||
stopSignal chan struct{},
|
||||
) (*conversation.Message, error) {
|
||||
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
||||
modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := modelConfig.Client
|
||||
|
||||
params := provider.RequestParameters{
|
||||
Model: model,
|
||||
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
||||
Model: modelConfig.Model,
|
||||
MaxTokens: modelConfig.MaxTokens,
|
||||
Temperature: modelConfig.Temperature,
|
||||
}
|
||||
|
||||
if a.Agent != nil {
|
||||
|
||||
Reference in New Issue
Block a user