Private
Public Access
1
0

Package restructure and API changes, several fixes

- More emphasis on `api` package. It now holds database model structs
  from `lmcli/models` (which is now gone) as well as the tool spec,
  call, and result types. `tools.Tool` is now `api.ToolSpec`.
  `api.ChatCompletionClient` was renamed to
  `api.ChatCompletionProvider`.

- Change ChatCompletion interface and implementations to no longer do
  automatic tool call recursion - they simply return a ToolCall message
  which the caller can decide what to do with (e.g. prompt for user
  confirmation before executing)

- `api.ChatCompletionProvider` functions have had their ReplyCallback
  parameter removed, as now they only return a single reply.

- Added a top-level `agent` package, moved the current built-in tools
  implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now
  `agent.ExecuteToolCalls`.

- Fixed request context handling in openai, google, ollama (use
  `NewRequestWithContext`), cleaned up request cancellation in TUI

- Fix tool call tui persistence bug (we were skipping message with empty
  content)

- Now handle tool calling from TUI layer

TODO:
- Prompt users before executing tool calls
- Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
2024-06-12 08:35:07 +00:00
parent 85a2abbbf3
commit 3fde58b77d
35 changed files with 608 additions and 749 deletions

View File

@@ -11,11 +11,9 @@ import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
func convertTools(tools []model.Tool) []Tool {
func convertTools(tools []api.ToolSpec) []Tool {
geminiTools := make([]Tool, len(tools))
for i, tool := range tools {
params := make(map[string]ToolParameter)
@@ -50,7 +48,7 @@ func convertTools(tools []model.Tool) []Tool {
return geminiTools
}
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls {
args := make(map[string]string)
@@ -65,8 +63,8 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
return converted
}
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
converted := make([]model.ToolCall, len(functionCalls))
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
converted := make([]api.ToolCall, len(functionCalls))
for i, call := range functionCalls {
params := make(map[string]interface{})
for k, v := range call.Args {
@@ -78,7 +76,7 @@ func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
return converted
}
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
@@ -95,14 +93,14 @@ func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionRespo
}
func createGenerateContentRequest(
params model.RequestParameters,
messages []model.Message,
params api.RequestParameters,
messages []api.Message,
) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages))
startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
system = messages[0].Content
startIdx = 1
}
@@ -135,9 +133,9 @@ func createGenerateContentRequest(
default:
var role string
switch m.Role {
case model.MessageRoleAssistant:
case api.MessageRoleAssistant:
role = "model"
case model.MessageRoleUser:
case api.MessageRoleUser:
role = "user"
}
@@ -183,55 +181,14 @@ func createGenerateContentRequest(
return request, nil
}
func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []model.ToolCall,
callback api.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx))
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
@@ -243,42 +200,41 @@ func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Resp
func (c *Client) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback api.ReplyCallback,
) (string, error) {
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
return nil, fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
return nil, err
}
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
return nil, err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
return nil, err
}
resp, err := c.sendRequest(ctx, httpReq)
resp, err := c.sendRequest(httpReq)
if err != nil {
return "", err
return nil, err
}
defer resp.Body.Close()
var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return "", err
return nil, err
}
choice := completionResp.Candidates[0]
@@ -301,58 +257,50 @@ func (c *Client) CreateChatCompletion(
}
if len(toolCalls) > 0 {
messages, err := handleToolCalls(
params, content, convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil {
return content, err
}
return c.CreateChatCompletion(ctx, params, messages, callback)
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
return content, nil
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
}
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback api.ReplyCallback,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (string, error) {
) (*api.Message, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
return nil, fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
return nil, err
}
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
return nil, err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
return nil, err
}
resp, err := c.sendRequest(ctx, httpReq)
resp, err := c.sendRequest(httpReq)
if err != nil {
return "", err
return nil, err
}
defer resp.Body.Close()
@@ -374,7 +322,7 @@ func (c *Client) CreateChatCompletionStream(
if err == io.EOF {
break
}
return "", err
return nil, err
}
line = bytes.TrimSpace(line)
@@ -387,7 +335,7 @@ func (c *Client) CreateChatCompletionStream(
var resp GenerateContentResponse
err = json.Unmarshal(line, &resp)
if err != nil {
return "", err
return nil, err
}
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
@@ -409,21 +357,15 @@ func (c *Client) CreateChatCompletionStream(
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
messages, err := handleToolCalls(
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil {
return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content.String(),
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
return content.String(), nil
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
}