Compare commits

..

6 Commits

Author SHA1 Message Date
aeeb7bb7f7 tui: Add --system-prompt handling
And some state handling changes
2024-05-07 08:19:45 +00:00
2b38db7db7 Update command flag handling
`lmcli chat` now supports common prompt flags (model, length, system
prompt, etc)
2024-05-07 08:18:48 +00:00
8e4ff90ab4 Multiple provider configuration
Add support for having multiple openai or anthropic compatible providers
accessible via different baseUrls
2024-05-05 08:15:17 +00:00
bdaf6204f6 Add openai response error handling 2024-05-05 07:32:35 +00:00
1b9a8f319c Split anthropic types out to types.go 2024-04-29 06:16:41 +00:00
ffe9d299ef Remove go-openai 2024-04-29 06:14:36 +00:00
21 changed files with 353 additions and 217 deletions

1
go.mod
View File

@ -9,7 +9,6 @@ require (
github.com/charmbracelet/lipgloss v0.10.0 github.com/charmbracelet/lipgloss v0.10.0
github.com/go-yaml/yaml v2.1.0+incompatible github.com/go-yaml/yaml v2.1.0+incompatible
github.com/muesli/reflow v0.3.0 github.com/muesli/reflow v0.3.0
github.com/sashabaranov/go-openai v1.17.7
github.com/spf13/cobra v1.8.0 github.com/spf13/cobra v1.8.0
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
gopkg.in/yaml.v2 v2.2.2 gopkg.in/yaml.v2 v2.2.2

2
go.sum
View File

@ -61,8 +61,6 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View File

@ -33,5 +33,6 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -8,10 +8,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var (
systemPromptFile string
)
func RootCmd(ctx *lmcli.Context) *cobra.Command { func RootCmd(ctx *lmcli.Context) *cobra.Command {
var root = &cobra.Command{ var root = &cobra.Command{
Use: "lmcli <command> [flags]", Use: "lmcli <command> [flags]",
@ -23,58 +19,43 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
chatCmd := ChatCmd(ctx)
continueCmd := ContinueCmd(ctx)
cloneCmd := CloneCmd(ctx)
editCmd := EditCmd(ctx)
listCmd := ListCmd(ctx)
newCmd := NewCmd(ctx)
promptCmd := PromptCmd(ctx)
renameCmd := RenameCmd(ctx)
replyCmd := ReplyCmd(ctx)
retryCmd := RetryCmd(ctx)
rmCmd := RemoveCmd(ctx)
viewCmd := ViewCmd(ctx)
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
for _, cmd := range inputCmds {
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
root.AddCommand( root.AddCommand(
chatCmd, ChatCmd(ctx),
cloneCmd, ContinueCmd(ctx),
continueCmd, CloneCmd(ctx),
editCmd, EditCmd(ctx),
listCmd, ListCmd(ctx),
newCmd, NewCmd(ctx),
promptCmd, PromptCmd(ctx),
renameCmd, RenameCmd(ctx),
replyCmd, ReplyCmd(ctx),
retryCmd, RetryCmd(ctx),
rmCmd, RemoveCmd(ctx),
viewCmd, ViewCmd(ctx),
) )
return root return root
} }
func getSystemPrompt(ctx *lmcli.Context) string { func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) {
if systemPromptFile != "" { f := cmd.Flags()
content, err := util.ReadFileContents(systemPromptFile)
if err != nil { f.StringVarP(
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err) ctx.Config.Defaults.Model,
} "model", "m",
return content *ctx.Config.Defaults.Model,
} "The model to generate a response with",
return *ctx.Config.Defaults.SystemPrompt )
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
} }
// inputFromArgsOrEditor returns either the provided input from the args slice // inputFromArgsOrEditor returns either the provided input from the args slice

View File

@ -44,7 +44,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
@ -68,5 +68,6 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -28,14 +28,14 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
messages := []model.Message{ messages := []model.Message{
{ {
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: getSystemPrompt(ctx), Content: ctx.GetSystemPrompt(),
}, },
{ {
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: messageContents, Content: messageContents,
}, },
} }
@ -56,5 +56,6 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -22,21 +22,23 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
messages := []model.Message{ messages := []model.Message{
{ {
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: getSystemPrompt(ctx), Content: ctx.GetSystemPrompt(),
}, },
{ {
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: message, Content: message,
}, },
} }
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) _, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }
return nil return nil
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -45,5 +45,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -54,5 +54,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -13,9 +13,9 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
// fetchAndShowCompletion prompts the LLM with the given messages and streams // Prompt prompts the configured the configured model and streams the response
// the response to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
@ -46,7 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
err = nil err = nil
} }
} }
return response, nil return response, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
@ -109,7 +109,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
} }
} }
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback) _, err = Prompt(ctx, allMessages, replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }

