Private
Public Access
1
0
Files
lmcli/pkg/lmcli/lmcli.go
Matt Low 259648f699 Slight lmcli package refactor
- Moved utility functions from lmcli.go to tools.go

This is preparing for `lmcli` package API refactoring
2025-06-25 07:18:57 +00:00

216 lines
4.9 KiB
Go

package lmcli
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type Agent struct {
Name string
SystemPrompt string
Toolbox []api.ToolSpec
}
type Context struct {
// high level app configuration, may be mutated at runtime
Config Config
Conversations conversation.Repo
Chroma tty.ChromaHighlighter
}
func NewContext() (*Context, error) {
configFile := filepath.Join(configDir(), "config.yaml")
config, err := NewConfig(configFile)
if err != nil {
return nil, err
}
databaseFile := filepath.Join(dataDir(), "conversations.db")
gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log"))
if err != nil {
return nil, fmt.Errorf("Could not open database log file: %v", err)
}
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
Logger: logger.New(log.New(gormLogFile, "\n", log.LstdFlags), logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
Colorful: true,
}),
})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
repo, err := conversation.NewRepo(db)
if err != nil {
return nil, err
}
// Initialize chroma
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
return &Context{*config, repo, *chroma}, nil
}
func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers {
name := p.Kind
if p.Name != "" {
name = p.Name
}
for _, m := range p.Models() {
modelCounts[m.Name]++
models = append(models, fmt.Sprintf("%s@%s", m.Name, name))
}
}
for m, c := range modelCounts {
if c == 1 {
models = append(models, m)
}
}
return
}
func (c *Context) GetAgents() (agents []string) {
for _, p := range c.Config.Agents {
agents = append(agents, p.Name)
}
return
}
func (c *Context) GetAgent(name string) *Agent {
if name == "" || name == "none" {
return nil
}
for _, a := range c.Config.Agents {
if name != a.Name {
continue
}
var enabledTools []api.ToolSpec
for _, toolName := range a.Tools {
tool, ok := agents.AvailableTools[toolName]
if ok {
enabledTools = append(enabledTools, tool)
}
}
return &Agent{
Name: a.Name,
SystemPrompt: a.SystemPrompt,
Toolbox: enabledTools,
}
}
return nil
}
func (c *Context) DefaultSystemPrompt() string {
if c.Config.Defaults.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
}
return content
}
return c.Config.Defaults.SystemPrompt
}
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) {
parts := strings.Split(model, "@")
if provider == "" && len(parts) > 1 {
model = parts[0]
provider = parts[1]
}
for _, p := range c.Config.Providers {
name := p.Kind
if p.Name != "" {
name = p.Name
}
if provider != "" && name != provider {
continue
}
for _, m := range p.Models() {
if m.Name == model {
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &google.Client{
BaseURL: url,
APIKey: p.APIKey,
}, nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &ollama.OllamaClient{
BaseURL: url,
}, nil
case "openai":
url := "https://api.openai.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
Headers: p.Headers,
}, nil
default:
return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
}
}
}
}
return "", "", nil, fmt.Errorf("unknown model: %s", model)
}
func Fatal(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
os.Exit(1)
}
func Warn(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
}