Compare commits
No commits in common. "aeeb7bb7f79ccc5a1380b47a54e4b33559381101" and "08a202733244dc400a2e2d9ca0eb6df82404c522" have entirely different histories.
aeeb7bb7f7
...
08a2027332
1
go.mod
1
go.mod
@ -9,6 +9,7 @@ require (
|
|||||||
github.com/charmbracelet/lipgloss v0.10.0
|
github.com/charmbracelet/lipgloss v0.10.0
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||||
github.com/muesli/reflow v0.3.0
|
github.com/muesli/reflow v0.3.0
|
||||||
|
github.com/sashabaranov/go-openai v1.17.7
|
||||||
github.com/spf13/cobra v1.8.0
|
github.com/spf13/cobra v1.8.0
|
||||||
github.com/sqids/sqids-go v0.4.1
|
github.com/sqids/sqids-go v0.4.1
|
||||||
gopkg.in/yaml.v2 v2.2.2
|
gopkg.in/yaml.v2 v2.2.2
|
||||||
|
2
go.sum
2
go.sum
@ -61,6 +61,8 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
|
|||||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
|
||||||
|
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||||
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
|
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
|
||||||
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
|
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
|
@ -33,6 +33,5 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
applyPromptFlags(ctx, cmd)
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,10 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
systemPromptFile string
|
||||||
|
)
|
||||||
|
|
||||||
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
var root = &cobra.Command{
|
var root = &cobra.Command{
|
||||||
Use: "lmcli <command> [flags]",
|
Use: "lmcli <command> [flags]",
|
||||||
@ -19,43 +23,58 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chatCmd := ChatCmd(ctx)
|
||||||
|
continueCmd := ContinueCmd(ctx)
|
||||||
|
cloneCmd := CloneCmd(ctx)
|
||||||
|
editCmd := EditCmd(ctx)
|
||||||
|
listCmd := ListCmd(ctx)
|
||||||
|
newCmd := NewCmd(ctx)
|
||||||
|
promptCmd := PromptCmd(ctx)
|
||||||
|
renameCmd := RenameCmd(ctx)
|
||||||
|
replyCmd := ReplyCmd(ctx)
|
||||||
|
retryCmd := RetryCmd(ctx)
|
||||||
|
rmCmd := RemoveCmd(ctx)
|
||||||
|
viewCmd := ViewCmd(ctx)
|
||||||
|
|
||||||
|
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
||||||
|
for _, cmd := range inputCmds {
|
||||||
|
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
|
||||||
|
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
||||||
|
})
|
||||||
|
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||||
|
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||||
|
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
||||||
|
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||||
|
}
|
||||||
|
|
||||||
root.AddCommand(
|
root.AddCommand(
|
||||||
ChatCmd(ctx),
|
chatCmd,
|
||||||
ContinueCmd(ctx),
|
cloneCmd,
|
||||||
CloneCmd(ctx),
|
continueCmd,
|
||||||
EditCmd(ctx),
|
editCmd,
|
||||||
ListCmd(ctx),
|
listCmd,
|
||||||
NewCmd(ctx),
|
newCmd,
|
||||||
PromptCmd(ctx),
|
promptCmd,
|
||||||
RenameCmd(ctx),
|
renameCmd,
|
||||||
ReplyCmd(ctx),
|
replyCmd,
|
||||||
RetryCmd(ctx),
|
retryCmd,
|
||||||
RemoveCmd(ctx),
|
rmCmd,
|
||||||
ViewCmd(ctx),
|
viewCmd,
|
||||||
)
|
)
|
||||||
|
|
||||||
return root
|
return root
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) {
|
func getSystemPrompt(ctx *lmcli.Context) string {
|
||||||
f := cmd.Flags()
|
if systemPromptFile != "" {
|
||||||
|
content, err := util.ReadFileContents(systemPromptFile)
|
||||||
f.StringVarP(
|
if err != nil {
|
||||||
ctx.Config.Defaults.Model,
|
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
||||||
"model", "m",
|
}
|
||||||
*ctx.Config.Defaults.Model,
|
return content
|
||||||
"The model to generate a response with",
|
}
|
||||||
)
|
return *ctx.Config.Defaults.SystemPrompt
|
||||||
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
|
||||||
})
|
|
||||||
|
|
||||||
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
|
||||||
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
|
|
||||||
|
|
||||||
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
|
||||||
f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
|
||||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
|
@ -44,7 +44,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
fmt.Print(lastMessage.Content)
|
fmt.Print(lastMessage.Content)
|
||||||
|
|
||||||
// Submit the LLM request, allowing it to continue the last message
|
// Submit the LLM request, allowing it to continue the last message
|
||||||
continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
|
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
@ -68,6 +68,5 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
applyPromptFlags(ctx, cmd)
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
{
|
{
|
||||||
ConversationID: conversation.ID,
|
ConversationID: conversation.ID,
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: ctx.GetSystemPrompt(),
|
Content: getSystemPrompt(ctx),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ConversationID: conversation.ID,
|
ConversationID: conversation.ID,
|
||||||
@ -56,6 +56,5 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyPromptFlags(ctx, cmd)
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
messages := []model.Message{
|
messages := []model.Message{
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: ctx.GetSystemPrompt(),
|
Content: getSystemPrompt(ctx),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleUser,
|
Role: model.MessageRoleUser,
|
||||||
@ -31,14 +31,12 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := cmdutil.Prompt(ctx, messages, nil)
|
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyPromptFlags(ctx, cmd)
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,5 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyPromptFlags(ctx, cmd)
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,5 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
applyPromptFlags(ctx, cmd)
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,9 @@ import (
|
|||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Prompt prompts the configured the configured model and streams the response
|
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||||
// to stdout. Returns all model reply messages.
|
// the response to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||||
content := make(chan string) // receives the reponse from LLM
|
content := make(chan string) // receives the reponse from LLM
|
||||||
defer close(content)
|
defer close(content)
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Me
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return response, err
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
// lookupConversation either returns the conversation found by the
|
||||||
@ -109,7 +109,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = Prompt(ctx, allMessages, replyCallback)
|
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -19,14 +19,16 @@ type Config struct {
|
|||||||
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||||
} `yaml:"conversations"`
|
} `yaml:"conversations"`
|
||||||
Tools *struct {
|
Tools *struct {
|
||||||
EnabledTools []string `yaml:"enabledTools"`
|
EnabledTools *[]string `yaml:"enabledTools"`
|
||||||
} `yaml:"tools"`
|
} `yaml:"tools"`
|
||||||
Providers []*struct {
|
OpenAI *struct {
|
||||||
Kind *string `yaml:"kind"`
|
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||||
BaseURL *string `yaml:"baseUrl"`
|
|
||||||
APIKey *string `yaml:"apiKey"`
|
|
||||||
Models *[]string `yaml:"models"`
|
Models *[]string `yaml:"models"`
|
||||||
} `yaml:"providers"`
|
} `yaml:"openai"`
|
||||||
|
Anthropic *struct {
|
||||||
|
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||||
|
Models *[]string `yaml:"models"`
|
||||||
|
} `yaml:"anthropic"`
|
||||||
Chroma *struct {
|
Chroma *struct {
|
||||||
Style *string `yaml:"style" default:"onedark"`
|
Style *string `yaml:"style" default:"onedark"`
|
||||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||||
|
@ -10,20 +10,17 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
Config *Config // may be updated at runtime
|
Config *Config
|
||||||
Store ConversationStore
|
Store ConversationStore
|
||||||
|
|
||||||
Chroma *tty.ChromaHighlighter
|
Chroma *tty.ChromaHighlighter
|
||||||
EnabledTools []model.Tool
|
EnabledTools []model.Tool
|
||||||
|
|
||||||
SystemPromptFile string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContext() (*Context, error) {
|
func NewContext() (*Context, error) {
|
||||||
@ -46,70 +43,46 @@ func NewContext() (*Context, error) {
|
|||||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||||
|
|
||||||
var enabledTools []model.Tool
|
var enabledTools []model.Tool
|
||||||
for _, toolName := range config.Tools.EnabledTools {
|
for _, toolName := range *config.Tools.EnabledTools {
|
||||||
tool, ok := tools.AvailableTools[toolName]
|
tool, ok := tools.AvailableTools[toolName]
|
||||||
if ok {
|
if ok {
|
||||||
enabledTools = append(enabledTools, tool)
|
enabledTools = append(enabledTools, tool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Context{config, store, chroma, enabledTools, ""}, nil
|
return &Context{config, store, chroma, enabledTools}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
for _, p := range c.Config.Providers {
|
for _, m := range *c.Config.Anthropic.Models {
|
||||||
for _, m := range *p.Models {
|
|
||||||
models = append(models, m)
|
models = append(models, m)
|
||||||
}
|
}
|
||||||
|
for _, m := range *c.Config.OpenAI.Models {
|
||||||
|
models = append(models, m)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||||
for _, p := range c.Config.Providers {
|
for _, m := range *c.Config.Anthropic.Models {
|
||||||
for _, m := range *p.Models {
|
|
||||||
if m == model {
|
if m == model {
|
||||||
switch *p.Kind {
|
|
||||||
case "anthropic":
|
|
||||||
url := "https://api.anthropic.com/v1"
|
|
||||||
if p.BaseURL != nil {
|
|
||||||
url = *p.BaseURL
|
|
||||||
}
|
|
||||||
anthropic := &anthropic.AnthropicClient{
|
anthropic := &anthropic.AnthropicClient{
|
||||||
BaseURL: url,
|
APIKey: *c.Config.Anthropic.APIKey,
|
||||||
APIKey: *p.APIKey,
|
|
||||||
}
|
}
|
||||||
return anthropic, nil
|
return anthropic, nil
|
||||||
case "openai":
|
|
||||||
url := "https://api.openai.com/v1"
|
|
||||||
if p.BaseURL != nil {
|
|
||||||
url = *p.BaseURL
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
for _, m := range *c.Config.OpenAI.Models {
|
||||||
|
if m == model {
|
||||||
openai := &openai.OpenAIClient{
|
openai := &openai.OpenAIClient{
|
||||||
BaseURL: url,
|
APIKey: *c.Config.OpenAI.APIKey,
|
||||||
APIKey: *p.APIKey,
|
|
||||||
}
|
}
|
||||||
return openai, nil
|
return openai, nil
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown model: %s", model)
|
return nil, fmt.Errorf("unknown model: %s", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetSystemPrompt() string {
|
|
||||||
if c.SystemPromptFile != "" {
|
|
||||||
content, err := util.ReadFileContents(c.SystemPromptFile)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
return *c.Config.Defaults.SystemPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
func configDir() string {
|
func configDir() string {
|
||||||
var configDir string
|
var configDir string
|
||||||
|
|
||||||
|
@ -33,11 +33,11 @@ type Conversation struct {
|
|||||||
|
|
||||||
type RequestParameters struct {
|
type RequestParameters struct {
|
||||||
Model string
|
Model string
|
||||||
|
|
||||||
MaxTokens int
|
MaxTokens int
|
||||||
Temperature float32
|
Temperature float32
|
||||||
TopP float32
|
TopP float32
|
||||||
|
|
||||||
|
SystemPrompt string
|
||||||
ToolBag []Tool
|
ToolBag []Tool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,10 +15,48 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AnthropicClient struct {
|
||||||
|
APIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
//TopP float32 `json:"top_p,omitempty"`
|
||||||
|
//TopK float32 `json:"top_k,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OriginalContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content []OriginalContent `json:"content"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
StopSequence string `json:"stop_sequence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||||
|
|
||||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
||||||
requestBody := Request{
|
requestBody := Request{
|
||||||
Model: params.Model,
|
Model: params.Model,
|
||||||
Messages: make([]Message, len(messages)),
|
Messages: make([]Message, len(messages)),
|
||||||
|
System: params.SystemPrompt,
|
||||||
MaxTokens: params.MaxTokens,
|
MaxTokens: params.MaxTokens,
|
||||||
Temperature: params.Temperature,
|
Temperature: params.Temperature,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
@ -80,12 +118,14 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||||
|
url := "https://api.anthropic.com/v1/messages"
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(r)
|
jsonBody, err := json.Marshal(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -9,8 +9,6 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
|
||||||
|
|
||||||
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
||||||
|
|
||||||
You may call them like this:
|
You may call them like this:
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
package anthropic
|
|
||||||
|
|
||||||
type AnthropicClient struct {
|
|
||||||
BaseURL string
|
|
||||||
APIKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
//TopP float32 `json:"top_p,omitempty"`
|
|
||||||
//TopK float32 `json:"top_k,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OriginalContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content []OriginalContent `json:"content"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
StopSequence string `json:"stop_sequence"`
|
|
||||||
}
|
|
||||||
|
|
@ -1,30 +1,45 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
|
openai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func convertTools(tools []model.Tool) []Tool {
|
type OpenAIClient struct {
|
||||||
openaiTools := make([]Tool, len(tools))
|
APIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTools(tools []model.Tool) []openai.Tool {
|
||||||
|
openaiTools := make([]openai.Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
openaiTools[i].Type = "function"
|
openaiTools[i].Type = "function"
|
||||||
|
|
||||||
params := make(map[string]ToolParameter)
|
params := make(map[string]OpenAIToolParameter)
|
||||||
var required []string
|
var required []string
|
||||||
|
|
||||||
for _, param := range tool.Parameters {
|
for _, param := range tool.Parameters {
|
||||||
params[param.Name] = ToolParameter{
|
params[param.Name] = OpenAIToolParameter{
|
||||||
Type: param.Type,
|
Type: param.Type,
|
||||||
Description: param.Description,
|
Description: param.Description,
|
||||||
Enum: param.Enum,
|
Enum: param.Enum,
|
||||||
@ -34,10 +49,10 @@ func convertTools(tools []model.Tool) []Tool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiTools[i].Function = FunctionDefinition{
|
openaiTools[i].Function = openai.FunctionDefinition{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: tool.Description,
|
Description: tool.Description,
|
||||||
Parameters: ToolParameters{
|
Parameters: OpenAIToolParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: params,
|
Properties: params,
|
||||||
Required: required,
|
Required: required,
|
||||||
@ -47,8 +62,8 @@ func convertTools(tools []model.Tool) []Tool {
|
|||||||
return openaiTools
|
return openaiTools
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
|
||||||
converted := make([]ToolCall, len(toolCalls))
|
converted := make([]openai.ToolCall, len(toolCalls))
|
||||||
for i, call := range toolCalls {
|
for i, call := range toolCalls {
|
||||||
converted[i].Type = "function"
|
converted[i].Type = "function"
|
||||||
converted[i].ID = call.ID
|
converted[i].ID = call.ID
|
||||||
@ -60,7 +75,7 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
|||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
|
||||||
converted := make([]model.ToolCall, len(toolCalls))
|
converted := make([]model.ToolCall, len(toolCalls))
|
||||||
for i, call := range toolCalls {
|
for i, call := range toolCalls {
|
||||||
converted[i].ID = call.ID
|
converted[i].ID = call.ID
|
||||||
@ -71,15 +86,16 @@ func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createChatCompletionRequest(
|
func createChatCompletionRequest(
|
||||||
|
c *OpenAIClient,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
) ChatCompletionRequest {
|
) openai.ChatCompletionRequest {
|
||||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
|
||||||
|
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case "tool_call":
|
case "tool_call":
|
||||||
message := ChatCompletionMessage{}
|
message := openai.ChatCompletionMessage{}
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
message.Content = m.Content
|
message.Content = m.Content
|
||||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||||
@ -87,21 +103,21 @@ func createChatCompletionRequest(
|
|||||||
case "tool_result":
|
case "tool_result":
|
||||||
// expand tool_result messages' results into multiple openAI messages
|
// expand tool_result messages' results into multiple openAI messages
|
||||||
for _, result := range m.ToolResults {
|
for _, result := range m.ToolResults {
|
||||||
message := ChatCompletionMessage{}
|
message := openai.ChatCompletionMessage{}
|
||||||
message.Role = "tool"
|
message.Role = "tool"
|
||||||
message.Content = result.Result
|
message.Content = result.Result
|
||||||
message.ToolCallID = result.ToolCallID
|
message.ToolCallID = result.ToolCallID
|
||||||
requestMessages = append(requestMessages, message)
|
requestMessages = append(requestMessages, message)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
message := ChatCompletionMessage{}
|
message := openai.ChatCompletionMessage{}
|
||||||
message.Role = string(m.Role)
|
message.Role = string(m.Role)
|
||||||
message.Content = m.Content
|
message.Content = m.Content
|
||||||
requestMessages = append(requestMessages, message)
|
requestMessages = append(requestMessages, message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := ChatCompletionRequest{
|
request := openai.ChatCompletionRequest{
|
||||||
Model: params.Model,
|
Model: params.Model,
|
||||||
MaxTokens: params.MaxTokens,
|
MaxTokens: params.MaxTokens,
|
||||||
Temperature: params.Temperature,
|
Temperature: params.Temperature,
|
||||||
@ -120,7 +136,7 @@ func createChatCompletionRequest(
|
|||||||
func handleToolCalls(
|
func handleToolCalls(
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
content string,
|
content string,
|
||||||
toolCalls []ToolCall,
|
toolCalls []openai.ToolCall,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
) ([]model.Message, error) {
|
) ([]model.Message, error) {
|
||||||
@ -161,21 +177,6 @@ func handleToolCalls(
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req.WithContext(ctx))
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
bytes, _ := io.ReadAll(resp.Body)
|
|
||||||
return resp, fmt.Errorf("%v", string(bytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletion(
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
@ -186,30 +187,14 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createChatCompletionRequest(params, messages)
|
client := openai.NewClient(c.APIKey)
|
||||||
jsonData, err := json.Marshal(req)
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
|
resp, err := client.CreateChatCompletion(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
choice := resp.Choices[0]
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var completionResp ChatCompletionResponse
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
choice := completionResp.Choices[0]
|
|
||||||
|
|
||||||
var content string
|
var content string
|
||||||
lastMessage := messages[len(messages)-1]
|
lastMessage := messages[len(messages)-1]
|
||||||
@ -251,60 +236,36 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createChatCompletionRequest(params, messages)
|
client := openai.NewClient(c.APIKey)
|
||||||
req.Stream = true
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
|
|
||||||
jsonData, err := json.Marshal(req)
|
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
defer stream.Close()
|
||||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
content := strings.Builder{}
|
content := strings.Builder{}
|
||||||
toolCalls := []ToolCall{}
|
toolCalls := []openai.ToolCall{}
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
lastMessage := messages[len(messages)-1]
|
||||||
if lastMessage.Role.IsAssistant() {
|
if lastMessage.Role.IsAssistant() {
|
||||||
content.WriteString(lastMessage.Content)
|
content.WriteString(lastMessage.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
// Iterate stream segments
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
response, e := stream.Recv()
|
||||||
if err != nil {
|
if errors.Is(e, io.EOF) {
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
|
||||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
||||||
if bytes.Equal(line, []byte("[DONE]")) {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
var streamResp ChatCompletionStreamResponse
|
if e != nil {
|
||||||
err = json.Unmarshal(line, &streamResp)
|
err = e
|
||||||
if err != nil {
|
break
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
delta := streamResp.Choices[0].Delta
|
delta := response.Choices[0].Delta
|
||||||
if len(delta.ToolCalls) > 0 {
|
if len(delta.ToolCalls) > 0 {
|
||||||
// Construct streamed tool_call arguments
|
// Construct streamed tool_call arguments
|
||||||
for _, tc := range delta.ToolCalls {
|
for _, tc := range delta.ToolCalls {
|
||||||
@ -317,8 +278,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
if len(delta.Content) > 0 {
|
|
||||||
output <- delta.Content
|
output <- delta.Content
|
||||||
content.WriteString(delta.Content)
|
content.WriteString(delta.Content)
|
||||||
}
|
}
|
||||||
@ -341,5 +301,5 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return content.String(), nil
|
return content.String(), err
|
||||||
}
|
}
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
type OpenAIClient struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
Index *int `json:"index,omitempty"`
|
|
||||||
Function FunctionDefinition `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionDefinition struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters ToolParameters `json:"parameters"`
|
|
||||||
Arguments string `json:"arguments,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function FunctionDefinition `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
Messages []ChatCompletionMessage `json:"messages"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
ToolChoice string `json:"tool_choice,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionChoice struct {
|
|
||||||
Message ChatCompletionMessage `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionResponse struct {
|
|
||||||
Choices []ChatCompletionChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionStreamChoice struct {
|
|
||||||
Delta ChatCompletionMessage `json:"delta"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionStreamResponse struct {
|
|
||||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
|
||||||
}
|
|
@ -115,14 +115,6 @@ func newChatModel(tui *model) chatModel {
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
system := tui.ctx.GetSystemPrompt()
|
|
||||||
if system != "" {
|
|
||||||
m.messages = []models.Message{{
|
|
||||||
Role: models.MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.input.Focus()
|
m.input.Focus()
|
||||||
m.input.MaxHeight = 0
|
m.input.MaxHeight = 0
|
||||||
m.input.CharLimit = 0
|
m.input.CharLimit = 0
|
||||||
@ -185,18 +177,17 @@ func (m *chatModel) handleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
|
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "esc":
|
case "esc":
|
||||||
if m.waitingForReply {
|
|
||||||
m.stopSignal <- struct{}{}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return true, func() tea.Msg {
|
return true, func() tea.Msg {
|
||||||
return msgStateChange(stateConversations)
|
return msgChangeState(stateConversations)
|
||||||
}
|
}
|
||||||
case "ctrl+c":
|
case "ctrl+c":
|
||||||
if m.waitingForReply {
|
if m.waitingForReply {
|
||||||
m.stopSignal <- struct{}{}
|
m.stopSignal <- struct{}{}
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
return true, func() tea.Msg {
|
||||||
|
return msgChangeState(stateConversations)
|
||||||
|
}
|
||||||
case "ctrl+p":
|
case "ctrl+p":
|
||||||
m.persistence = !m.persistence
|
m.persistence = !m.persistence
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -236,12 +227,10 @@ func (m *chatModel) handleResize(width, height int) {
|
|||||||
func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
||||||
var cmds []tea.Cmd
|
var cmds []tea.Cmd
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case msgStateEnter:
|
case msgChangeState:
|
||||||
if m.opts.convShortname != "" && m.conversation.ShortName.String != m.opts.convShortname {
|
if m.opts.convShortname != "" && m.conversation.ShortName.String != m.opts.convShortname {
|
||||||
cmds = append(cmds, m.loadConversation(m.opts.convShortname))
|
cmds = append(cmds, m.loadConversation(m.opts.convShortname))
|
||||||
}
|
}
|
||||||
m.rebuildMessageCache()
|
|
||||||
m.updateContent()
|
|
||||||
case tea.WindowSizeMsg:
|
case tea.WindowSizeMsg:
|
||||||
m.handleResize(msg.Width, msg.Height)
|
m.handleResize(msg.Width, msg.Height)
|
||||||
case msgTempfileEditorClosed:
|
case msgTempfileEditorClosed:
|
||||||
@ -265,8 +254,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
cmds = append(cmds, m.loadMessages(m.conversation))
|
cmds = append(cmds, m.loadMessages(m.conversation))
|
||||||
case msgMessagesLoaded:
|
case msgMessagesLoaded:
|
||||||
m.selectedMessage = len(msg) - 1
|
m.selectedMessage = len(msg) - 1
|
||||||
m.messages = msg
|
m.setMessages(msg)
|
||||||
m.rebuildMessageCache()
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
m.content.GotoBottom()
|
m.content.GotoBottom()
|
||||||
case msgResponseChunk:
|
case msgResponseChunk:
|
||||||
@ -370,12 +358,10 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
fixedHeight := height(m.views.header) + height(m.views.error) + height(m.views.footer)
|
fixedHeight := height(m.views.header) + height(m.views.error) + height(m.views.footer)
|
||||||
|
|
||||||
// calculate clamped input height to accomodate input text
|
// calculate clamped input height to accomodate input text
|
||||||
// minimum 4 lines, maximum half of content area
|
|
||||||
newHeight := max(4, min((m.height-fixedHeight-1)/2, m.input.LineCount()))
|
newHeight := max(4, min((m.height-fixedHeight-1)/2, m.input.LineCount()))
|
||||||
m.input.SetHeight(newHeight)
|
m.input.SetHeight(newHeight)
|
||||||
m.views.input = m.input.View()
|
m.views.input = m.input.View()
|
||||||
|
|
||||||
// remaining height towards content
|
|
||||||
m.content.Height = m.height - fixedHeight - height(m.views.input)
|
m.content.Height = m.height - fixedHeight - height(m.views.input)
|
||||||
m.views.content = m.content.View()
|
m.views.content = m.content.View()
|
||||||
}
|
}
|
||||||
@ -715,6 +701,11 @@ func (m *chatModel) footerView() string {
|
|||||||
return footerStyle.Width(m.width).Render(footer)
|
return footerStyle.Width(m.width).Render(footer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *chatModel) setMessages(messages []models.Message) {
|
||||||
|
m.messages = messages
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *chatModel) setMessage(i int, msg models.Message) {
|
func (m *chatModel) setMessage(i int, msg models.Message) {
|
||||||
if i >= len(m.messages) {
|
if i >= len(m.messages) {
|
||||||
panic("i out of range")
|
panic("i out of range")
|
||||||
|
@ -115,7 +115,7 @@ func (m *conversationsModel) handleResize(width, height int) {
|
|||||||
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
||||||
var cmds []tea.Cmd
|
var cmds []tea.Cmd
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case msgStateChange:
|
case msgChangeState:
|
||||||
cmds = append(cmds, m.loadConversations())
|
cmds = append(cmds, m.loadConversations())
|
||||||
m.content.SetContent(m.renderConversationList())
|
m.content.SetContent(m.renderConversationList())
|
||||||
case tea.WindowSizeMsg:
|
case tea.WindowSizeMsg:
|
||||||
|
@ -35,10 +35,8 @@ type views struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// send to change the current state
|
// send to change the current app state
|
||||||
msgStateChange state
|
msgChangeState state
|
||||||
// sent to a state when it is entered
|
|
||||||
msgStateEnter struct{}
|
|
||||||
// sent when an error occurs
|
// sent when an error occurs
|
||||||
msgError error
|
msgError error
|
||||||
)
|
)
|
||||||
@ -83,7 +81,7 @@ func (m model) Init() tea.Cmd {
|
|||||||
m.conversations.Init(),
|
m.conversations.Init(),
|
||||||
m.chat.Init(),
|
m.chat.Init(),
|
||||||
func() tea.Msg {
|
func() tea.Msg {
|
||||||
return msgStateChange(m.state)
|
return msgChangeState(m.state)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -126,20 +124,18 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
if handled {
|
if handled {
|
||||||
return m, cmd
|
return m, cmd
|
||||||
}
|
}
|
||||||
case msgStateChange:
|
case msgChangeState:
|
||||||
m.state = state(msg)
|
switch msg {
|
||||||
switch m.state {
|
|
||||||
case stateChat:
|
case stateChat:
|
||||||
m.chat.handleResize(m.width, m.height)
|
m.chat.handleResize(m.width, m.height)
|
||||||
case stateConversations:
|
case stateConversations:
|
||||||
m.conversations.handleResize(m.width, m.height)
|
m.conversations.handleResize(m.width, m.height)
|
||||||
}
|
}
|
||||||
return m, func() tea.Msg { return msgStateEnter(struct{}{}) }
|
m.state = state(msg)
|
||||||
case msgConversationSelected:
|
case msgConversationSelected:
|
||||||
// passed up through conversation list model
|
|
||||||
m.opts.convShortname = msg.ShortName.String
|
m.opts.convShortname = msg.ShortName.String
|
||||||
cmds = append(cmds, func() tea.Msg {
|
cmds = append(cmds, func() tea.Msg {
|
||||||
return msgStateChange(stateChat)
|
return msgChangeState(stateChat)
|
||||||
})
|
})
|
||||||
case tea.WindowSizeMsg:
|
case tea.WindowSizeMsg:
|
||||||
m.width, m.height = msg.Width, msg.Height
|
m.width, m.height = msg.Width, msg.Height
|
||||||
|
Loading…
Reference in New Issue
Block a user