Private
Public Access
1
0
Files
lmcli/pkg/lmcli/lmcli.go
2024-06-01 01:38:45 +00:00

192 lines
4.4 KiB
Go

package lmcli
import (
"fmt"
"os"
"path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Context struct {
Config *Config // may be updated at runtime
Store ConversationStore
Chroma *tty.ChromaHighlighter
EnabledTools []model.Tool
SystemPromptFile string
}
func NewContext() (*Context, error) {
configFile := filepath.Join(configDir(), "config.yaml")
config, err := NewConfig(configFile)
if err != nil {
Fatal("%v\n", err)
}
databaseFile := filepath.Join(dataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
store, err := NewSQLStore(db)
if err != nil {
Fatal("%v\n", err)
}
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
var enabledTools []model.Tool
for _, toolName := range config.Tools.EnabledTools {
tool, ok := tools.AvailableTools[toolName]
if ok {
enabledTools = append(enabledTools, tool)
}
}
return &Context{config, store, chroma, enabledTools, ""}, nil
}
func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers {
for _, m := range *p.Models {
modelCounts[m]++
models = append(models, *p.Name+"/"+m)
}
}
for m, c := range modelCounts {
if c == 1 {
models = append(models, m)
}
}
return
}
func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) {
parts := strings.Split(model, "/")
var provider string
if len(parts) > 1 {
provider = parts[0]
model = parts[1]
}
for _, p := range c.Config.Providers {
if provider != "" && *p.Name != provider {
continue
}
for _, m := range *p.Models {
if m == model {
switch *p.Kind {
case "anthropic":
url := "https://api.anthropic.com/v1"
if p.BaseURL != nil {
url = *p.BaseURL
}
return model, &anthropic.AnthropicClient{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != nil {
url = *p.BaseURL
}
return model, &google.Client{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != nil {
url = *p.BaseURL
}
return model, &ollama.OllamaClient{
BaseURL: url,
}, nil
case "openai":
url := "https://api.openai.com/v1"
if p.BaseURL != nil {
url = *p.BaseURL
}
return model, &openai.OpenAIClient{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
default:
return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
}
}
}
}
return "", nil, fmt.Errorf("unknown model: %s", model)
}
func (c *Context) GetSystemPrompt() string {
if c.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err)
}
return content
}
return *c.Config.Defaults.SystemPrompt
}
func configDir() string {
var configDir string
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
if xdgConfigHome != "" {
configDir = filepath.Join(xdgConfigHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
configDir = filepath.Join(userHomeDir, ".config/lmcli")
}
os.MkdirAll(configDir, 0755)
return configDir
}
func dataDir() string {
var dataDir string
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
dataDir = filepath.Join(xdgDataHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
}
os.MkdirAll(dataDir, 0755)
return dataDir
}
func Fatal(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
os.Exit(1)
}
func Warn(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
}