Private
Public Access
1
0

Project refactor, add anthropic API support

- Split pkg/cli/cmd.go into new pkg/cmd package
- Split pkg/cli/functions.go into pkg/lmcli/tools package
- Refactor pkg/cli/openai.go to pkg/lmcli/provider/openai

Other changes:

- Made models configurable
- Slight config reorganization
This commit is contained in:
2024-02-22 04:55:38 +00:00
parent 2611663168
commit 0a27b9a8d3
41 changed files with 3029 additions and 1899 deletions

71
pkg/lmcli/config.go Normal file
View File

@@ -0,0 +1,71 @@
package lmcli
import (
"fmt"
"os"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/go-yaml/yaml"
)
type Config struct {
Defaults *struct {
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
MaxTokens *int `yaml:"maxTokens" default:"256"`
Temperature *float32 `yaml:"temperature" default:"0.7"`
Model *string `yaml:"model" default:"gpt-4"`
} `yaml:"defaults"`
Conversations *struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
} `yaml:"conversations"`
Tools *struct {
EnabledTools *[]string `yaml:"enabledTools"`
} `yaml:"tools"`
OpenAI *struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"`
Models *[]string `yaml:"models"`
} `yaml:"openai"`
Anthropic *struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"`
Models *[]string `yaml:"models"`
} `yaml:"anthropic"`
Chroma *struct {
Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"`
} `yaml:"chroma"`
}
func NewConfig(configFile string) (*Config, error) {
shouldWriteDefaults := false
c := &Config{}
configExists := true
configBytes, err := os.ReadFile(configFile)
if os.IsNotExist(err) {
configExists = false
} else if err != nil {
return nil, fmt.Errorf("Could not read config file: %v", err)
} else {
yaml.Unmarshal(configBytes, c)
}
shouldWriteDefaults = util.SetStructDefaults(c)
if !configExists || shouldWriteDefaults {
if configExists {
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
os.Rename(configFile, configFile+".bak")
}
fmt.Printf("Writing configuration file to %s\n", configFile)
file, err := os.Create(configFile)
if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
}
bytes, _ := yaml.Marshal(c)
_, err = file.Write(bytes)
if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err)
}
}
return c, nil
}

97
pkg/lmcli/lmcli.go Normal file
View File

@@ -0,0 +1,97 @@
package lmcli
import (
"fmt"
"os"
"path/filepath"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Context struct {
Config Config
Store ConversationStore
}
func NewContext() (*Context, error) {
configFile := filepath.Join(configDir(), "config.yaml")
config, err := NewConfig(configFile)
if err != nil {
Fatal("%v\n", err)
}
databaseFile := filepath.Join(dataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
s, err := NewSQLStore(db)
if err != nil {
Fatal("%v\n", err)
}
return &Context{*config, s}, nil
}
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
for _, m := range *c.Config.Anthropic.Models {
if m == model {
anthropic := &anthropic.AnthropicClient{
APIKey: *c.Config.Anthropic.APIKey,
}
return anthropic, nil
}
}
for _, m := range *c.Config.OpenAI.Models {
if m == model {
openai := &openai.OpenAIClient{
APIKey: *c.Config.OpenAI.APIKey,
}
return openai, nil
}
}
return nil, fmt.Errorf("unknown model: %s", model)
}
func configDir() string {
var configDir string
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
if xdgConfigHome != "" {
configDir = filepath.Join(xdgConfigHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
configDir = filepath.Join(userHomeDir, ".config/lmcli")
}
os.MkdirAll(configDir, 0755)
return configDir
}
func dataDir() string {
var dataDir string
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
dataDir = filepath.Join(xdgDataHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
}
os.MkdirAll(dataDir, 0755)
return dataDir
}
func Fatal(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
os.Exit(1)
}
func Warn(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
}

View File

@@ -0,0 +1,58 @@
package model
import (
"database/sql"
"time"
)
type MessageRole string
const (
MessageRoleSystem MessageRole = "system"
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"`
Content string
Role MessageRole
CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the modl)
ToolResults ToolResults // a json array of tool results
}
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
SystemPrompt string
ToolBag []Tool
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m *MessageRole) FriendlyRole() string {
var friendlyRole string
switch *m {
case MessageRoleUser:
friendlyRole = "You"
case MessageRoleSystem:
friendlyRole = "System"
case MessageRoleAssistant:
friendlyRole = "Assistant"
default:
friendlyRole = string(*m)
}
return friendlyRole
}

