Properly support per-model maxTokens/temperature
This commit is contained in:
@@ -19,15 +19,16 @@ import (
|
|||||||
// Prompt prompts the configured the configured model and streams the response
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
||||||
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
p := modelConfig.Client
|
||||||
|
|
||||||
params := provider.RequestParameters{
|
params := provider.RequestParameters{
|
||||||
Model: m,
|
Model: modelConfig.Model,
|
||||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
MaxTokens: modelConfig.MaxTokens,
|
||||||
Temperature: *ctx.Config.Defaults.Temperature,
|
Temperature: modelConfig.Temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
system := ctx.DefaultSystemPrompt()
|
system := ctx.DefaultSystemPrompt()
|
||||||
@@ -206,15 +207,16 @@ Example response:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, _, p, err := ctx.GetModelProvider(
|
modelConfig, err := ctx.GetModelProvider(
|
||||||
*ctx.Config.Conversations.TitleGenerationModel, "",
|
*ctx.Config.Conversations.TitleGenerationModel, "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
p := modelConfig.Client
|
||||||
|
|
||||||
requestParams := provider.RequestParameters{
|
requestParams := provider.RequestParameters{
|
||||||
Model: m,
|
Model: modelConfig.Model,
|
||||||
MaxTokens: 25,
|
MaxTokens: 25,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -139,7 +139,35 @@ func (c *Context) DefaultSystemPrompt() string {
|
|||||||
return c.Config.Defaults.SystemPrompt
|
return c.Config.Defaults.SystemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) {
|
type ModelConfig struct {
|
||||||
|
Provider string
|
||||||
|
Client provider.ChatCompletionProvider
|
||||||
|
Model string
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
|
||||||
|
// Set model name
|
||||||
|
cfg.Model = m.Name
|
||||||
|
|
||||||
|
// Set max tokens
|
||||||
|
if m.MaxTokens == nil {
|
||||||
|
cfg.MaxTokens = *c.Config.Defaults.MaxTokens
|
||||||
|
} else {
|
||||||
|
cfg.MaxTokens = *m.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set temperature
|
||||||
|
if m.Temperature == nil {
|
||||||
|
cfg.Temperature = *c.Config.Defaults.Temperature
|
||||||
|
} else {
|
||||||
|
cfg.Temperature = *m.Temperature
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) {
|
||||||
parts := strings.Split(model, "@")
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
if provider == "" && len(parts) > 1 {
|
if provider == "" && len(parts) > 1 {
|
||||||
@@ -158,6 +186,7 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range p.Models() {
|
for _, m := range p.Models() {
|
||||||
|
var cfg *ModelConfig
|
||||||
if m.Name == model {
|
if m.Name == model {
|
||||||
switch p.Kind {
|
switch p.Kind {
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
@@ -165,44 +194,56 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
|
|||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
return model, name, &anthropic.AnthropicClient{
|
cfg = &ModelConfig{
|
||||||
|
Client: &anthropic.AnthropicClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
}, nil
|
},
|
||||||
|
}
|
||||||
|
return c.fillModelConfig(cfg, m), nil
|
||||||
case "google":
|
case "google":
|
||||||
url := "https://generativelanguage.googleapis.com"
|
url := "https://generativelanguage.googleapis.com"
|
||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
return model, name, &google.Client{
|
cfg := &ModelConfig{
|
||||||
|
Client: &google.Client{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
}, nil
|
},
|
||||||
|
}
|
||||||
|
return c.fillModelConfig(cfg, m), nil
|
||||||
case "ollama":
|
case "ollama":
|
||||||
url := "http://localhost:11434/api"
|
url := "http://localhost:11434/api"
|
||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
return model, name, &ollama.OllamaClient{
|
cfg := &ModelConfig{
|
||||||
|
Client: &ollama.OllamaClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
}, nil
|
},
|
||||||
|
}
|
||||||
|
return c.fillModelConfig(cfg, m), nil
|
||||||
case "openai":
|
case "openai":
|
||||||
url := "https://api.openai.com"
|
url := "https://api.openai.com"
|
||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
return model, name, &openai.OpenAIClient{
|
cfg := &ModelConfig{
|
||||||
|
Client: &openai.OpenAIClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
Headers: p.Headers,
|
Headers: p.Headers,
|
||||||
}, nil
|
},
|
||||||
|
}
|
||||||
|
return c.fillModelConfig(cfg, m), nil
|
||||||
default:
|
default:
|
||||||
return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
return nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "", "", nil, fmt.Errorf("unknown model: %s", model)
|
return nil, fmt.Errorf("unknown model: %s", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Fatal(format string, args ...any) {
|
func Fatal(format string, args ...any) {
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
app.Model = model
|
app.Model = modelConfig.Model
|
||||||
app.ProviderName = provider
|
app.ProviderName = modelConfig.Provider
|
||||||
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
@@ -257,15 +257,16 @@ func (a *AppModel) Prompt(
|
|||||||
chatReplyChunks chan provider.Chunk,
|
chatReplyChunks chan provider.Chunk,
|
||||||
stopSignal chan struct{},
|
stopSignal chan struct{},
|
||||||
) (*conversation.Message, error) {
|
) (*conversation.Message, error) {
|
||||||
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
p := modelConfig.Client
|
||||||
|
|
||||||
params := provider.RequestParameters{
|
params := provider.RequestParameters{
|
||||||
Model: model,
|
Model: modelConfig.Model,
|
||||||
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
MaxTokens: modelConfig.MaxTokens,
|
||||||
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
Temperature: modelConfig.Temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
if a.Agent != nil {
|
if a.Agent != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user