Private
Public Access
1
0

Package restructure and API changes, several fixes

- More emphasis on `api` package. It now holds database model structs
  from `lmcli/models` (which is now gone) as well as the tool spec,
  call, and result types. `tools.Tool` is now `api.ToolSpec`.
  `api.ChatCompletionClient` was renamed to
  `api.ChatCompletionProvider`.

- Change ChatCompletion interface and implementations to no longer do
  automatic tool call recursion - they simply return a ToolCall message
  which the caller can decide what to do with (e.g. prompt for user
  confirmation before executing)

- `api.ChatCompletionProvider` functions have had their ReplyCallback
  parameter removed, as now they only return a single reply.

- Added a top-level `agent` package, moved the current built-in tools
  implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now
  `agent.ExecuteToolCalls`.

- Fixed request context handling in openai, google, ollama (use
  `NewRequestWithContext`), cleaned up request cancellation in TUI

- Fix tool call tui persistence bug (we were skipping message with empty
  content)

- Now handle tool calling from TUI layer

TODO:
- Prompt users before executing tool calls
- Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
2024-06-12 08:35:07 +00:00
parent 85a2abbbf3
commit 3fde58b77d
35 changed files with 608 additions and 749 deletions

View File

@@ -11,11 +11,9 @@ import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
requestBody := Request{
Model: params.Model,
Messages: make([]Message, len(messages)),
@@ -30,7 +28,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
}
startIdx := 0
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:]
startIdx = 1
@@ -48,7 +46,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
message := &requestBody.Messages[i]
switch msg.Role {
case model.MessageRoleToolCall:
case api.MessageRoleToolCall:
message.Role = "assistant"
if msg.Content != "" {
message.Content = msg.Content
@@ -63,7 +61,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
} else {
message.Content = xmlString
}
case model.MessageRoleToolResult:
case api.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString()
if err != nil {
@@ -105,26 +103,25 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback api.ReplyCallback,
) (string, error) {
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
return nil, fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
return nil, err
}
defer resp.Body.Close()
var response Response
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", fmt.Errorf("failed to decode response: %v", err)
return nil, fmt.Errorf("failed to decode response: %v", err)
}
sb := strings.Builder{}
@@ -137,34 +134,28 @@ func (c *AnthropicClient) CreateChatCompletion(
}
for _, content := range response.Content {
var reply model.Message
switch content.Type {
case "text":
reply = model.Message{
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
sb.WriteString(content.Text)
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
}
}
return sb.String(), nil
return &api.Message{
Role: api.MessageRoleAssistant,
Content: sb.String(),
}, nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback api.ReplyCallback,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (string, error) {
) (*api.Message, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
return nil, fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages)
@@ -172,19 +163,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
return nil, err
}
defer resp.Body.Close()
sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
// TODO: handle this at higher level
sb.WriteString(lastMessage.Content)
continuation = true
}
scanner := bufio.NewScanner(resp.Body)
@@ -200,29 +190,29 @@ func (c *AnthropicClient) CreateChatCompletionStream(
var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event: %s", line)
return nil, fmt.Errorf("invalid event: %s", line)
}
switch eventType {
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
return nil, fmt.Errorf("an error occurred: %s", event["error"])
default:
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
return nil, fmt.Errorf("unknown event type: %s", eventType)
}
} else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event type")
return nil, fmt.Errorf("invalid event type")
}
switch eventType {
@@ -235,15 +225,15 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid content block delta")
return nil, fmt.Errorf("invalid content block delta")
}
text, ok := delta["text"].(string)
if !ok {
return "", fmt.Errorf("invalid text delta")
return nil, fmt.Errorf("invalid text delta")
}
sb.WriteString(text)
output <- api.Chunk{
Content: text,
Content: text,
TokenCount: 1,
}
case "content_block_stop":
@@ -251,7 +241,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "message_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid message delta")
return nil, fmt.Errorf("invalid message delta")
}
stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" {
@@ -261,67 +251,39 @@ func (c *AnthropicClient) CreateChatCompletionStream(
start := strings.Index(content, "<function_calls>")
if start == -1 {
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- api.Chunk{
Content: FUNCTION_STOP_SEQUENCE,
Content: FUNCTION_STOP_SEQUENCE,
TokenCount: 1,
}
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
return &api.Message{
Role: api.MessageRoleToolCall,
// function call xml stripped from content for model interop
Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}, nil
}
}
case "message_stop":
// return the completed message
content := sb.String()
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
return content, nil
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
return nil, fmt.Errorf("an error occurred: %s", event["error"])
default:
fmt.Printf("\nUnrecognized event: %s\n", data)
}
@@ -329,8 +291,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read response body: %v", err)
return nil, fmt.Errorf("failed to read response body: %v", err)
}
return "", fmt.Errorf("unexpected end of stream")
return nil, fmt.Errorf("unexpected end of stream")
}

View File

@@ -6,7 +6,7 @@ import (
"strings"
"text/template"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
@@ -97,7 +97,7 @@ func parseFunctionParametersXML(params string) map[string]interface{} {
return ret
}
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools {
converted[i].ToolName = tool.Name
@@ -117,8 +117,8 @@ func convertToolsToXMLTools(tools []model.Tool) XMLTools {
}
}
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
@@ -126,7 +126,7 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.
return toolCalls
}
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters
@@ -145,7 +145,7 @@ func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionC
}
}
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults {
converted[i].ToolName = result.ToolName
@@ -156,11 +156,11 @@ func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFu
}
}
func buildToolsSystemPrompt(tools []model.Tool) string {
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString()
if err != nil {
panic("Could not serialize []model.Tool to XMLTools")
panic("Could not serialize []api.Tool to XMLTools")
}
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
}