Compare commits

..

No commits in common. "42c3297e54e2e89381d8b0b8a9855824c10099a8" and "45df957a062d962e42a2ab7992e7c7ac62edc2b8" have entirely different histories.

22 changed files with 159 additions and 182 deletions

View File

@ -31,6 +31,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
} }
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
ConversationID: conversation.ID,
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: reply, Content: reply,
}) })

View File

@ -8,7 +8,6 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
@ -18,7 +17,7 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan api.Chunk) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
// render all content received over the channel // render all content received over the channel
@ -252,7 +251,7 @@ func ShowWaitAnimation(signal chan any) {
// chunked) content is received on the channel, the waiting animation is // chunked) content is received on the channel, the waiting animation is
// replaced by the content. // replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func ShowDelayedContent(content <-chan api.Chunk) { func ShowDelayedContent(content <-chan string) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)
@ -265,7 +264,7 @@ func ShowDelayedContent(content <-chan api.Chunk) {
<-waitSignal <-waitSignal
firstChunk = false firstChunk = false
} }
fmt.Print(chunk.Content) fmt.Print(chunk)
} }
} }

View File

@ -6,12 +6,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
@ -79,7 +79,7 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) { func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) {
parts := strings.Split(model, "/") parts := strings.Split(model, "/")
var provider string var provider string

View File

@ -17,8 +17,8 @@ const (
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ConversationID *uint `gorm:"index"` ConversationID uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"` Conversation Conversation `gorm:"foreignKey:ConversationID"`
Content string Content string
Role MessageRole Role MessageRole
CreatedAt time.Time CreatedAt time.Time

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -107,7 +107,7 @@ func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -160,8 +160,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
output chan<- api.Chunk, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -242,9 +242,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", fmt.Errorf("invalid text delta") return "", fmt.Errorf("invalid text delta")
} }
sb.WriteString(text) sb.WriteString(text)
output <- api.Chunk{ output <- text
Content: text,
}
case "content_block_stop": case "content_block_stop":
// ignore? // ignore?
case "message_delta": case "message_delta":
@ -264,9 +262,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- api.Chunk{ output <- FUNCTION_STOP_SEQUENCE
Content: FUNCTION_STOP_SEQUENCE,
}
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -187,7 +187,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []model.ToolCall, toolCalls []model.ToolCall,
callback api.ReplyCallback, callback provider.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -245,7 +245,7 @@ func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -325,8 +325,8 @@ func (c *Client) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
output chan<- api.Chunk, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -393,9 +393,7 @@ func (c *Client) CreateChatCompletionStream(
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- api.Chunk { output <- part.Text
Content: part.Text,
}
content.WriteString(part.Text) content.WriteString(part.Text)
} }
} }

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
) )
type OllamaClient struct { type OllamaClient struct {
@ -85,7 +85,7 @@ func (c *OllamaClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -117,6 +117,9 @@ func (c *OllamaClient) CreateChatCompletion(
} }
content := completionResp.Message.Content content := completionResp.Message.Content
fmt.Println(content)
if callback != nil { if callback != nil {
callback(model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
@ -131,8 +134,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
output chan<- api.Chunk, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -181,9 +184,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
output <- api.Chunk{ output <- streamResp.Message.Content
Content: streamResp.Message.Content,
}
content.WriteString(streamResp.Message.Content) content.WriteString(streamResp.Message.Content)
} }
} }

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -121,7 +121,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []ToolCall, toolCalls []ToolCall,
callback api.ReplyCallback, callback provider.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -180,7 +180,7 @@ func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -244,8 +244,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback api.ReplyCallback, callback provider.ReplyCallback,
output chan<- api.Chunk, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -319,9 +319,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
if len(delta.Content) > 0 { if len(delta.Content) > 0 {
output <- api.Chunk { output <- delta.Content
Content: delta.Content,
}
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }
} }

View File

@ -1,4 +1,4 @@
package api package provider
import ( import (
"context" "context"
@ -8,10 +8,6 @@ import (
type ReplyCallback func(model.Message) type ReplyCallback func(model.Message)
type Chunk struct {
Content string
}
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
@ -30,6 +26,6 @@ type ChatCompletionClient interface {
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, callback ReplyCallback,
output chan<- Chunk, output chan<- string,
) (string, error) ) (string, error)
} }

View File

