Private
Public Access
1
0

Rough-in support for deepseeek-style separate reasoning output

This commit is contained in:
2025-01-25 19:18:52 +00:00
parent fb3edad0c3
commit 9372c1d2c0
10 changed files with 94 additions and 39 deletions

View File

@@ -16,10 +16,11 @@ const (
) )
type Message struct { type Message struct {
Content string // TODO: support multi-part messages Content string // TODO: support multi-part messages
Role MessageRole ReasoningContent string
ToolCalls []ToolCall Role MessageRole
ToolResults []ToolResult ToolCalls []ToolCall
ToolResults []ToolResult
} }
type ToolSpec struct { type ToolSpec struct {
@@ -49,10 +50,11 @@ type ToolResult struct {
Result string `json:"result,omitempty" yaml:"result"` Result string `json:"result,omitempty" yaml:"result"`
} }
func NewMessageWithAssistant(content string) *Message { func NewMessageWithAssistant(content string, reasoning string) *Message {
return &Message{ return &Message{
Role: MessageRoleAssistant, Role: MessageRoleAssistant,
Content: content, Content: content,
ReasoningContent: reasoning,
} }
} }

View File

@@ -38,10 +38,11 @@ type Message struct {
SelectedReplyID *uint SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
Role api.MessageRole Role api.MessageRole
Content string Content string
ToolCalls ToolCalls // a json array of tool calls (from the model) ReasoningContent string
ToolResults ToolResults // a json array of tool results ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
} }
func (m *MessageMeta) Scan(value interface{}) error { func (m *MessageMeta) Scan(value interface{}) error {

View File

@@ -22,10 +22,11 @@ func ApplySystemPrompt(m []Message, system string, force bool) []Message {
func MessageToAPI(m Message) api.Message { func MessageToAPI(m Message) api.Message {
return api.Message{ return api.Message{
Role: m.Role, Role: m.Role,
Content: m.Content, Content: m.Content,
ToolCalls: m.ToolCalls, ReasoningContent: m.ReasoningContent,
ToolResults: m.ToolResults, ToolCalls: m.ToolCalls,
ToolResults: m.ToolResults,
} }
} }
@@ -39,10 +40,11 @@ func MessagesToAPI(messages []Message) []api.Message {
func MessageFromAPI(m api.Message) Message { func MessageFromAPI(m api.Message) Message {
return Message{ return Message{
Role: m.Role, Role: m.Role,
Content: m.Content, Content: m.Content,
ToolCalls: m.ToolCalls, ReasoningContent: m.ReasoningContent,
ToolResults: m.ToolResults, ToolCalls: m.ToolCalls,
ToolResults: m.ToolResults,
} }
} }

View File

@@ -443,5 +443,5 @@ func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error)
return api.NewMessageWithToolCalls(content.String(), toolCalls), nil return api.NewMessageWithToolCalls(content.String(), toolCalls), nil
} }
return api.NewMessageWithAssistant(content.String()), nil return api.NewMessageWithAssistant(content.String(), ""), nil
} }

View File

@@ -340,7 +340,7 @@ func (c *Client) CreateChatCompletion(
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
} }
return api.NewMessageWithAssistant(content), nil return api.NewMessageWithAssistant(content, ""), nil
} }
func (c *Client) CreateChatCompletionStream( func (c *Client) CreateChatCompletionStream(
@@ -432,5 +432,5 @@ func (c *Client) CreateChatCompletionStream(
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
} }
return api.NewMessageWithAssistant(content.String()), nil return api.NewMessageWithAssistant(content.String(), ""), nil
} }

View File

@@ -115,7 +115,7 @@ func (c *OllamaClient) CreateChatCompletion(
return nil, err return nil, err
} }
return api.NewMessageWithAssistant(completionResp.Message.Content), nil return api.NewMessageWithAssistant(completionResp.Message.Content, ""), nil
} }
func (c *OllamaClient) CreateChatCompletionStream( func (c *OllamaClient) CreateChatCompletionStream(
@@ -179,5 +179,5 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
} }
return api.NewMessageWithAssistant(content.String()), nil return api.NewMessageWithAssistant(content.String(), ""), nil
} }

View File

@@ -21,10 +21,11 @@ type OpenAIClient struct {
} }
type ChatCompletionMessage struct { type ChatCompletionMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
} }
type ToolCall struct { type ToolCall struct {
@@ -256,7 +257,7 @@ func (c *OpenAIClient) CreateChatCompletion(
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
} }
return api.NewMessageWithAssistant(content), nil return api.NewMessageWithAssistant(content, ""), nil
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
@@ -279,6 +280,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
defer resp.Body.Close() defer resp.Body.Close()
content := strings.Builder{} content := strings.Builder{}
reasoning := strings.Builder{}
toolCalls := []ToolCall{} toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@@ -333,11 +335,18 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }
if len(delta.ReasoningContent) > 0 {
output <- provider.Chunk{
ReasoningContent: delta.ReasoningContent,
TokenCount: 1,
}
reasoning.WriteString(delta.ReasoningContent)
}
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
} }
return api.NewMessageWithAssistant(content.String()), nil return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil
} }

