Compare commits

..

8 Commits

Author SHA1 Message Date
42c3297e54 Make Conversation a pointer refernece on Message
Instead of a value, which lead some odd handling of conversation
references.

Also fixed some formatting and removed an unnecessary (and probably
broken) setting of ConversationID in a call to
`cmdutil.HandleConversationReply`
2024-06-09 18:51:44 +00:00
a22119f738 Better handling of newly saved conversations
When a new conversation is created in the chat view's
`persistConversation`, we now set `rootMessages` appropriately.
2024-06-09 18:51:44 +00:00
a2c860252f Refactor pkg/lmcli/provider
Moved `ChangeCompletionInterface` to `pkg/api`, moved individual
providers to `pkg/api/provider`
2024-06-09 18:31:43 +00:00
d2d946b776 Wrap chunk content in a Chunk type
Preparing to include additional information with each chunk (e.g. token
count)
2024-06-09 18:31:43 +00:00
c963747066 Store fixes
We were taking double pointers (`**T`) in some areas, and in
we were not setting foreign references correctly in `StartConversation`
and `Reply`.
2024-06-09 18:31:40 +00:00
e334d9fc4f Remove forgotten printf 2024-06-09 16:19:22 +00:00
c1ead83939 Rename shared.State to shared.Shared 2024-06-09 16:19:19 +00:00
c9e92e186e Chat view cleanup
Replace `waitingForReply` and the `status` string with the `state`
variable.
2024-06-09 16:19:17 +00:00
22 changed files with 182 additions and 159 deletions

View File

@ -1,4 +1,4 @@
package provider package api
import ( import (
"context" "context"
@ -8,6 +8,10 @@ import (
type ReplyCallback func(model.Message) type ReplyCallback func(model.Message)
type Chunk struct {
Content string
}
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
@ -26,6 +30,6 @@ type ChatCompletionClient interface {
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, callback ReplyCallback,
output chan<- string, output chan<- Chunk,
) (string, error) ) (string, error)
} }

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -107,7 +107,7 @@ func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -160,8 +160,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- string, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -242,7 +242,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", fmt.Errorf("invalid text delta") return "", fmt.Errorf("invalid text delta")
} }
sb.WriteString(text) sb.WriteString(text)
output <- text output <- api.Chunk{
Content: text,
}
case "content_block_stop": case "content_block_stop":
// ignore? // ignore?
case "message_delta": case "message_delta":
@ -262,7 +264,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE output <- api.Chunk{
Content: FUNCTION_STOP_SEQUENCE,
}
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -187,7 +187,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []model.ToolCall, toolCalls []model.ToolCall,
callback provider.ReplyCallback, callback api.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -245,7 +245,7 @@ func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -325,8 +325,8 @@ func (c *Client) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- string, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -393,7 +393,9 @@ func (c *Client) CreateChatCompletionStream(
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- part.Text output <- api.Chunk {
Content: part.Text,
}
content.WriteString(part.Text) content.WriteString(part.Text)
} }
} }

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
) )
type OllamaClient struct { type OllamaClient struct {
@ -85,7 +85,7 @@ func (c *OllamaClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -117,9 +117,6 @@ func (c *OllamaClient) CreateChatCompletion(
} }
content := completionResp.Message.Content content := completionResp.Message.Content
fmt.Println(content)
if callback != nil { if callback != nil {
callback(model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
@ -134,8 +131,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- string, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -184,7 +181,9 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
output <- streamResp.Message.Content output <- api.Chunk{
Content: streamResp.Message.Content,
}
content.WriteString(streamResp.Message.Content) content.WriteString(streamResp.Message.Content)
} }
} }

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -121,7 +121,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []ToolCall, toolCalls []ToolCall,
callback provider.ReplyCallback, callback api.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -180,7 +180,7 @@ func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -244,8 +244,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- string, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -319,7 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
if len(delta.Content) > 0 { if len(delta.Content) > 0 {
output <- delta.Content output <- api.Chunk {
Content: delta.Content,
}
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }
} }

View File

