231 lines
4.9 KiB
Go
231 lines
4.9 KiB
Go
package lmcli
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Agent struct {
|
|
Name string
|
|
SystemPrompt string
|
|
Toolbox []api.ToolSpec
|
|
}
|
|
|
|
type Context struct {
|
|
// high level app configuration, may be mutated at runtime
|
|
Config Config
|
|
Store ConversationStore
|
|
Chroma *tty.ChromaHighlighter
|
|
}
|
|
|
|
func NewContext() (*Context, error) {
|
|
configFile := filepath.Join(configDir(), "config.yaml")
|
|
config, err := NewConfig(configFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
|
//Logger: logger.Default.LogMode(logger.Info),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
|
}
|
|
store, err := NewSQLStore(db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
|
|
|
return &Context{*config, store, chroma}, nil
|
|
}
|
|
|
|
func (c *Context) GetModels() (models []string) {
|
|
modelCounts := make(map[string]int)
|
|
for _, p := range c.Config.Providers {
|
|
name := p.Kind
|
|
if p.Name != "" {
|
|
name = p.Name
|
|
}
|
|
|
|
for _, m := range p.Models {
|
|
modelCounts[m]++
|
|
models = append(models, fmt.Sprintf("%s@%s", m, name))
|
|
}
|
|
}
|
|
|
|
for m, c := range modelCounts {
|
|
if c == 1 {
|
|
models = append(models, m)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Context) GetAgents() (agents []string) {
|
|
for _, p := range c.Config.Agents {
|
|
agents = append(agents, p.Name)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *Context) GetAgent(name string) *Agent {
|
|
if name == "" {
|
|
return nil
|
|
}
|
|
|
|
for _, a := range c.Config.Agents {
|
|
if name != a.Name {
|
|
continue
|
|
}
|
|
|
|
var enabledTools []api.ToolSpec
|
|
for _, toolName := range a.Tools {
|
|
tool, ok := agents.AvailableTools[toolName]
|
|
if ok {
|
|
enabledTools = append(enabledTools, tool)
|
|
}
|
|
}
|
|
|
|
return &Agent{
|
|
Name: a.Name,
|
|
SystemPrompt: a.SystemPrompt,
|
|
Toolbox: enabledTools,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Context) DefaultSystemPrompt() string {
|
|
if c.Config.Defaults.SystemPromptFile != "" {
|
|
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
|
|
if err != nil {
|
|
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
|
|
}
|
|
return content
|
|
}
|
|
return c.Config.Defaults.SystemPrompt
|
|
}
|
|
|
|
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
|
parts := strings.Split(model, "@")
|
|
|
|
var provider string
|
|
if len(parts) > 1 {
|
|
model = parts[0]
|
|
provider = parts[1]
|
|
}
|
|
|
|
for _, p := range c.Config.Providers {
|
|
name := p.Kind
|
|
if p.Name != "" {
|
|
name = p.Name
|
|
}
|
|
|
|
if provider != "" && name != provider {
|
|
continue
|
|
}
|
|
|
|
for _, m := range p.Models {
|
|
if m == model {
|
|
switch p.Kind {
|
|
case "anthropic":
|
|
url := "https://api.anthropic.com"
|
|
if p.BaseURL != "" {
|
|
url = p.BaseURL
|
|
}
|
|
return model, &anthropic.AnthropicClient{
|
|
BaseURL: url,
|
|
APIKey: p.APIKey,
|
|
}, nil
|
|
case "google":
|
|
url := "https://generativelanguage.googleapis.com"
|
|
if p.BaseURL != "" {
|
|
url = p.BaseURL
|
|
}
|
|
return model, &google.Client{
|
|
BaseURL: url,
|
|
APIKey: p.APIKey,
|
|
}, nil
|
|
case "ollama":
|
|
url := "http://localhost:11434/api"
|
|
if p.BaseURL != "" {
|
|
url = p.BaseURL
|
|
}
|
|
return model, &ollama.OllamaClient{
|
|
BaseURL: url,
|
|
}, nil
|
|
case "openai":
|
|
url := "https://api.openai.com"
|
|
if p.BaseURL != "" {
|
|
url = p.BaseURL
|
|
}
|
|
return model, &openai.OpenAIClient{
|
|
BaseURL: url,
|
|
APIKey: p.APIKey,
|
|
Headers: p.Headers,
|
|
}, nil
|
|
default:
|
|
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return "", nil, fmt.Errorf("unknown model: %s", model)
|
|
}
|
|
|
|
func configDir() string {
|
|
var configDir string
|
|
|
|
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
|
if xdgConfigHome != "" {
|
|
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
|
} else {
|
|
userHomeDir, _ := os.UserHomeDir()
|
|
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
|
}
|
|
|
|
os.MkdirAll(configDir, 0755)
|
|
return configDir
|
|
}
|
|
|
|
func dataDir() string {
|
|
var dataDir string
|
|
|
|
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
|
if xdgDataHome != "" {
|
|
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
|
} else {
|
|
userHomeDir, _ := os.UserHomeDir()
|
|
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
|
}
|
|
|
|
os.MkdirAll(dataDir, 0755)
|
|
return dataDir
|
|
}
|
|
|
|
func Fatal(format string, args ...any) {
|
|
fmt.Fprintf(os.Stderr, format, args...)
|
|
os.Exit(1)
|
|
}
|
|
|
|
func Warn(format string, args ...any) {
|
|
fmt.Fprintf(os.Stderr, format, args...)
|
|
}
|