View File

@@ -7,8 +7,9 @@ import (
) )
type Chunk struct { type Chunk struct {
Content string Content string
TokenCount uint ReasoningContent string
TokenCount uint
} }
type RequestParameters struct { type RequestParameters struct {

View File

@@ -33,6 +33,14 @@ func (m *Model) setMessageContents(i int, content string) {
m.messageCache[i] = m.renderMessage(i) m.messageCache[i] = m.renderMessage(i)
} }
func (m *Model) setReasoningContents(i int, content string) {
if i >= len(m.App.Messages) {
panic("i out of range")
}
m.App.Messages[i].ReasoningContent = content
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) rebuildMessageCache() { func (m *Model) rebuildMessageCache() {
m.messageCache = make([]string, len(m.App.Messages)) m.messageCache = make([]string, len(m.App.Messages))
for i := range m.App.Messages { for i := range m.App.Messages {
@@ -108,7 +116,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
case msgChatResponseChunk: case msgChatResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
if msg.Content == "" { if msg.Content == "" && msg.ReasoningContent == "" {
// skip empty chunks // skip empty chunks
break break
} }
@@ -116,19 +124,27 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
last := len(m.App.Messages) - 1 last := len(m.App.Messages) - 1
if last >= 0 && m.App.Messages[last].Role.IsAssistant() { if last >= 0 && m.App.Messages[last].Role.IsAssistant() {
// append chunk to existing message // append chunk to existing message
m.setMessageContents(last, m.App.Messages[last].Content+msg.Content) if msg.Content != "" {
m.setMessageContents(last, m.App.Messages[last].Content+msg.Content)
}
if msg.ReasoningContent != "" {
m.setReasoningContents(last, m.App.Messages[last].ReasoningContent+msg.ReasoningContent)
}
} else { } else {
// use chunk in a new message // use chunk in a new message
m.addMessage(conversation.Message{ m.addMessage(conversation.Message{
Role: api.MessageRoleAssistant, Role: api.MessageRoleAssistant,
Content: msg.Content, Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
}) })
} }
m.updateContent() m.updateContent()
// show cursor and reset blink interval (simulate typing) // show cursor and reset blink interval (simulate typing)
m.replyCursor.Blink = false if msg.ReasoningContent == "" || m.showDetails {
cmds = append(cmds, m.replyCursor.BlinkCmd()) m.replyCursor.Blink = false
cmds = append(cmds, m.replyCursor.BlinkCmd())
}
m.tokenCount += msg.TokenCount m.tokenCount += msg.TokenCount
m.elapsed = time.Now().Sub(m.startTime) m.elapsed = time.Now().Sub(m.startTime)
@@ -137,6 +153,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
reply := conversation.Message(msg) reply := conversation.Message(msg)
reply.Content = strings.TrimSpace(reply.Content) reply.Content = strings.TrimSpace(reply.Content)
reply.ReasoningContent = strings.TrimSpace(reply.ReasoningContent)
last := len(m.App.Messages) - 1 last := len(m.App.Messages) - 1
if last < 0 { if last < 0 {

View File

@@ -116,19 +116,42 @@ func (m *Model) renderMessage(i int) string {
// Write message contents // Write message contents
sb := &strings.Builder{} sb := &strings.Builder{}
sb.Grow(len(msg.Content) * 2) sb.Grow((len(msg.Content) + len(msg.ReasoningContent) * 2))
isLast := i == len(m.App.Messages)-1
isAssistant := msg.Role == api.MessageRoleAssistant
hasReasoning := msg.ReasoningContent != ""
if hasReasoning {
reasoning := strings.Builder{}
reasoning.WriteString("<thinking>\n")
if m.showDetails {
//_ = m.App.Ctx.Chroma.Highlight(sb, msg.ReasoningContent)
reasoning.WriteString(msg.ReasoningContent)
} else {
reasoning.WriteString("...")
}
if m.state == pendingResponse && isLast && isAssistant && msg.Content == "" {
// Show the assistant's cursor
reasoning.WriteString(m.replyCursor.View())
}
reasoning.WriteString("\n</thinking>")
_ = m.App.Ctx.Chroma.Highlight(sb, reasoning.String())
}
if msg.Content != "" { if msg.Content != "" {
if hasReasoning {
sb.WriteString("\n\n")
}
err := m.App.Ctx.Chroma.Highlight(sb, msg.Content) err := m.App.Ctx.Chroma.Highlight(sb, msg.Content)
if err != nil { if err != nil {
// This would wipe out the thinking text
sb.Reset() sb.Reset()
sb.WriteString(msg.Content) sb.WriteString(msg.Content)
} }
} }
isLast := i == len(m.App.Messages)-1 if m.state == pendingResponse && isLast && isAssistant && (!hasReasoning || msg.Content != "") {
isAssistant := msg.Role == api.MessageRoleAssistant
if m.state == pendingResponse && isLast && isAssistant {
// Show the assistant's cursor // Show the assistant's cursor
sb.WriteString(m.replyCursor.View()) sb.WriteString(m.replyCursor.View())
} }