@ -58,28 +58,24 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) createConversation() (*model.Conversation, error) { func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
// Create the new conversation // Save the new conversation
c := &model.Conversation{} err := s.db.Save(&c).Error
err := s.db.Save(c).Error
if err != nil { if err != nil {
return nil, err return err
} }
// Generate and save its "short name" // Generate and save its "short name"
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)}) shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
c.ShortName = sql.NullString{String: shortName, Valid: true} c.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Updates(c).Error return s.UpdateConversation(c)
if err != nil {
return nil, err
}
return c, nil
} }
func (s *SQLStore) UpdateConversation(c *model.Conversation) error { func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
if c == nil || c.ID == 0 { if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)") return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(c).Error return s.db.Updates(&c).Error
} }
func (s *SQLStore) DeleteConversation(c *model.Conversation) error { func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
@ -88,7 +84,7 @@ func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
if err != nil { if err != nil {
return err return err
} }
return s.db.Delete(c).Error return s.db.Delete(&c).Error
} }
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error { func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
@ -100,7 +96,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error {
if m == nil || m.ID == 0 { if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)") return fmt.Errorf("Message is nil or invalid (missing ID)")
} }
return s.db.Updates(m).Error return s.db.Updates(&m).Error
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
@ -153,13 +149,14 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa
} }
// Create new conversation // Create new conversation
conversation, err := s.createConversation() conversation := &model.Conversation{}
err := s.saveNewConversation(conversation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Create first message // Create first message
messages[0].Conversation = conversation messages[0].ConversationID = conversation.ID
err = s.db.Create(&messages[0]).Error err = s.db.Create(&messages[0]).Error
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -190,18 +187,19 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers
return nil, 0, err return nil, 0, err
} }
clone, err := s.createConversation() clone := &model.Conversation{
if err != nil { Title: toClone.Title + " - Clone",
}
if err := s.saveNewConversation(clone); err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err) return nil, 0, fmt.Errorf("Could not create clone: %s", err)
} }
clone.Title = toClone.Title + " - Clone"
var errors []error var errors []error
var messageCnt uint = 0 var messageCnt uint = 0
for _, root := range rootMessages { for _, root := range rootMessages {
messageCnt++ messageCnt++
newRoot := root newRoot := root
newRoot.ConversationID = &clone.ID newRoot.ConversationID = clone.ID
cloned, count, err := s.CloneBranch(newRoot) cloned, count, err := s.CloneBranch(newRoot)
if err != nil { if err != nil {
@ -232,10 +230,9 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to currentParent := to
for i := range messages { for i := range messages {
parent := currentParent
message := messages[i] message := messages[i]
message.Parent = parent message.ConversationID = currentParent.ConversationID
message.Conversation = parent.Conversation message.ParentID = &currentParent.ID
message.ID = 0 message.ID = 0
message.CreatedAt = time.Time{} message.CreatedAt = time.Time{}
@ -244,9 +241,9 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
} }
// update parent selected reply // update parent selected reply
parent.Replies = append(parent.Replies, message) currentParent.Replies = append(currentParent.Replies, message)
parent.SelectedReply = &message currentParent.SelectedReply = &message
if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil { if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
return err return err
} }

View File

@ -9,7 +9,7 @@ type Values struct {
ConvShortname string ConvShortname string
} }
type Shared struct { type State struct {
Ctx *lmcli.Context Ctx *lmcli.Context
Values *Values Values *Values
Width int Width int

View File

@ -18,7 +18,7 @@ import (
// Application model // Application model
type Model struct { type Model struct {
shared.Shared shared.State
state shared.View state shared.View
chat chat.Model chat chat.Model
@ -27,15 +27,15 @@ type Model struct {
func initialModel(ctx *lmcli.Context, values shared.Values) Model { func initialModel(ctx *lmcli.Context, values shared.Values) Model {
m := Model{ m := Model{
Shared: shared.Shared{ State: shared.State{
Ctx: ctx, Ctx: ctx,
Values: &values, Values: &values,
}, },
} }
m.state = shared.StateChat m.state = shared.StateChat
m.chat = chat.Chat(m.Shared) m.chat = chat.Chat(m.State)
m.conversations = conversations.Conversations(m.Shared) m.conversations = conversations.Conversations(m.State)
return m return m
} }

View File

@ -3,7 +3,6 @@ package chat
import ( import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
@ -14,10 +13,24 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent on each chunk received from LLM // sent on each chunk received from LLM
msgResponseChunk api.Chunk msgResponseChunk string
// sent when response is finished being received // sent when response is finished being received
msgResponseEnd string msgResponseEnd string
// a special case of common.MsgError that stops the response waiting animation // a special case of common.MsgError that stops the response waiting animation
@ -35,7 +48,6 @@ type (
msgMessagesLoaded []models.Message msgMessagesLoaded []models.Message
// sent when the conversation has been persisted, triggers a reload of contents // sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct { msgConversationPersisted struct {
isNew bool
conversation *models.Conversation conversation *models.Conversation
messages []models.Message messages []models.Message
} }
@ -49,47 +61,26 @@ type (
msgMessageCloned *models.Message msgMessageCloned *models.Message
) )
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type state int
const (
idle state = iota
loading
pendingResponse
)
type Model struct { type Model struct {
shared.Shared shared.State
shared.Sections shared.Sections
// app state // app state
state state // current overall status of the view
conversation *models.Conversation conversation *models.Conversation
rootMessages []models.Message rootMessages []models.Message
messages []models.Message messages []models.Message
selectedMessage int selectedMessage int
waitingForReply bool
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan models.Message replyChan chan models.Message
replyChunkChan chan api.Chunk replyChunkChan chan string
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // ui state
focus focusState focus focusState
wrap bool // whether message content is wrapped to viewport width wrap bool // whether message content is wrapped to viewport width
status string // a general status message
showToolResults bool // whether tool calls and results are shown showToolResults bool // whether tool calls and results are shown
messageCache []string // cache of syntax highlighted and wrapped message content messageCache []string // cache of syntax highlighted and wrapped message content
messageOffsets []int messageOffsets []int
@ -106,17 +97,16 @@ type Model struct {
elapsed time.Duration elapsed time.Duration
} }
func Chat(shared shared.Shared) Model { func Chat(state shared.State) Model {
m := Model{ m := Model{
Shared: shared, State: state,
state: idle,
conversation: &models.Conversation{}, conversation: &models.Conversation{},
persistence: true, persistence: true,
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan api.Chunk), replyChunkChan: make(chan string),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,
@ -142,7 +132,7 @@ func Chat(shared shared.Shared) Model {
m.replyCursor.SetChar(" ") m.replyCursor.SetChar(" ")
m.replyCursor.Focus() m.replyCursor.Focus()
system := shared.Ctx.GetSystemPrompt() system := state.Ctx.GetSystemPrompt()
if system != "" { if system != "" {
m.messages = []models.Message{{ m.messages = []models.Message{{
Role: models.MessageRoleSystem, Role: models.MessageRoleSystem,
@ -160,6 +150,8 @@ func Chat(shared shared.Shared) Model {
m.input.FocusedStyle.Base = inputFocusedStyle m.input.FocusedStyle.Base = inputFocusedStyle
m.input.BlurredStyle.Base = inputBlurredStyle m.input.BlurredStyle.Base = inputBlurredStyle
m.waitingForReply = false
m.status = "Press ctrl+s to send"
return m return m
} }

View File

@ -53,14 +53,14 @@ func (m *Model) loadConversation(shortname string) tea.Cmd {
if shortname == "" { if shortname == "" {
return nil return nil
} }
c, err := m.Shared.Ctx.Store.ConversationByShortName(shortname) c, err := m.State.Ctx.Store.ConversationByShortName(shortname)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err)) return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err))
} }
if c.ID == 0 { if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname)) return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
} }
rootMessages, err := m.Shared.Ctx.Store.RootMessages(c.ID) rootMessages, err := m.State.Ctx.Store.RootMessages(c.ID)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err)) return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
} }
@ -70,7 +70,7 @@ func (m *Model) loadConversation(shortname string) tea.Cmd {
func (m *Model) loadConversationMessages() tea.Cmd { func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.Shared.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot) messages, err := m.State.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
} }
@ -80,7 +80,7 @@ func (m *Model) loadConversationMessages() tea.Cmd {
func (m *Model) generateConversationTitle() tea.Cmd { func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.Shared.Ctx, m.messages) title, err := cmdutil.GenerateTitle(m.State.Ctx, m.messages)
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
@ -90,7 +90,7 @@ func (m *Model) generateConversationTitle() tea.Cmd {
func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd { func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateConversation(conversation) err := m.State.Ctx.Store.UpdateConversation(conversation)
if err != nil { if err != nil {
return shared.WrapError(err) return shared.WrapError(err)
} }
@ -110,10 +110,10 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
if selected { if selected {
if msg.Parent == nil { if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg msg.Conversation.SelectedRoot = msg
err = m.Shared.Ctx.Store.UpdateConversation(msg.Conversation) err = m.State.Ctx.Store.UpdateConversation(&msg.Conversation)
} else { } else {
msg.Parent.SelectedReply = msg msg.Parent.SelectedReply = msg
err = m.Shared.Ctx.Store.UpdateMessage(msg.Parent) err = m.State.Ctx.Store.UpdateMessage(msg.Parent)
} }
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err)) return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
@ -125,7 +125,7 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
func (m *Model) updateMessageContent(message *models.Message) tea.Cmd { func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateMessage(message) err := m.State.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message: %v", err)) return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
} }
@ -170,7 +170,7 @@ func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDir
} }
conv.SelectedRoot = nextRoot conv.SelectedRoot = nextRoot
err = m.Shared.Ctx.Store.UpdateConversation(conv) err = m.State.Ctx.Store.UpdateConversation(conv)
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err)) return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
} }
@ -190,7 +190,7 @@ func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDire
} }
message.SelectedReply = nextReply message.SelectedReply = nextReply
err = m.Shared.Ctx.Store.UpdateMessage(message) err = m.State.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err)) return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
} }
@ -203,14 +203,14 @@ func (m *Model) persistConversation() tea.Cmd {
messages := m.messages messages := m.messages
var err error var err error
if conversation.ID == 0 { if m.conversation.ID == 0 {
return func() tea.Msg { return func() tea.Msg {
// Start a new conversation with all messages so far // Start a new conversation with all messages so far
conversation, messages, err = m.Shared.Ctx.Store.StartConversation(messages...) conversation, messages, err = m.State.Ctx.Store.StartConversation(messages...)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err)) return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
} }
return msgConversationPersisted{true, conversation, messages} return msgConversationPersisted{conversation, messages}
} }
} }
@ -219,7 +219,7 @@ func (m *Model) persistConversation() tea.Cmd {
for i := range messages { for i := range messages {
if messages[i].ID > 0 { if messages[i].ID > 0 {
// message has an ID, update its contents // message has an ID, update its contents
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i]) err := m.State.Ctx.Store.UpdateMessage(&messages[i])
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
@ -228,7 +228,7 @@ func (m *Model) persistConversation() tea.Cmd {
continue continue
} }
// messages is new, so add it as a reply to previous message // messages is new, so add it as a reply to previous message
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i]) saved, err := m.State.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
@ -239,29 +239,30 @@ func (m *Model) persistConversation() tea.Cmd {
return fmt.Errorf("Error: no messages to reply to") return fmt.Errorf("Error: no messages to reply to")
} }
} }
return msgConversationPersisted{false, conversation, messages} return msgConversationPersisted{conversation, messages}
} }
} }
func (m *Model) promptLLM() tea.Cmd { func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse m.waitingForReply = true
m.replyCursor.Blink = false m.replyCursor.Blink = false
m.status = "Press ctrl+c to cancel"
m.tokenCount = 0 m.tokenCount = 0
m.startTime = time.Now() m.startTime = time.Now()
m.elapsed = 0 m.elapsed = 0
return func() tea.Msg { return func() tea.Msg {
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model) model, provider, err := m.State.Ctx.GetModelProvider(*m.State.Ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
requestParams := models.RequestParameters{ requestParams := models.RequestParameters{
Model: model, Model: model,
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens, MaxTokens: *m.State.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature, Temperature: *m.State.Ctx.Config.Defaults.Temperature,
ToolBag: m.Shared.Ctx.EnabledTools, ToolBag: m.State.Ctx.EnabledTools,
} }
replyHandler := func(msg models.Message) { replyHandler := func(msg models.Message) {

View File

@ -33,7 +33,7 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
switch msg.String() { switch msg.String() {
case "esc": case "esc":
if m.state == pendingResponse { if m.waitingForReply {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return true, nil
} }
@ -41,7 +41,7 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
return shared.MsgViewChange(shared.StateConversations) return shared.MsgViewChange(shared.StateConversations)
} }
case "ctrl+c": case "ctrl+c":
if m.state == pendingResponse { if m.waitingForReply {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return true, nil
} }
@ -112,7 +112,9 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return cmd != nil, cmd return cmd != nil, cmd
case "ctrl+r": case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message // resubmit the conversation with all messages up until and including the selected message
if m.state == idle && m.selectedMessage < len(m.messages) { if m.waitingForReply || len(m.messages) == 0 {
return true, nil
}
m.messages = m.messages[:m.selectedMessage+1] m.messages = m.messages[:m.selectedMessage+1]
m.messageCache = m.messageCache[:m.selectedMessage+1] m.messageCache = m.messageCache[:m.selectedMessage+1]
cmd := m.promptLLM() cmd := m.promptLLM()
@ -120,7 +122,6 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.content.GotoBottom() m.content.GotoBottom()
return true, cmd return true, cmd
} }
}
return false, nil return false, nil
} }
@ -140,8 +141,8 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Blur() m.input.Blur()
return true, nil return true, nil
case "ctrl+s": case "ctrl+s":
// TODO: call a "handleSend" function which returns a tea.Cmd // TODO: call a "handleSend" function with returns a tea.Cmd
if m.state != idle { if m.waitingForReply {
return false, nil return false, nil
} }

View File

@ -42,11 +42,11 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
// wake up spinners and cursors // wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick) cmds = append(cmds, cursor.Blink, m.spinner.Tick)
if m.Shared.Values.ConvShortname != "" { if m.State.Values.ConvShortname != "" {
// (re)load conversation contents // (re)load conversation contents
cmds = append(cmds, m.loadConversation(m.Shared.Values.ConvShortname)) cmds = append(cmds, m.loadConversation(m.State.Values.ConvShortname))
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname { if m.conversation.ShortName.String != m.State.Values.ConvShortname {
// clear existing messages if we're loading a new conversation // clear existing messages if we're loading a new conversation
m.messages = []models.Message{} m.messages = []models.Message{}
m.selectedMessage = 0 m.selectedMessage = 0
@ -90,19 +90,20 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
case msgResponseChunk: case msgResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
if msg.Content == "" { chunk := string(msg)
if chunk == "" {
break break
} }
last := len(m.messages) - 1 last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role.IsAssistant() { if last >= 0 && m.messages[last].Role.IsAssistant() {
// append chunk to existing message // append chunk to existing message
m.setMessageContents(last, m.messages[last].Content+msg.Content) m.setMessageContents(last, m.messages[last].Content+chunk)
} else { } else {
// use chunk in new message // use chunk in new message
m.addMessage(models.Message{ m.addMessage(models.Message{
Role: models.MessageRoleAssistant, Role: models.MessageRoleAssistant,
Content: msg.Content, Content: chunk,
}) })
} }
m.updateContent() m.updateContent()
@ -141,16 +142,18 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.updateContent() m.updateContent()
case msgResponseEnd: case msgResponseEnd:
m.state = idle m.waitingForReply = false
last := len(m.messages) - 1 last := len(m.messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected empty messages handling msgResponseEnd") panic("Unexpected empty messages handling msgResponseEnd")
} }
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content)) m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.updateContent() m.updateContent()
m.status = "Press ctrl+s to send"
case msgResponseError: case msgResponseError:
m.state = idle m.waitingForReply = false
m.Shared.Err = error(msg) m.status = "Press ctrl+s to send"
m.State.Err = error(msg)
m.updateContent() m.updateContent()
case msgConversationTitleGenerated: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
@ -159,21 +162,18 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
cmds = append(cmds, m.updateConversationTitle(m.conversation)) cmds = append(cmds, m.updateConversationTitle(m.conversation))
} }
case cursor.BlinkMsg: case cursor.BlinkMsg:
if m.state == pendingResponse { if m.waitingForReply {
// ensure we show the updated "wait for response" cursor blink state // ensure we show the updated "wait for response" cursor blink state
m.updateContent() m.updateContent()
} }
case msgConversationPersisted: case msgConversationPersisted:
m.conversation = msg.conversation m.conversation = msg.conversation
m.messages = msg.messages m.messages = msg.messages
if msg.isNew {
m.rootMessages = []models.Message{m.messages[0]}
}
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case msgMessageCloned: case msgMessageCloned:
if msg.Parent == nil { if msg.Parent == nil {
m.conversation = msg.Conversation m.conversation = &msg.Conversation
m.rootMessages = append(m.rootMessages, *msg) m.rootMessages = append(m.rootMessages, *msg)
} }
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())