View File

@ -19,16 +19,14 @@ type Config struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
} `yaml:"conversations"` } `yaml:"conversations"`
Tools *struct { Tools *struct {
EnabledTools *[]string `yaml:"enabledTools"` EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"` } `yaml:"tools"`
OpenAI *struct { Providers []*struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"` Kind *string `yaml:"kind"`
Models *[]string `yaml:"models"` BaseURL *string `yaml:"baseUrl"`
} `yaml:"openai"` APIKey *string `yaml:"apiKey"`
Anthropic *struct { Models *[]string `yaml:"models"`
APIKey *string `yaml:"apiKey" default:"your_key_here"` } `yaml:"providers"`
Models *[]string `yaml:"models"`
} `yaml:"anthropic"`
Chroma *struct { Chroma *struct {
Style *string `yaml:"style" default:"onedark"` Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"` Formatter *string `yaml:"formatter" default:"terminal16m"`

View File

@ -10,17 +10,20 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
) )
type Context struct { type Context struct {
Config *Config Config *Config // may be updated at runtime
Store ConversationStore Store ConversationStore
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
EnabledTools []model.Tool EnabledTools []model.Tool
SystemPromptFile string
} }
func NewContext() (*Context, error) { func NewContext() (*Context, error) {
@ -43,46 +46,70 @@ func NewContext() (*Context, error) {
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
var enabledTools []model.Tool var enabledTools []model.Tool
for _, toolName := range *config.Tools.EnabledTools { for _, toolName := range config.Tools.EnabledTools {
tool, ok := tools.AvailableTools[toolName] tool, ok := tools.AvailableTools[toolName]
if ok { if ok {
enabledTools = append(enabledTools, tool) enabledTools = append(enabledTools, tool)
} }
} }
return &Context{config, store, chroma, enabledTools}, nil return &Context{config, store, chroma, enabledTools, ""}, nil
} }
func (c *Context) GetModels() (models []string) { func (c *Context) GetModels() (models []string) {
for _, m := range *c.Config.Anthropic.Models { for _, p := range c.Config.Providers {
models = append(models, m) for _, m := range *p.Models {
} models = append(models, m)
for _, m := range *c.Config.OpenAI.Models { }
models = append(models, m)
} }
return return
} }
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
for _, m := range *c.Config.Anthropic.Models { for _, p := range c.Config.Providers {
if m == model { for _, m := range *p.Models {
anthropic := &anthropic.AnthropicClient{ if m == model {
APIKey: *c.Config.Anthropic.APIKey, switch *p.Kind {
case "anthropic":
url := "https://api.anthropic.com/v1"
if p.BaseURL != nil {
url = *p.BaseURL
}
anthropic := &anthropic.AnthropicClient{
BaseURL: url,
APIKey: *p.APIKey,
}
return anthropic, nil
case "openai":
url := "https://api.openai.com/v1"
if p.BaseURL != nil {
url = *p.BaseURL
}
openai := &openai.OpenAIClient{
BaseURL: url,
APIKey: *p.APIKey,
}
return openai, nil
default:
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
}
} }
return anthropic, nil
}
}
for _, m := range *c.Config.OpenAI.Models {
if m == model {
openai := &openai.OpenAIClient{
APIKey: *c.Config.OpenAI.APIKey,
}
return openai, nil
} }
} }
return nil, fmt.Errorf("unknown model: %s", model) return nil, fmt.Errorf("unknown model: %s", model)
} }
func (c *Context) GetSystemPrompt() string {
if c.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err)
}
return content
}
return *c.Config.Defaults.SystemPrompt
}
func configDir() string { func configDir() string {
var configDir string var configDir string

View File

@ -32,13 +32,13 @@ type Conversation struct {
} }
type RequestParameters struct { type RequestParameters struct {
Model string Model string
MaxTokens int MaxTokens int
Temperature float32 Temperature float32
TopP float32 TopP float32
SystemPrompt string ToolBag []Tool
ToolBag []Tool
} }
func (m *MessageRole) IsAssistant() bool { func (m *MessageRole) IsAssistant() bool {

View File

@ -15,48 +15,10 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
type AnthropicClient struct {
APIKey string
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
}
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
}
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
func buildRequest(params model.RequestParameters, messages []model.Message) Request { func buildRequest(params model.RequestParameters, messages []model.Message) Request {
requestBody := Request{ requestBody := Request{
Model: params.Model, Model: params.Model,
Messages: make([]Message, len(messages)), Messages: make([]Message, len(messages)),
System: params.SystemPrompt,
MaxTokens: params.MaxTokens, MaxTokens: params.MaxTokens,
Temperature: params.Temperature, Temperature: params.Temperature,
Stream: false, Stream: false,
@ -118,14 +80,12 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
} }
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) { func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r) jsonBody, err := json.Marshal(r)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %v", err)
} }
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody)) req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err) return nil, fmt.Errorf("failed to create HTTP request: %v", err)
} }