@ -31,9 +31,8 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
} }
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
ConversationID: conversation.ID, Role: model.MessageRoleUser,
Role: model.MessageRoleUser, Content: reply,
Content: reply,
}) })
return nil return nil
}, },

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
@ -17,7 +18,7 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan api.Chunk) // receives the reponse from LLM
defer close(content) defer close(content)
// render all content received over the channel // render all content received over the channel
@ -251,7 +252,7 @@ func ShowWaitAnimation(signal chan any) {
// chunked) content is received on the channel, the waiting animation is // chunked) content is received on the channel, the waiting animation is
// replaced by the content. // replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func ShowDelayedContent(content <-chan string) { func ShowDelayedContent(content <-chan api.Chunk) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)
@ -264,7 +265,7 @@ func ShowDelayedContent(content <-chan string) {
<-waitSignal <-waitSignal
firstChunk = false firstChunk = false
} }
fmt.Print(chunk) fmt.Print(chunk.Content)
} }
} }

View File

@ -6,12 +6,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
@ -79,7 +79,7 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) { func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
parts := strings.Split(model, "/") parts := strings.Split(model, "/")
var provider string var provider string

View File

@ -16,17 +16,17 @@ const (
) )
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"index"` ConversationID *uint `gorm:"index"`
Conversation Conversation `gorm:"foreignKey:ConversationID"` Conversation *Conversation `gorm:"foreignKey:ConversationID"`
Content string Content string
Role MessageRole Role MessageRole
CreatedAt time.Time CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the model) ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results ToolResults ToolResults // a json array of tool results
ParentID *uint ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"` Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"` Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
@ -37,7 +37,7 @@ type Conversation struct {
ShortName sql.NullString ShortName sql.NullString
Title string Title string
SelectedRootID *uint SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
} }
type RequestParameters struct { type RequestParameters struct {

View File

@ -58,24 +58,28 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) saveNewConversation(c *model.Conversation) error { func (s *SQLStore) createConversation() (*model.Conversation, error) {
// Save the new conversation // Create the new conversation
err := s.db.Save(&c).Error c := &model.Conversation{}
err := s.db.Save(c).Error
if err != nil { if err != nil {
return err return nil, err
} }
// Generate and save its "short name" // Generate and save its "short name"
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)}) shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
c.ShortName = sql.NullString{String: shortName, Valid: true} c.ShortName = sql.NullString{String: shortName, Valid: true}
return s.UpdateConversation(c) err = s.db.Updates(c).Error
if err != nil {
return nil, err
}
return c, nil
} }
func (s *SQLStore) UpdateConversation(c *model.Conversation) error { func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
if c == nil || c.ID == 0 { if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)") return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(&c).Error return s.db.Updates(c).Error
} }
func (s *SQLStore) DeleteConversation(c *model.Conversation) error { func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
@ -84,7 +88,7 @@ func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
if err != nil { if err != nil {
return err return err
} }
return s.db.Delete(&c).Error return s.db.Delete(c).Error
} }
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error { func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
@ -96,7 +100,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error {
if m == nil || m.ID == 0 { if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)") return fmt.Errorf("Message is nil or invalid (missing ID)")
} }
return s.db.Updates(&m).Error return s.db.Updates(m).Error
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
@ -149,14 +153,13 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa
} }
// Create new conversation // Create new conversation
conversation := &model.Conversation{} conversation, err := s.createConversation()
err := s.saveNewConversation(conversation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Create first message // Create first message
messages[0].ConversationID = conversation.ID messages[0].Conversation = conversation
err = s.db.Create(&messages[0]).Error err = s.db.Create(&messages[0]).Error
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -187,19 +190,18 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers
return nil, 0, err return nil, 0, err
} }
clone := &model.Conversation{ clone, err := s.createConversation()
Title: toClone.Title + " - Clone", if err != nil {
}
if err := s.saveNewConversation(clone); err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err) return nil, 0, fmt.Errorf("Could not create clone: %s", err)
} }
clone.Title = toClone.Title + " - Clone"
var errors []error var errors []error
var messageCnt uint = 0 var messageCnt uint = 0
for _, root := range rootMessages { for _, root := range rootMessages {
messageCnt++ messageCnt++
newRoot := root newRoot := root
newRoot.ConversationID = clone.ID newRoot.ConversationID = &clone.ID
cloned, count, err := s.CloneBranch(newRoot) cloned, count, err := s.CloneBranch(newRoot)
if err != nil { if err != nil {
@ -230,9 +232,10 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to currentParent := to
for i := range messages { for i := range messages {
parent := currentParent
message := messages[i] message := messages[i]
message.ConversationID = currentParent.ConversationID message.Parent = parent
message.ParentID = &currentParent.ID message.Conversation = parent.Conversation
message.ID = 0 message.ID = 0
message.CreatedAt = time.Time{} message.CreatedAt = time.Time{}
@ -241,9 +244,9 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
} }
// update parent selected reply // update parent selected reply
currentParent.Replies = append(currentParent.Replies, message) parent.Replies = append(parent.Replies, message)
currentParent.SelectedReply = &message parent.SelectedReply = &message
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil { if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil {
return err return err
} }

View File

@ -9,7 +9,7 @@ type Values struct {
ConvShortname string ConvShortname string
} }
type State struct { type Shared struct {
Ctx *lmcli.Context Ctx *lmcli.Context
Values *Values Values *Values
Width int Width int

View File

@ -18,7 +18,7 @@ import (
// Application model // Application model
type Model struct { type Model struct {
shared.State shared.Shared
state shared.View state shared.View
chat chat.Model chat chat.Model
@ -27,15 +27,15 @@ type Model struct {
func initialModel(ctx *lmcli.Context, values shared.Values) Model { func initialModel(ctx *lmcli.Context, values shared.Values) Model {
m := Model{ m := Model{
State: shared.State{ Shared: shared.Shared{
Ctx: ctx, Ctx: ctx,
Values: &values, Values: &values,
}, },
} }
m.state = shared.StateChat m.state = shared.StateChat
m.chat = chat.Chat(m.State) m.chat = chat.Chat(m.Shared)
m.conversations = conversations.Conversations(m.State) m.conversations = conversations.Conversations(m.Shared)
return m return m
} }

View File

@ -3,6 +3,7 @@ package chat
import ( import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
@ -13,24 +14,10 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent on each chunk received from LLM // sent on each chunk received from LLM
msgResponseChunk string msgResponseChunk api.Chunk
// sent when response is finished being received // sent when response is finished being received
msgResponseEnd string msgResponseEnd string
// a special case of common.MsgError that stops the response waiting animation // a special case of common.MsgError that stops the response waiting animation
@ -48,6 +35,7 @@ type (
msgMessagesLoaded []models.Message msgMessagesLoaded []models.Message
// sent when the conversation has been persisted, triggers a reload of contents // sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct { msgConversationPersisted struct {
isNew bool
conversation *models.Conversation conversation *models.Conversation
messages []models.Message messages []models.Message
} }
@ -61,26 +49,47 @@ type (
msgMessageCloned *models.Message msgMessageCloned *models.Message
) )
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type state int
const (
idle state = iota
loading
pendingResponse
)
type Model struct { type Model struct {
shared.State shared.Shared
shared.Sections shared.Sections
// app state // app state
state state // current overall status of the view
conversation *models.Conversation conversation *models.Conversation
rootMessages []models.Message rootMessages []models.Message
messages []models.Message messages []models.Message
selectedMessage int selectedMessage int
waitingForReply bool
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan models.Message replyChan chan models.Message
replyChunkChan chan string replyChunkChan chan api.Chunk
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // ui state
focus focusState focus focusState
wrap bool // whether message content is wrapped to viewport width wrap bool // whether message content is wrapped to viewport width
status string // a general status message
showToolResults bool // whether tool calls and results are shown showToolResults bool // whether tool calls and results are shown
messageCache []string // cache of syntax highlighted and wrapped message content messageCache []string // cache of syntax highlighted and wrapped message content
messageOffsets []int messageOffsets []int
@ -97,16 +106,17 @@ type Model struct {
elapsed time.Duration elapsed time.Duration
} }
func Chat(state shared.State) Model { func Chat(shared shared.Shared) Model {
m := Model{ m := Model{
State: state, Shared: shared,
state: idle,
conversation: &models.Conversation{}, conversation: &models.Conversation{},
persistence: true, persistence: true,
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan string), replyChunkChan: make(chan api.Chunk),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,
@ -132,7 +142,7 @@ func Chat(state shared.State) Model {
m.replyCursor.SetChar(" ") m.replyCursor.SetChar(" ")
m.replyCursor.Focus() m.replyCursor.Focus()
system := state.Ctx.GetSystemPrompt() system := shared.Ctx.GetSystemPrompt()
if system != "" { if system != "" {
m.messages = []models.Message{{ m.messages = []models.Message{{
Role: models.MessageRoleSystem, Role: models.MessageRoleSystem,
@ -150,8 +160,6 @@ func Chat(state shared.State) Model {
m.input.FocusedStyle.Base = inputFocusedStyle m.input.FocusedStyle.Base = inputFocusedStyle
m.input.BlurredStyle.Base = inputBlurredStyle m.input.BlurredStyle.Base = inputBlurredStyle
m.waitingForReply = false
m.status = "Press ctrl+s to send"
return m return m
} }

View File

@ -53,14 +53,14 @@ func (m *Model) loadConversation(shortname string) tea.Cmd {
if shortname == "" { if shortname == "" {
return nil return nil
} }
c, err := m.State.Ctx.Store.ConversationByShortName(shortname) c, err := m.Shared.Ctx.Store.ConversationByShortName(shortname)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err)) return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err))
} }
if c.ID == 0 { if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname)) return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
} }
rootMessages, err := m.State.Ctx.Store.RootMessages(c.ID) rootMessages, err := m.Shared.Ctx.Store.RootMessages(c.ID)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err)) return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
} }
@ -70,7 +70,7 @@ func (m *Model) loadConversation(shortname string) tea.Cmd {
func (m *Model) loadConversationMessages() tea.Cmd { func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.State.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot) messages, err := m.Shared.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
} }
@ -80,7 +80,7 @@ func (m *Model) loadConversationMessages() tea.Cmd {
func (m *Model) generateConversationTitle() tea.Cmd { func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.State.Ctx, m.messages) title, err := cmdutil.GenerateTitle(m.Shared.Ctx, m.messages)
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
@ -90,7 +90,7 @@ func (m *Model) generateConversationTitle() tea.Cmd {
func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd { func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.State.Ctx.Store.UpdateConversation(conversation) err := m.Shared.Ctx.Store.UpdateConversation(conversation)
if err != nil { if err != nil {
return shared.WrapError(err) return shared.WrapError(err)
} }
@ -110,10 +110,10 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
if selected { if selected {
if msg.Parent == nil { if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg msg.Conversation.SelectedRoot = msg
err = m.State.Ctx.Store.UpdateConversation(&msg.Conversation) err = m.Shared.Ctx.Store.UpdateConversation(msg.Conversation)
} else { } else {
msg.Parent.SelectedReply = msg msg.Parent.SelectedReply = msg
err = m.State.Ctx.Store.UpdateMessage(msg.Parent) err = m.Shared.Ctx.Store.UpdateMessage(msg.Parent)
} }
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err)) return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
@ -125,7 +125,7 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
func (m *Model) updateMessageContent(message *models.Message) tea.Cmd { func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.State.Ctx.Store.UpdateMessage(message) err := m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message: %v", err)) return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
} }
@ -170,7 +170,7 @@ func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDir
} }
conv.SelectedRoot = nextRoot conv.SelectedRoot = nextRoot
err = m.State.Ctx.Store.UpdateConversation(conv) err = m.Shared.Ctx.Store.UpdateConversation(conv)
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err)) return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
} }
@ -190,7 +190,7 @@ func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDire
} }
message.SelectedReply = nextReply message.SelectedReply = nextReply
err = m.State.Ctx.Store.UpdateMessage(message) err = m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err)) return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
} }
@ -203,14 +203,14 @@ func (m *Model) persistConversation() tea.Cmd {
messages := m.messages messages := m.messages
var err error var err error
if m.conversation.ID == 0 { if conversation.ID == 0 {
return func() tea.Msg { return func() tea.Msg {
// Start a new conversation with all messages so far // Start a new conversation with all messages so far
conversation, messages, err = m.State.Ctx.Store.StartConversation(messages...) conversation, messages, err = m.Shared.Ctx.Store.StartConversation(messages...)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err)) return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
} }
return msgConversationPersisted{conversation, messages} return msgConversationPersisted{true, conversation, messages}
} }
} }
@ -219,7 +219,7 @@ func (m *Model) persistConversation() tea.Cmd {
for i := range messages { for i := range messages {
if messages[i].ID > 0 { if messages[i].ID > 0 {
// message has an ID, update its contents // message has an ID, update its contents
err := m.State.Ctx.Store.UpdateMessage(&messages[i]) err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
@ -228,7 +228,7 @@ func (m *Model) persistConversation() tea.Cmd {
continue continue
} }
// messages is new, so add it as a reply to previous message // messages is new, so add it as a reply to previous message
saved, err := m.State.Ctx.Store.Reply(&messages[i-1], messages[i]) saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
@ -239,30 +239,29 @@ func (m *Model) persistConversation() tea.Cmd {
return fmt.Errorf("Error: no messages to reply to") return fmt.Errorf("Error: no messages to reply to")
} }
} }
return msgConversationPersisted{conversation, messages} return msgConversationPersisted{false, conversation, messages}
} }
} }
func (m *Model) promptLLM() tea.Cmd { func (m *Model) promptLLM() tea.Cmd {
m.waitingForReply = true m.state = pendingResponse
m.replyCursor.Blink = false m.replyCursor.Blink = false
m.status = "Press ctrl+c to cancel"
m.tokenCount = 0 m.tokenCount = 0
m.startTime = time.Now() m.startTime = time.Now()
m.elapsed = 0 m.elapsed = 0
return func() tea.Msg { return func() tea.Msg {
model, provider, err := m.State.Ctx.GetModelProvider(*m.State.Ctx.Config.Defaults.Model) model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
requestParams := models.RequestParameters{ requestParams := models.RequestParameters{
Model: model, Model: model,
MaxTokens: *m.State.Ctx.Config.Defaults.MaxTokens, MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.State.Ctx.Config.Defaults.Temperature, Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
ToolBag: m.State.Ctx.EnabledTools, ToolBag: m.Shared.Ctx.EnabledTools,
} }
replyHandler := func(msg models.Message) { replyHandler := func(msg models.Message) {

View File

@ -33,7 +33,7 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
switch msg.String() { switch msg.String() {
case "esc": case "esc":
if m.waitingForReply { if m.state == pendingResponse {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return true, nil
} }
@ -41,7 +41,7 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
return shared.MsgViewChange(shared.StateConversations) return shared.MsgViewChange(shared.StateConversations)
} }
case "ctrl+c": case "ctrl+c":
if m.waitingForReply { if m.state == pendingResponse {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return true, nil
} }
@ -112,15 +112,14 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return cmd != nil, cmd return cmd != nil, cmd
case "ctrl+r": case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message // resubmit the conversation with all messages up until and including the selected message
if m.waitingForReply || len(m.messages) == 0 { if m.state == idle && m.selectedMessage < len(m.messages) {
return true, nil m.messages = m.messages[:m.selectedMessage+1]
m.messageCache = m.messageCache[:m.selectedMessage+1]
cmd := m.promptLLM()
m.updateContent()
m.content.GotoBottom()
return true, cmd
} }
m.messages = m.messages[:m.selectedMessage+1]
m.messageCache = m.messageCache[:m.selectedMessage+1]
cmd := m.promptLLM()
m.updateContent()
m.content.GotoBottom()
return true, cmd
} }
return false, nil return false, nil
} }
@ -141,8 +140,8 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Blur() m.input.Blur()
return true, nil return true, nil
case "ctrl+s": case "ctrl+s":
// TODO: call a "handleSend" function with returns a tea.Cmd // TODO: call a "handleSend" function which returns a tea.Cmd
if m.waitingForReply { if m.state != idle {
return false, nil return false, nil
} }

View File

@ -42,11 +42,11 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
// wake up spinners and cursors // wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick) cmds = append(cmds, cursor.Blink, m.spinner.Tick)
if m.State.Values.ConvShortname != "" { if m.Shared.Values.ConvShortname != "" {
// (re)load conversation contents // (re)load conversation contents
cmds = append(cmds, m.loadConversation(m.State.Values.ConvShortname)) cmds = append(cmds, m.loadConversation(m.Shared.Values.ConvShortname))
if m.conversation.ShortName.String != m.State.Values.ConvShortname { if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
// clear existing messages if we're loading a new conversation // clear existing messages if we're loading a new conversation
m.messages = []models.Message{} m.messages = []models.Message{}
m.selectedMessage = 0 m.selectedMessage = 0
@ -90,20 +90,19 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
case msgResponseChunk: case msgResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
chunk := string(msg) if msg.Content == "" {
if chunk == "" {
break break
} }
last := len(m.messages) - 1 last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role.IsAssistant() { if last >= 0 && m.messages[last].Role.IsAssistant() {
// append chunk to existing message // append chunk to existing message
m.setMessageContents(last, m.messages[last].Content+chunk) m.setMessageContents(last, m.messages[last].Content+msg.Content)
} else { } else {
// use chunk in new message // use chunk in new message
m.addMessage(models.Message{ m.addMessage(models.Message{
Role: models.MessageRoleAssistant, Role: models.MessageRoleAssistant,
Content: chunk, Content: msg.Content,
}) })
} }
m.updateContent() m.updateContent()
@ -142,18 +141,16 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.updateContent() m.updateContent()
case msgResponseEnd: case msgResponseEnd:
m.waitingForReply = false m.state = idle
last := len(m.messages) - 1 last := len(m.messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected empty messages handling msgResponseEnd") panic("Unexpected empty messages handling msgResponseEnd")
} }
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content)) m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.updateContent() m.updateContent()
m.status = "Press ctrl+s to send"
case msgResponseError: case msgResponseError:
m.waitingForReply = false m.state = idle
m.status = "Press ctrl+s to send" m.Shared.Err = error(msg)
m.State.Err = error(msg)
m.updateContent() m.updateContent()
case msgConversationTitleGenerated: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
@ -162,18 +159,21 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
cmds = append(cmds, m.updateConversationTitle(m.conversation)) cmds = append(cmds, m.updateConversationTitle(m.conversation))
} }
case cursor.BlinkMsg: case cursor.BlinkMsg:
if m.waitingForReply { if m.state == pendingResponse {
// ensure we show the updated "wait for response" cursor blink state // ensure we show the updated "wait for response" cursor blink state
m.updateContent() m.updateContent()
} }
case msgConversationPersisted: case msgConversationPersisted:
m.conversation = msg.conversation m.conversation = msg.conversation
m.messages = msg.messages m.messages = msg.messages
if msg.isNew {
m.rootMessages = []models.Message{m.messages[0]}
}
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case msgMessageCloned: case msgMessageCloned:
if msg.Parent == nil { if msg.Parent == nil {
m.conversation = &msg.Conversation m.conversation = msg.Conversation
m.rootMessages = append(m.rootMessages, *msg) m.rootMessages = append(m.rootMessages, *msg)
} }
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())

