diff --git a/pkg/api/api.go b/pkg/api/api.go index 6042af8..b636b66 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -16,10 +16,11 @@ const ( ) type Message struct { - Content string // TODO: support multi-part messages - Role MessageRole - ToolCalls []ToolCall - ToolResults []ToolResult + Content string // TODO: support multi-part messages + ReasoningContent string + Role MessageRole + ToolCalls []ToolCall + ToolResults []ToolResult } type ToolSpec struct { @@ -49,10 +50,11 @@ type ToolResult struct { Result string `json:"result,omitempty" yaml:"result"` } -func NewMessageWithAssistant(content string) *Message { +func NewMessageWithAssistant(content string, reasoning string) *Message { return &Message{ Role: MessageRoleAssistant, Content: content, + ReasoningContent: reasoning, } } diff --git a/pkg/conversation/conversation.go b/pkg/conversation/conversation.go index 7c7db06..e225bfc 100644 --- a/pkg/conversation/conversation.go +++ b/pkg/conversation/conversation.go @@ -38,10 +38,11 @@ type Message struct { SelectedReplyID *uint SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` - Role api.MessageRole - Content string - ToolCalls ToolCalls // a json array of tool calls (from the model) - ToolResults ToolResults // a json array of tool results + Role api.MessageRole + Content string + ReasoningContent string + ToolCalls ToolCalls // a json array of tool calls (from the model) + ToolResults ToolResults // a json array of tool results } func (m *MessageMeta) Scan(value interface{}) error { diff --git a/pkg/conversation/tools.go b/pkg/conversation/tools.go index ca61e88..be55b94 100644 --- a/pkg/conversation/tools.go +++ b/pkg/conversation/tools.go @@ -22,10 +22,11 @@ func ApplySystemPrompt(m []Message, system string, force bool) []Message { func MessageToAPI(m Message) api.Message { return api.Message{ - Role: m.Role, - Content: m.Content, - ToolCalls: m.ToolCalls, - ToolResults: m.ToolResults, + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolResults: m.ToolResults, } } @@ -39,10 +40,11 @@ func MessagesToAPI(messages []Message) []api.Message { func MessageFromAPI(m api.Message) Message { return Message{ - Role: m.Role, - Content: m.Content, - ToolCalls: m.ToolCalls, - ToolResults: m.ToolResults, + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolResults: m.ToolResults, } } diff --git a/pkg/provider/anthropic/anthropic.go b/pkg/provider/anthropic/anthropic.go index 5b4cd9a..ec01648 100644 --- a/pkg/provider/anthropic/anthropic.go +++ b/pkg/provider/anthropic/anthropic.go @@ -443,5 +443,5 @@ func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) return api.NewMessageWithToolCalls(content.String(), toolCalls), nil } - return api.NewMessageWithAssistant(content.String()), nil + return api.NewMessageWithAssistant(content.String(), ""), nil } diff --git a/pkg/provider/google/google.go b/pkg/provider/google/google.go index 1d8bfe0..7a61f4a 100644 --- a/pkg/provider/google/google.go +++ b/pkg/provider/google/google.go @@ -340,7 +340,7 @@ func (c *Client) CreateChatCompletion( return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil } - return api.NewMessageWithAssistant(content), nil + return api.NewMessageWithAssistant(content, ""), nil } func (c *Client) CreateChatCompletionStream( @@ -432,5 +432,5 @@ func (c *Client) CreateChatCompletionStream( return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil } - return api.NewMessageWithAssistant(content.String()), nil + return api.NewMessageWithAssistant(content.String(), ""), nil } diff --git a/pkg/provider/ollama/ollama.go b/pkg/provider/ollama/ollama.go index 5b860bb..f023eea 100644 --- a/pkg/provider/ollama/ollama.go +++ b/pkg/provider/ollama/ollama.go @@ -115,7 +115,7 @@ func (c *OllamaClient) CreateChatCompletion( return nil, err } - return api.NewMessageWithAssistant(completionResp.Message.Content), nil + return api.NewMessageWithAssistant(completionResp.Message.Content, ""), nil } func (c *OllamaClient) CreateChatCompletionStream( @@ -179,5 +179,5 @@ func (c *OllamaClient) CreateChatCompletionStream( } } - return api.NewMessageWithAssistant(content.String()), nil + return api.NewMessageWithAssistant(content.String(), ""), nil } diff --git a/pkg/provider/openai/openai.go b/pkg/provider/openai/openai.go index 9d1f567..86667df 100644 --- a/pkg/provider/openai/openai.go +++ b/pkg/provider/openai/openai.go @@ -21,10 +21,11 @@ type OpenAIClient struct { } type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } type ToolCall struct { @@ -256,7 +257,7 @@ func (c *OpenAIClient) CreateChatCompletion( return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil } - return api.NewMessageWithAssistant(content), nil + return api.NewMessageWithAssistant(content, ""), nil } func (c *OpenAIClient) CreateChatCompletionStream( @@ -279,6 +280,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( defer resp.Body.Close() content := strings.Builder{} + reasoning := strings.Builder{} toolCalls := []ToolCall{} lastMessage := messages[len(messages)-1] @@ -333,11 +335,18 @@ func (c *OpenAIClient) CreateChatCompletionStream( } content.WriteString(delta.Content) } + if len(delta.ReasoningContent) > 0 { + output <- provider.Chunk{ + ReasoningContent: delta.ReasoningContent, + TokenCount: 1, + } + reasoning.WriteString(delta.ReasoningContent) + } } if len(toolCalls) > 0 { return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil } - return api.NewMessageWithAssistant(content.String()), nil + return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil } diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index e14b7da..d387ced 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -7,8 +7,9 @@ import ( ) type Chunk struct { - Content string - TokenCount uint + Content string + ReasoningContent string + TokenCount uint } type RequestParameters struct { diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index fa95972..7b81244 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -33,6 +33,14 @@ func (m *Model) setMessageContents(i int, content string) { m.messageCache[i] = m.renderMessage(i) } +func (m *Model) setReasoningContents(i int, content string) { + if i >= len(m.App.Messages) { + panic("i out of range") + } + m.App.Messages[i].ReasoningContent = content + m.messageCache[i] = m.renderMessage(i) +} + func (m *Model) rebuildMessageCache() { m.messageCache = make([]string, len(m.App.Messages)) for i := range m.App.Messages { @@ -108,7 +116,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case msgChatResponseChunk: cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk - if msg.Content == "" { + if msg.Content == "" && msg.ReasoningContent == "" { // skip empty chunks break } @@ -116,19 +124,27 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { last := len(m.App.Messages) - 1 if last >= 0 && m.App.Messages[last].Role.IsAssistant() { // append chunk to existing message - m.setMessageContents(last, m.App.Messages[last].Content+msg.Content) + if msg.Content != "" { + m.setMessageContents(last, m.App.Messages[last].Content+msg.Content) + } + if msg.ReasoningContent != "" { + m.setReasoningContents(last, m.App.Messages[last].ReasoningContent+msg.ReasoningContent) + } } else { // use chunk in a new message m.addMessage(conversation.Message{ Role: api.MessageRoleAssistant, Content: msg.Content, + ReasoningContent: msg.ReasoningContent, }) } m.updateContent() // show cursor and reset blink interval (simulate typing) - m.replyCursor.Blink = false - cmds = append(cmds, m.replyCursor.BlinkCmd()) + if msg.ReasoningContent == "" || m.showDetails { + m.replyCursor.Blink = false + cmds = append(cmds, m.replyCursor.BlinkCmd()) + } m.tokenCount += msg.TokenCount m.elapsed = time.Now().Sub(m.startTime) @@ -137,6 +153,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { reply := conversation.Message(msg) reply.Content = strings.TrimSpace(reply.Content) + reply.ReasoningContent = strings.TrimSpace(reply.ReasoningContent) last := len(m.App.Messages) - 1 if last < 0 { diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index 8b364a7..0b686cb 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -116,19 +116,42 @@ func (m *Model) renderMessage(i int) string { // Write message contents sb := &strings.Builder{} - sb.Grow(len(msg.Content) * 2) + sb.Grow((len(msg.Content) + len(msg.ReasoningContent) * 2)) + + isLast := i == len(m.App.Messages)-1 + isAssistant := msg.Role == api.MessageRoleAssistant + hasReasoning := msg.ReasoningContent != "" + + if hasReasoning { + reasoning := strings.Builder{} + reasoning.WriteString("\n") + if m.showDetails { + //_ = m.App.Ctx.Chroma.Highlight(sb, msg.ReasoningContent) + reasoning.WriteString(msg.ReasoningContent) + } else { + reasoning.WriteString("...") + } + if m.state == pendingResponse && isLast && isAssistant && msg.Content == "" { + // Show the assistant's cursor + reasoning.WriteString(m.replyCursor.View()) + } + reasoning.WriteString("\n") + _ = m.App.Ctx.Chroma.Highlight(sb, reasoning.String()) + } + if msg.Content != "" { + if hasReasoning { + sb.WriteString("\n\n") + } err := m.App.Ctx.Chroma.Highlight(sb, msg.Content) if err != nil { + // This would wipe out the thinking text sb.Reset() sb.WriteString(msg.Content) } } - isLast := i == len(m.App.Messages)-1 - isAssistant := msg.Role == api.MessageRoleAssistant - - if m.state == pendingResponse && isLast && isAssistant { + if m.state == pendingResponse && isLast && isAssistant && (!hasReasoning || msg.Content != "") { // Show the assistant's cursor sb.WriteString(m.replyCursor.View()) }