View File

@ -9,6 +9,8 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
const TOOL_PREAMBLE = `You have access to the following tools when replying. const TOOL_PREAMBLE = `You have access to the following tools when replying.
You may call them like this: You may call them like this:

View File

@ -0,0 +1,38 @@
package anthropic
type AnthropicClient struct {
BaseURL string
APIKey string
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
}
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
}

View File

@ -1,45 +1,30 @@
package openai package openai
import ( import (
"bufio"
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
) )
type OpenAIClient struct { func convertTools(tools []model.Tool) []Tool {
APIKey string openaiTools := make([]Tool, len(tools))
}
type OpenAIToolParameters struct {
Type string `json:"type"`
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type OpenAIToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
func convertTools(tools []model.Tool) []openai.Tool {
openaiTools := make([]openai.Tool, len(tools))
for i, tool := range tools { for i, tool := range tools {
openaiTools[i].Type = "function" openaiTools[i].Type = "function"
params := make(map[string]OpenAIToolParameter) params := make(map[string]ToolParameter)
var required []string var required []string
for _, param := range tool.Parameters { for _, param := range tool.Parameters {
params[param.Name] = OpenAIToolParameter{ params[param.Name] = ToolParameter{
Type: param.Type, Type: param.Type,
Description: param.Description, Description: param.Description,
Enum: param.Enum, Enum: param.Enum,
@ -49,10 +34,10 @@ func convertTools(tools []model.Tool) []openai.Tool {
} }
} }
openaiTools[i].Function = openai.FunctionDefinition{ openaiTools[i].Function = FunctionDefinition{
Name: tool.Name, Name: tool.Name,
Description: tool.Description, Description: tool.Description,
Parameters: OpenAIToolParameters{ Parameters: ToolParameters{
Type: "object", Type: "object",
Properties: params, Properties: params,
Required: required, Required: required,
@ -62,8 +47,8 @@ func convertTools(tools []model.Tool) []openai.Tool {
return openaiTools return openaiTools
} }
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall { func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
converted := make([]openai.ToolCall, len(toolCalls)) converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].Type = "function" converted[i].Type = "function"
converted[i].ID = call.ID converted[i].ID = call.ID
@ -75,7 +60,7 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
return converted return converted
} }
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall { func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
converted := make([]model.ToolCall, len(toolCalls)) converted := make([]model.ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].ID = call.ID converted[i].ID = call.ID
@ -86,16 +71,15 @@ func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
} }
func createChatCompletionRequest( func createChatCompletionRequest(
c *OpenAIClient,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
) openai.ChatCompletionRequest { ) ChatCompletionRequest {
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages)) requestMessages := make([]ChatCompletionMessage, 0, len(messages))
for _, m := range messages { for _, m := range messages {
switch m.Role { switch m.Role {
case "tool_call": case "tool_call":
message := openai.ChatCompletionMessage{} message := ChatCompletionMessage{}
message.Role = "assistant" message.Role = "assistant"
message.Content = m.Content message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
@ -103,21 +87,21 @@ func createChatCompletionRequest(
case "tool_result": case "tool_result":
// expand tool_result messages' results into multiple openAI messages // expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults { for _, result := range m.ToolResults {
message := openai.ChatCompletionMessage{} message := ChatCompletionMessage{}
message.Role = "tool" message.Role = "tool"
message.Content = result.Result message.Content = result.Result
message.ToolCallID = result.ToolCallID message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message) requestMessages = append(requestMessages, message)
} }
default: default:
message := openai.ChatCompletionMessage{} message := ChatCompletionMessage{}
message.Role = string(m.Role) message.Role = string(m.Role)
message.Content = m.Content message.Content = m.Content
requestMessages = append(requestMessages, message) requestMessages = append(requestMessages, message)
} }
} }
request := openai.ChatCompletionRequest{ request := ChatCompletionRequest{
Model: params.Model, Model: params.Model,
MaxTokens: params.MaxTokens, MaxTokens: params.MaxTokens,
Temperature: params.Temperature, Temperature: params.Temperature,
@ -136,7 +120,7 @@ func createChatCompletionRequest(
func handleToolCalls( func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []openai.ToolCall, toolCalls []ToolCall,
callback provider.ReplyCallback, callback provider.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
@ -177,6 +161,21 @@ func handleToolCalls(
return messages, nil return messages, nil
} }
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx))
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
@ -187,14 +186,30 @@ func (c *OpenAIClient) CreateChatCompletion(
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(params, messages)
req := createChatCompletionRequest(c, params, messages) jsonData, err := json.Marshal(req)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil { if err != nil {
return "", err return "", err
} }
choice := resp.Choices[0] httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return "", err
}
choice := completionResp.Choices[0]
var content string var content string
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -236,36 +251,60 @@ func (c *OpenAIClient) CreateChatCompletionStream(
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(params, messages)
req := createChatCompletionRequest(c, params, messages) req.Stream = true
stream, err := client.CreateChatCompletionStream(ctx, req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return "", err
} }
defer stream.Close()
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
content := strings.Builder{} content := strings.Builder{}
toolCalls := []openai.ToolCall{} toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() { if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content) content.WriteString(lastMessage.Content)
} }
// Iterate stream segments reader := bufio.NewReader(resp.Body)
for { for {
response, e := stream.Recv() line, err := reader.ReadBytes('\n')
if errors.Is(e, io.EOF) { if err != nil {
if err == io.EOF {
break
}
return "", err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(line, []byte("[DONE]")) {
break break
} }
if e != nil { var streamResp ChatCompletionStreamResponse
err = e err = json.Unmarshal(line, &streamResp)
break if err != nil {
return "", err
} }
delta := response.Choices[0].Delta delta := streamResp.Choices[0].Delta
if len(delta.ToolCalls) > 0 { if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments // Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls { for _, tc := range delta.ToolCalls {
@ -278,7 +317,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
} }
} }
} else { }
if len(delta.Content) > 0 {
output <- delta.Content output <- delta.Content
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }
@ -301,5 +341,5 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
return content.String(), err return content.String(), nil
} }