View File

@ -131,7 +131,7 @@ func (m *Model) renderMessage(i int) string {
sb := &strings.Builder{} sb := &strings.Builder{}
sb.Grow(len(msg.Content) * 2) sb.Grow(len(msg.Content) * 2)
if msg.Content != "" { if msg.Content != "" {
err := m.Shared.Ctx.Chroma.Highlight(sb, msg.Content) err := m.State.Ctx.Chroma.Highlight(sb, msg.Content)
if err != nil { if err != nil {
sb.Reset() sb.Reset()
sb.WriteString(msg.Content) sb.WriteString(msg.Content)
@ -139,7 +139,7 @@ func (m *Model) renderMessage(i int) string {
} }
// Show the assistant's cursor // Show the assistant's cursor
if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant { if m.waitingForReply && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant {
sb.WriteString(m.replyCursor.View()) sb.WriteString(m.replyCursor.View())
} }
@ -195,7 +195,7 @@ func (m *Model) renderMessage(i int) string {
if msg.Content != "" { if msg.Content != "" {
sb.WriteString("\n\n") sb.WriteString("\n\n")
} }
_ = m.Shared.Ctx.Chroma.HighlightLang(sb, toolString, "yaml") _ = m.State.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
} }
content := strings.TrimRight(sb.String(), "\n") content := strings.TrimRight(sb.String(), "\n")
@ -237,7 +237,7 @@ func (m *Model) conversationMessagesView() string {
lineCnt += lipgloss.Height(heading) lineCnt += lipgloss.Height(heading)
var rendered string var rendered string
if m.state == pendingResponse && i == len(m.messages)-1 { if m.waitingForReply && i == len(m.messages)-1 {
// do a direct render of final (assistant) message to handle the // do a direct render of final (assistant) message to handle the
// assistant cursor blink // assistant cursor blink
rendered = m.renderMessage(i) rendered = m.renderMessage(i)
@ -251,7 +251,7 @@ func (m *Model) conversationMessagesView() string {
} }
// Render a placeholder for the incoming assistant reply // Render a placeholder for the incoming assistant reply
if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) { if m.waitingForReply && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) {
heading := m.renderMessageHeading(-1, &models.Message{ heading := m.renderMessageHeading(-1, &models.Message{
Role: models.MessageRoleAssistant, Role: models.MessageRoleAssistant,
}) })
@ -289,12 +289,9 @@ func (m *Model) footerView() string {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾") saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
} }
var status string status := m.status
switch m.state { if m.waitingForReply {
case pendingResponse: status += m.spinner.View()
status = "Press ctrl+c to cancel" + m.spinner.View()
default:
status = "Press ctrl+s to send"
} }
leftSegments := []string{ leftSegments := []string{
@ -308,7 +305,7 @@ func (m *Model) footerView() string {
rightSegments = append(rightSegments, segmentStyle.Render(throughput)) rightSegments = append(rightSegments, segmentStyle.Render(throughput))
} }
model := fmt.Sprintf("Model: %s", *m.Shared.Ctx.Config.Defaults.Model) model := fmt.Sprintf("Model: %s", *m.State.Ctx.Config.Defaults.Model)
rightSegments = append(rightSegments, segmentStyle.Render(model)) rightSegments = append(rightSegments, segmentStyle.Render(model))
left := strings.Join(leftSegments, segmentSeparator) left := strings.Join(leftSegments, segmentSeparator)

View File

@ -28,7 +28,7 @@ type (
) )
type Model struct { type Model struct {
shared.Shared shared.State
shared.Sections shared.Sections
conversations []loadedConversation conversations []loadedConversation
@ -38,9 +38,9 @@ type Model struct {
content viewport.Model content viewport.Model
} }
func Conversations(shared shared.Shared) Model { func Conversations(state shared.State) Model {
m := Model{ m := Model{
Shared: shared, State: state,
content: viewport.New(0, 0), content: viewport.New(0, 0),
} }
return m return m
@ -155,7 +155,7 @@ func (m *Model) loadConversations() tea.Cmd {
loaded := make([]loadedConversation, len(messages)) loaded := make([]loadedConversation, len(messages))
for i, m := range messages { for i, m := range messages {
loaded[i].lastReply = m loaded[i].lastReply = m
loaded[i].conv = *m.Conversation loaded[i].conv = m.Conversation
} }
return msgConversationsLoaded(loaded) return msgConversationsLoaded(loaded)