98
pkg/lmcli/model/tool.go Normal file
View File

@@ -0,0 +1,98 @@
package model
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type Tool struct {
Name string
Description string
Parameters []ToolParameter
Impl func(*Tool, map[string]interface{}) (string, error)
}
type ToolParameter struct {
Name string `json:"name"`
Type string `json:"type"` // "string", "integer", "boolean"
Required bool `json:"required"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Parameters map[string]interface{} `json:"parameters"`
}
type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResult struct {
ToolCallID string `json:"toolCallID"`
ToolName string `json:"toolName,omitempty"`
Result string `json:"result,omitempty"`
}
type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@@ -0,0 +1,322 @@
package anthropic
import (
"bufio"
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type AnthropicClient struct {
APIKey string
}
type Message struct {
Role string `json:"role"`
OriginalContent string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
}
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
OriginalContent []OriginalContent `json:"content"`
}
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
requestBody := Request{
Model: params.Model,
Messages: make([]Message, len(messages)),
System: params.SystemPrompt,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Stream: false,
StopSequences: []string{
FUNCTION_STOP_SEQUENCE,
"\n\nHuman:",
},
}
startIdx := 0
if messages[0].Role == model.MessageRoleSystem {
requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[:len(messages)-1]
startIdx = 1
}
if len(params.ToolBag) > 0 {
if len(requestBody.System) > 0 {
// add a divider between existing system prompt and tools
requestBody.System += "\n\n---\n\n"
}
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
}
for i, msg := range messages[startIdx:] {
message := &requestBody.Messages[i]
switch msg.Role {
case model.MessageRoleToolCall:
message.Role = "assistant"
message.OriginalContent = msg.Content
//message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString()
if err != nil {
panic("Could not serialize []ToolResult to XMLFunctionResults")
}
message.Role = "user"
message.OriginalContent = xmlString
default:
message.Role = string(msg.Role)
message.OriginalContent = msg.Content
}
}
return requestBody
}
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
}
return resp, nil
}
func (c *AnthropicClient) CreateChatCompletion(
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
) (string, error) {
request := buildRequest(params, messages)
resp, err := sendRequest(c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
var response Response
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", fmt.Errorf("failed to decode response: %v", err)
}
sb := strings.Builder{}
for _, content := range response.OriginalContent {
var reply model.Message
switch content.Type {
case "text":
reply = model.Message{
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
*replies = append(*replies, reply)
}
return sb.String(), nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
output chan<- string,
) (string, error) {
request := buildRequest(params, messages)
request.Stream = true
resp, err := sendRequest(c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{}
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
if line[0] == '{' {
var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event: %s", line)
}
switch eventType {
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
}
} else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event type")
}
switch eventType {
case "message_start":
// noop
case "ping":
// write an empty string to signal start of text
output <- ""
case "content_block_start":
// ignore?
case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid content block delta")
}
text, ok := delta["text"].(string)
if !ok {
return "", fmt.Errorf("invalid text delta")
}
sb.WriteString(text)
output <- text
case "content_block_stop":
// ignore?
case "message_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid message delta")
}
stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" {
stopSequence, ok := delta["stop_sequence"].(string)
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
content := sb.String()
start := strings.Index(content, "<function_calls>")
if start == -1 {
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE
// Extract function calls
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
}
// Execute function calls
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolReply := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if replies != nil {
*replies = append(append(*replies, toolCall), toolReply)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(params, messages, replies, output)
}
}
case "message_stop":
// return the completed message
reply := model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
}
*replies = append(*replies, reply)
return sb.String(), nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
fmt.Printf("\nUnrecognized event: %s\n", data)
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read response body: %v", err)
}
return "", fmt.Errorf("unexpected end of stream")
}

View File

@@ -0,0 +1,182 @@
package anthropic
import (
"bytes"
"strings"
"text/template"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
const TOOL_PREAMBLE = `In this environment you have access to a set of tools which may assist you in fulfilling user requests.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:`
type XMLTools struct {
XMLName struct{} `xml:"tools"`
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
}
type XMLToolDescription struct {
ToolName string `xml:"tool_name"`
Description string `xml:"description"`
Parameters []XMLToolParameter `xml:"parameters>parameter"`
}
type XMLToolParameter struct {
Name string `xml:"name"`
Type string `xml:"type"`
Description string `xml:"description"`
}
type XMLFunctionCalls struct {
XMLName struct{} `xml:"function_calls"`
Invoke []XMLFunctionInvoke `xml:"invoke"`
}
type XMLFunctionInvoke struct {
ToolName string `xml:"tool_name"`
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
}
type XMLFunctionInvokeParameters struct {
String string `xml:",innerxml"`
}
type XMLFunctionResults struct {
XMLName struct{} `xml:"function_results"`
Result []XMLFunctionResult `xml:"result"`
}
type XMLFunctionResult struct {
ToolName string `xml:"tool_name"`
Stdout string `xml:"stdout"`
}
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
// parameters name to value
func parseFunctionParametersXML(params string) map[string]interface{} {
lines := strings.Split(params, "\n")
ret := make(map[string]interface{}, len(lines))
for _, line := range lines {
i := strings.Index(line, ">")
if i == -1 {
continue
}
j := strings.Index(line, "</")
if j == -1 {
continue
}
// chop from after opening < to first > to get parameter name,
// then chop after > to first </ to get parameter value
ret[line[1:i]] = line[i+1 : j]
}
return ret
}
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools {
converted[i].ToolName = tool.Name
converted[i].Description = tool.Description
params := make([]XMLToolParameter, len(tool.Parameters))
for j, param := range tool.Parameters {
params[j].Name = param.Name
params[j].Description = param.Description
params[j].Type = param.Type
}
converted[i].Parameters = params
}
return XMLTools{
ToolDescriptions: converted,
}
}
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
}
return toolCalls
}
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults {
converted[i].ToolName = result.ToolName
converted[i].Stdout = result.Result
}
return XMLFunctionResults{
Result: converted,
}
}
func buildToolsSystemPrompt(tools []model.Tool) string {
xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString()
if err != nil {
panic("Could not serialize []model.Tool to XMLTools")
}
return TOOL_PREAMBLE + "\n" + xmlToolsString + "\n"
}
func (x XMLTools) XMLString() (string, error) {
tmpl, err := template.New("tools").Parse(`<tools>
{{range .ToolDescriptions}}<tool_description>
<tool_name>{{.ToolName}}</tool_name>
<description>
{{.Description}}
</description>
<parameters>
{{range .Parameters}}<parameter>
<name>{{.Name}}</name>
<type>{{.Type}}</type>
<description>{{.Description}}</description>
</parameter>
{{end}}</parameters>
</tool_description>
{{end}}</tools>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionResults) XMLString() (string, error) {
tmpl, err := template.New("function_results").Parse(`<function_results>
{{range .Result}}<result>
<tool_name>{{.ToolName}}</tool_name>
<stdout>{{.Stdout}}</stdout>
</result>
{{end}}</function_results>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,270 @@
package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
)
type OpenAIClient struct {
APIKey string
}
type OpenAIToolParameters struct {
Type string `json:"type"`
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type OpenAIToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
func convertTools(tools []model.Tool) []openai.Tool {
openaiTools := make([]openai.Tool, len(tools))
for i, tool := range tools {
openaiTools[i].Type = "function"
params := make(map[string]OpenAIToolParameter)
var required []string
for _, param := range tool.Parameters {
params[param.Name] = OpenAIToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
openaiTools[i].Function = openai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: OpenAIToolParameters{
Type: "object",
Properties: params,
Required: required,
},
}
}
return openaiTools
}
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
converted := make([]openai.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].Type = "function"
converted[i].ID = call.ID
converted[i].Function.Name = call.Name
json, _ := json.Marshal(call.Parameters)
converted[i].Function.Arguments = string(json)
}
return converted
}
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
converted := make([]model.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].ID = call.ID
converted[i].Name = call.Function.Name
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
}
return converted
}
func createChatCompletionRequest(
c *OpenAIClient,
params model.RequestParameters,
messages []model.Message,
) openai.ChatCompletionRequest {
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
message := openai.ChatCompletionMessage{}
message.Role = "assistant"
message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
requestMessages = append(requestMessages, message)
case "tool_result":
// expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults {
message := openai.ChatCompletionMessage{}
message.Role = "tool"
message.Content = result.Result
message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message)
}
default:
message := openai.ChatCompletionMessage{}
message.Role = string(m.Role)
message.Content = m.Content
requestMessages = append(requestMessages, message)
}
}
request := openai.ChatCompletionRequest{
Model: params.Model,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Messages: requestMessages,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
request.ToolChoice = "auto"
}
return request
}
func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []openai.ToolCall,
) ([]model.Message, error) {
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
return []model.Message{toolCall, toolResult}, nil
}
func (c *OpenAIClient) CreateChatCompletion(
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(context.Background(), req)
if err != nil {
return "", err
}
choice := resp.Choices[0]
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
if err != nil {
return "", err
}
if results != nil {
*replies = append(*replies, results...)
}
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(params, messages, replies)
}
if replies != nil {
*replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant,
Content: choice.Message.Content,
})
}
// Return the user-facing message.
return choice.Message.Content, nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
output chan<- string,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil {
return "", err
}
defer stream.Close()
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
}
if e != nil {
err = e
break
}
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, content.String(), toolCalls)
if err != nil {
return content.String(), err
}
if results != nil {
*replies = append(*replies, results...)
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(params, messages, replies, output)
}
if replies != nil {
*replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
return content.String(), err
}

View File

@@ -0,0 +1,23 @@
package provider
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string.
CreateChatCompletion(
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
) (string, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received.
CreateChatCompletionStream(
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
output chan<- string,
) (string, error)
}

121
pkg/lmcli/store.go Normal file
View File

@@ -0,0 +1,121 @@
package lmcli
import (
"database/sql"
"errors"
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
sqids "github.com/sqids/sqids-go"
"gorm.io/gorm"
)
type ConversationStore interface {
Conversations() ([]model.Conversation, error)
ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationShortNameCompletions(search string) []string
SaveConversation(conversation *model.Conversation) error
DeleteConversation(conversation *model.Conversation) error
Messages(conversation *model.Conversation) ([]model.Message, error)
LastMessage(conversation *model.Conversation) (*model.Message, error)
SaveMessage(message *model.Message) error
DeleteMessage(message *model.Message) error
UpdateMessage(message *model.Message) error
}
type SQLStore struct {
db *gorm.DB
sqids *sqids.Sqids
}
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
models := []any{
&model.Conversation{},
&model.Message{},
}
for _, x := range models {
err := db.AutoMigrate(x)
if err != nil {
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
}
}
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &SQLStore{db, _sqids}, nil
}
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
err := s.db.Save(&conversation).Error
if err != nil {
return err
}
if !conversation.ShortName.Valid {
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error
}
return err
}
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
return s.db.Delete(&conversation).Error
}
func (s *SQLStore) SaveMessage(message *model.Message) error {
return s.db.Create(message).Error
}
func (s *SQLStore) DeleteMessage(message *model.Message) error {
return s.db.Delete(&message).Error
}
func (s *SQLStore) UpdateMessage(message *model.Message) error {
return s.db.Updates(&message).Error
}
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
var conversations []model.Conversation
err := s.db.Find(&conversations).Error
return conversations, err
}
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var completions []string
conversations, _ := s.Conversations() // ignore error for completions
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
}
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
if shortName == "" {
return nil, errors.New("shortName is empty")
}
var conversation model.Conversation
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
var messages []model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
return messages, err
}
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
var message model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
return &message, err
}

View File

@@ -0,0 +1,114 @@
package tools
import (
"fmt"
"os"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.`
var FileInsertLinesTool = model.Tool{
Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "position",
Type: "integer",
Description: `Which line to insert content *before*.`,
Required: true,
},
{
Name: "content",
Type: "string",
Description: `The content to insert.`,
Required: true,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var position int
tmp, ok = args["position"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
}
position = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileInsertLines(path, position, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileInsertLines(path string, position int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
if position < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
before := lines[:position-1]
after := lines[position-1:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{Result: newContent}
}

View File

@@ -0,0 +1,133 @@
package tools
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
Useful for re-writing snippets/blocks of code or entire functions.
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
var FileReplaceLinesTool = model.Tool{
Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "start_line",
Type: "integer",
Description: `Line number which specifies the start of the replacement range (inclusive).`,
Required: true,
},
{
Name: "end_line",
Type: "integer",
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
},
{
Name: "content",
Type: "string",
Description: `Content to replace specified range. Omit to remove the specified range.`,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start_line int
tmp, ok = args["start_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
}
start_line = int(tmp)
}
var end_line int
tmp, ok = args["end_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
}
end_line = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileReplaceLines(path, start_line, end_line, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
if startLine < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
before := lines[:startLine-1]
after := lines[endLine:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{Result: newContent}
}

100
pkg/lmcli/tools/read_dir.go Normal file
View File

@@ -0,0 +1,100 @@
package tools
import (
"fmt"
"os"
"path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
Example result:
{
"message": "success",
"result": [
{"name": "a_file.txt", "type": "file", "size": 123},
{"name": "a_directory/", "type": "dir", "size": 11},
...
]
}
For files, size represents the size of the file, in bytes.
For directories, size represents the number of entries in that directory.`
var ReadDirTool = model.Tool{
Name: "read_dir",
Description: READ_DIR_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "relative_dir",
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
result := readDir(relativeDir)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func readDir(path string) model.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
files, err := os.ReadDir(path)
if err != nil {
return model.CallResult{
Message: err.Error(),
}
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
name := f.Name()
if strings.HasPrefix(name, ".") {
// skip hidden files
continue
}
entryType := "file"
size := info.Size()
if info.IsDir() {
name += "/"
entryType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
size = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": name,
"type": entryType,
"size": size,
})
}
return model.CallResult{Result: dirContents}
}

View File

@@ -0,0 +1,71 @@
package tools
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
Each line of the returned content is prefixed with its line number and a tab (\t).
Example result:
{
"message": "success",
"result": "1\tthe contents\n2\tof the file\n"
}`
var ReadFileTool = model.Tool{
Name: "read_file",
Description: READ_FILE_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path to a file within the current working directory to read.",
Required: true,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
result := readFile(path)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func readFile(path string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
data, err := os.ReadFile(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
lines := strings.Split(string(data), "\n")
content := strings.Builder{}
for i, line := range lines {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return model.CallResult{
Result: content.String(),
}
}

51
pkg/lmcli/tools/tools.go Normal file
View File

@@ -0,0 +1,51 @@
package tools
import (
"fmt"
"os"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
var AvailableTools map[string]model.Tool = map[string]model.Tool{
"read_dir": ReadDirTool,
"read_file": ReadFileTool,
"write_file": WriteFileTool,
"file_insert_lines": FileInsertLinesTool,
"file_replace_lines": FileReplaceLinesTool,
}
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
var toolResults []model.ToolResult
for _, toolCall := range toolCalls {
var tool *model.Tool
for _, available := range toolBag {
if available.Name == toolCall.Name {
tool = &available
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
}
// TODO: ability to silence this
fmt.Fprintf(os.Stderr, "\nINFO: Executing tool '%s' with args %s\n", toolCall.Name, toolCall.Parameters)
// Execute the tool
result, err := tool.Impl(tool, toolCall.Parameters)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
}
toolResult := model.ToolResult{
ToolCallID: toolCall.ID,
ToolName: toolCall.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@@ -0,0 +1,67 @@
package util
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// isPathContained attempts to verify whether `path` is the same as or
// contained within `directory`. It is overly cautious, returning false even if
// `path` IS contained within `directory`, but the two paths use different
// casing, and we happen to be on a case-insensitive filesystem.
// This is ultimately to attempt to stop an LLM from going outside of where I
// tell it to. Additional layers of security should be considered.. run in a
// VM/container.
func IsPathContained(directory string, path string) (bool, error) {
// Clean and resolve symlinks for both paths
path, err := filepath.Abs(path)
if err != nil {
return false, err
}
// check if path exists
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return false, fmt.Errorf("Could not stat path: %v", err)
}
} else {
path, err = filepath.EvalSymlinks(path)
if err != nil {
return false, err
}
}
directory, err = filepath.Abs(directory)
if err != nil {
return false, err
}
directory, err = filepath.EvalSymlinks(directory)
if err != nil {
return false, err
}
// Case insensitive checks
if !strings.EqualFold(path, directory) &&
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
return false, nil
}
return true, nil
}
func IsPathWithinCWD(path string) (bool, string) {
cwd, err := os.Getwd()
if err != nil {
return false, "Failed to determine current working directory"
}
if ok, err := IsPathContained(cwd, path); !ok {
if err != nil {
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
}
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
}
return true, ""
}

View File

@@ -0,0 +1,71 @@
package tools
import (
"fmt"
"os"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
Example result:
{
"message": "success"
}`
var WriteFileTool = model.Tool{
Name: "write_file",
Description: WRITE_FILE_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path to a file within the current working directory to write to.",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "The content to write to the file. Overwrites any existing content!",
Required: true,
},
},
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["content"]
if !ok {
return "", fmt.Errorf("Content parameter to write_file was not included.")
}
content, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
result := writeFile(path, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func writeFile(path string, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
err := os.WriteFile(path, []byte(content), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{}
}