View File

@ -0,0 +1,71 @@
package openai
type OpenAIClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall struct {
Type string `json:"type"`
ID string `json:"id"`
Index *int `json:"index,omitempty"`
Function FunctionDefinition `json:"function"`
}
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
Arguments string `json:"arguments,omitempty"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Messages []ChatCompletionMessage `json:"messages"`
N int `json:"n"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionChoice struct {
Message ChatCompletionMessage `json:"message"`
}
type ChatCompletionResponse struct {
Choices []ChatCompletionChoice `json:"choices"`
}
type ChatCompletionStreamChoice struct {
Delta ChatCompletionMessage `json:"delta"`
}
type ChatCompletionStreamResponse struct {
Choices []ChatCompletionStreamChoice `json:"choices"`
}

View File

@ -115,6 +115,14 @@ func newChatModel(tui *model) chatModel {
)), )),
} }
system := tui.ctx.GetSystemPrompt()
if system != "" {
m.messages = []models.Message{{
Role: models.MessageRoleSystem,
Content: system,
}}
}
m.input.Focus() m.input.Focus()
m.input.MaxHeight = 0 m.input.MaxHeight = 0
m.input.CharLimit = 0 m.input.CharLimit = 0
@ -177,16 +185,17 @@ func (m *chatModel) handleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
switch msg.String() { switch msg.String() {
case "esc": case "esc":
return true, func() tea.Msg {
return msgChangeState(stateConversations)
}
case "ctrl+c":
if m.waitingForReply { if m.waitingForReply {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return true, nil
} }
return true, func() tea.Msg { return true, func() tea.Msg {
return msgChangeState(stateConversations) return msgStateChange(stateConversations)
}
case "ctrl+c":
if m.waitingForReply {
m.stopSignal <- struct{}{}
return true, nil
} }
case "ctrl+p": case "ctrl+p":
m.persistence = !m.persistence m.persistence = !m.persistence
@ -227,10 +236,12 @@ func (m *chatModel) handleResize(width, height int) {
func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) { func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case msgChangeState: case msgStateEnter:
if m.opts.convShortname != "" && m.conversation.ShortName.String != m.opts.convShortname { if m.opts.convShortname != "" && m.conversation.ShortName.String != m.opts.convShortname {
cmds = append(cmds, m.loadConversation(m.opts.convShortname)) cmds = append(cmds, m.loadConversation(m.opts.convShortname))
} }
m.rebuildMessageCache()
m.updateContent()
case tea.WindowSizeMsg: case tea.WindowSizeMsg:
m.handleResize(msg.Width, msg.Height) m.handleResize(msg.Width, msg.Height)
case msgTempfileEditorClosed: case msgTempfileEditorClosed:
@ -254,7 +265,8 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
cmds = append(cmds, m.loadMessages(m.conversation)) cmds = append(cmds, m.loadMessages(m.conversation))
case msgMessagesLoaded: case msgMessagesLoaded:
m.selectedMessage = len(msg) - 1 m.selectedMessage = len(msg) - 1
m.setMessages(msg) m.messages = msg
m.rebuildMessageCache()
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
case msgResponseChunk: case msgResponseChunk:
@ -358,10 +370,12 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
fixedHeight := height(m.views.header) + height(m.views.error) + height(m.views.footer) fixedHeight := height(m.views.header) + height(m.views.error) + height(m.views.footer)
// calculate clamped input height to accomodate input text // calculate clamped input height to accomodate input text
// minimum 4 lines, maximum half of content area
newHeight := max(4, min((m.height-fixedHeight-1)/2, m.input.LineCount())) newHeight := max(4, min((m.height-fixedHeight-1)/2, m.input.LineCount()))
m.input.SetHeight(newHeight) m.input.SetHeight(newHeight)
m.views.input = m.input.View() m.views.input = m.input.View()
// remaining height towards content
m.content.Height = m.height - fixedHeight - height(m.views.input) m.content.Height = m.height - fixedHeight - height(m.views.input)
m.views.content = m.content.View() m.views.content = m.content.View()
} }
@ -701,11 +715,6 @@ func (m *chatModel) footerView() string {
return footerStyle.Width(m.width).Render(footer) return footerStyle.Width(m.width).Render(footer)
} }
func (m *chatModel) setMessages(messages []models.Message) {
m.messages = messages
m.rebuildMessageCache()
}
func (m *chatModel) setMessage(i int, msg models.Message) { func (m *chatModel) setMessage(i int, msg models.Message) {
if i >= len(m.messages) { if i >= len(m.messages) {
panic("i out of range") panic("i out of range")

View File

@ -115,7 +115,7 @@ func (m *conversationsModel) handleResize(width, height int) {
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) { func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case msgChangeState: case msgStateChange:
cmds = append(cmds, m.loadConversations()) cmds = append(cmds, m.loadConversations())
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
case tea.WindowSizeMsg: case tea.WindowSizeMsg:

View File

@ -35,8 +35,10 @@ type views struct {
} }
type ( type (
// send to change the current app state // send to change the current state
msgChangeState state msgStateChange state
// sent to a state when it is entered
msgStateEnter struct{}
// sent when an error occurs // sent when an error occurs
msgError error msgError error
) )
@ -81,7 +83,7 @@ func (m model) Init() tea.Cmd {
m.conversations.Init(), m.conversations.Init(),
m.chat.Init(), m.chat.Init(),
func() tea.Msg { func() tea.Msg {
return msgChangeState(m.state) return msgStateChange(m.state)
}, },
) )
} }
@ -124,18 +126,20 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if handled { if handled {
return m, cmd return m, cmd
} }
case msgChangeState: case msgStateChange:
switch msg { m.state = state(msg)
switch m.state {
case stateChat: case stateChat:
m.chat.handleResize(m.width, m.height) m.chat.handleResize(m.width, m.height)
case stateConversations: case stateConversations:
m.conversations.handleResize(m.width, m.height) m.conversations.handleResize(m.width, m.height)
} }
m.state = state(msg) return m, func() tea.Msg { return msgStateEnter(struct{}{}) }
case msgConversationSelected: case msgConversationSelected:
// passed up through conversation list model
m.opts.convShortname = msg.ShortName.String m.opts.convShortname = msg.ShortName.String
cmds = append(cmds, func() tea.Msg { cmds = append(cmds, func() tea.Msg {
return msgChangeState(stateChat) return msgStateChange(stateChat)
}) })
case tea.WindowSizeMsg: case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height m.width, m.height = msg.Width, msg.Height