View File

@ -131,7 +131,7 @@ func (m *Model) renderMessage(i int) string {
sb := &strings.Builder{} sb := &strings.Builder{}
sb.Grow(len(msg.Content) * 2) sb.Grow(len(msg.Content) * 2)
if msg.Content != "" { if msg.Content != "" {
err := m.State.Ctx.Chroma.Highlight(sb, msg.Content) err := m.Shared.Ctx.Chroma.Highlight(sb, msg.Content)
if err != nil { if err != nil {
sb.Reset() sb.Reset()
sb.WriteString(msg.Content) sb.WriteString(msg.Content)
@ -139,7 +139,7 @@ func (m *Model) renderMessage(i int) string {
} }
// Show the assistant's cursor // Show the assistant's cursor
if m.waitingForReply && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant { if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant {
sb.WriteString(m.replyCursor.View()) sb.WriteString(m.replyCursor.View())
} }
@ -195,7 +195,7 @@ func (m *Model) renderMessage(i int) string {
if msg.Content != "" { if msg.Content != "" {
sb.WriteString("\n\n") sb.WriteString("\n\n")
} }
_ = m.State.Ctx.Chroma.HighlightLang(sb, toolString, "yaml") _ = m.Shared.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
} }
content := strings.TrimRight(sb.String(), "\n") content := strings.TrimRight(sb.String(), "\n")
@ -237,7 +237,7 @@ func (m *Model) conversationMessagesView() string {
lineCnt += lipgloss.Height(heading) lineCnt += lipgloss.Height(heading)
var rendered string var rendered string
if m.waitingForReply && i == len(m.messages)-1 { if m.state == pendingResponse && i == len(m.messages)-1 {
// do a direct render of final (assistant) message to handle the // do a direct render of final (assistant) message to handle the
// assistant cursor blink // assistant cursor blink
rendered = m.renderMessage(i) rendered = m.renderMessage(i)
@ -251,7 +251,7 @@ func (m *Model) conversationMessagesView() string {
} }
// Render a placeholder for the incoming assistant reply // Render a placeholder for the incoming assistant reply
if m.waitingForReply && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) { if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) {
heading := m.renderMessageHeading(-1, &models.Message{ heading := m.renderMessageHeading(-1, &models.Message{
Role: models.MessageRoleAssistant, Role: models.MessageRoleAssistant,
}) })
@ -289,9 +289,12 @@ func (m *Model) footerView() string {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾") saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
} }
status := m.status var status string
if m.waitingForReply { switch m.state {
status += m.spinner.View() case pendingResponse:
status = "Press ctrl+c to cancel" + m.spinner.View()
default:
status = "Press ctrl+s to send"
} }
leftSegments := []string{ leftSegments := []string{
@ -305,7 +308,7 @@ func (m *Model) footerView() string {
rightSegments = append(rightSegments, segmentStyle.Render(throughput)) rightSegments = append(rightSegments, segmentStyle.Render(throughput))
} }
model := fmt.Sprintf("Model: %s", *m.State.Ctx.Config.Defaults.Model) model := fmt.Sprintf("Model: %s", *m.Shared.Ctx.Config.Defaults.Model)
rightSegments = append(rightSegments, segmentStyle.Render(model)) rightSegments = append(rightSegments, segmentStyle.Render(model))
left := strings.Join(leftSegments, segmentSeparator) left := strings.Join(leftSegments, segmentSeparator)

View File

@ -28,7 +28,7 @@ type (
) )
type Model struct { type Model struct {
shared.State shared.Shared
shared.Sections shared.Sections
conversations []loadedConversation conversations []loadedConversation
@ -38,9 +38,9 @@ type Model struct {
content viewport.Model content viewport.Model
} }
func Conversations(state shared.State) Model { func Conversations(shared shared.Shared) Model {
m := Model{ m := Model{
State: state, Shared: shared,
content: viewport.New(0, 0), content: viewport.New(0, 0),
} }
return m return m
@ -155,7 +155,7 @@ func (m *Model) loadConversations() tea.Cmd {
loaded := make([]loadedConversation, len(messages)) loaded := make([]loadedConversation, len(messages))
for i, m := range messages { for i, m := range messages {
loaded[i].lastReply = m loaded[i].lastReply = m
loaded[i].conv = m.Conversation loaded[i].conv = *m.Conversation
} }
return msgConversationsLoaded(loaded) return msgConversationsLoaded(loaded)