Package restructure and API changes, several fixes
- More emphasis on `api` package. It now holds database model structs from `lmcli/models` (which is now gone) as well as the tool spec, call, and result types. `tools.Tool` is now `api.ToolSpec`. `api.ChatCompletionClient` was renamed to `api.ChatCompletionProvider`. - Change ChatCompletion interface and implementations to no longer do automatic tool call recursion - they simply return a ToolCall message which the caller can decide what to do with (e.g. prompt for user confirmation before executing) - `api.ChatCompletionProvider` functions have had their ReplyCallback parameter removed, as now they only return a single reply. - Added a top-level `agent` package, moved the current built-in tools implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now `agent.ExecuteToolCalls`. - Fixed request context handling in openai, google, ollama (use `NewRequestWithContext`), cleaned up request cancellation in TUI - Fix tool call tui persistence bug (we were skipping message with empty content) - Now handle tool calling from TUI layer TODO: - Prompt users before executing tool calls - Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
@@ -11,11 +11,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
||||
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
|
||||
requestBody := Request{
|
||||
Model: params.Model,
|
||||
Messages: make([]Message, len(messages)),
|
||||
@@ -30,7 +28,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
}
|
||||
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||
requestBody.System = messages[0].Content
|
||||
requestBody.Messages = requestBody.Messages[1:]
|
||||
startIdx = 1
|
||||
@@ -48,7 +46,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
message := &requestBody.Messages[i]
|
||||
|
||||
switch msg.Role {
|
||||
case model.MessageRoleToolCall:
|
||||
case api.MessageRoleToolCall:
|
||||
message.Role = "assistant"
|
||||
if msg.Content != "" {
|
||||
message.Content = msg.Content
|
||||
@@ -63,7 +61,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
} else {
|
||||
message.Content = xmlString
|
||||
}
|
||||
case model.MessageRoleToolResult:
|
||||
case api.MessageRoleToolResult:
|
||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||
xmlString, err := xmlFuncResults.XMLString()
|
||||
if err != nil {
|
||||
@@ -105,26 +103,25 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
) (string, error) {
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %v", err)
|
||||
return nil, fmt.Errorf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
@@ -137,34 +134,28 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
}
|
||||
|
||||
for _, content := range response.Content {
|
||||
var reply model.Message
|
||||
switch content.Type {
|
||||
case "text":
|
||||
reply = model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.Text,
|
||||
}
|
||||
sb.WriteString(reply.Content)
|
||||
sb.WriteString(content.Text)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
if callback != nil {
|
||||
callback(reply)
|
||||
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (string, error) {
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
@@ -172,19 +163,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if messages[len(messages)-1].Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
// TODO: handle this at higher level
|
||||
sb.WriteString(lastMessage.Content)
|
||||
continuation = true
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -200,29 +190,29 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(line), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||
}
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event: %s", line)
|
||||
return nil, fmt.Errorf("invalid event: %s", line)
|
||||
}
|
||||
switch eventType {
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
||||
return nil, fmt.Errorf("unknown event type: %s", eventType)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(data), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event type")
|
||||
return nil, fmt.Errorf("invalid event type")
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
@@ -235,15 +225,15 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
case "content_block_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid content block delta")
|
||||
return nil, fmt.Errorf("invalid content block delta")
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid text delta")
|
||||
return nil, fmt.Errorf("invalid text delta")
|
||||
}
|
||||
sb.WriteString(text)
|
||||
output <- api.Chunk{
|
||||
Content: text,
|
||||
Content: text,
|
||||
TokenCount: 1,
|
||||
}
|
||||
case "content_block_stop":
|
||||
@@ -251,7 +241,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
case "message_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid message delta")
|
||||
return nil, fmt.Errorf("invalid message delta")
|
||||
}
|
||||
stopReason, ok := delta["stop_reason"].(string)
|
||||
if ok && stopReason == "stop_sequence" {
|
||||
@@ -261,67 +251,39 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
|
||||
start := strings.Index(content, "<function_calls>")
|
||||
if start == -1 {
|
||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
}
|
||||
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- api.Chunk{
|
||||
Content: FUNCTION_STOP_SEQUENCE,
|
||||
Content: FUNCTION_STOP_SEQUENCE,
|
||||
TokenCount: 1,
|
||||
}
|
||||
|
||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||
|
||||
var functionCalls XMLFunctionCalls
|
||||
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
// function call xml stripped from content for model interop
|
||||
Content: strings.TrimSpace(content[:start]),
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
|
||||
messages = append(messages, toolResult)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
content := sb.String()
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
return content, nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||
}
|
||||
@@ -329,8 +291,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||
return nil, fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unexpected end of stream")
|
||||
return nil, fmt.Errorf("unexpected end of stream")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
@@ -97,7 +97,7 @@ func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||
return ret
|
||||
}
|
||||
|
||||
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
||||
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
|
||||
converted := make([]XMLToolDescription, len(tools))
|
||||
for i, tool := range tools {
|
||||
converted[i].ToolName = tool.Name
|
||||
@@ -117,8 +117,8 @@ func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
||||
}
|
||||
}
|
||||
|
||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
||||
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
|
||||
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
|
||||
for i, invoke := range functionCalls.Invoke {
|
||||
toolCalls[i].Name = invoke.ToolName
|
||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||
@@ -126,7 +126,7 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
||||
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
|
||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||
for i, toolCall := range toolCalls {
|
||||
var params XMLFunctionInvokeParameters
|
||||
@@ -145,7 +145,7 @@ func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionC
|
||||
}
|
||||
}
|
||||
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
|
||||
converted := make([]XMLFunctionResult, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
converted[i].ToolName = result.ToolName
|
||||
@@ -156,11 +156,11 @@ func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFu
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsSystemPrompt(tools []model.Tool) string {
|
||||
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
|
||||
xmlTools := convertToolsToXMLTools(tools)
|
||||
xmlToolsString, err := xmlTools.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []model.Tool to XMLTools")
|
||||
panic("Could not serialize []api.Tool to XMLTools")
|
||||
}
|
||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
||||
}
|
||||
|
||||
@@ -11,11 +11,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
geminiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
params := make(map[string]ToolParameter)
|
||||
@@ -50,7 +48,7 @@ func convertTools(tools []model.Tool) []Tool {
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
||||
converted := make([]ContentPart, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
args := make(map[string]string)
|
||||
@@ -65,8 +63,8 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(functionCalls))
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
||||
converted := make([]api.ToolCall, len(functionCalls))
|
||||
for i, call := range functionCalls {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range call.Args {
|
||||
@@ -78,7 +76,7 @@ func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
||||
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
||||
results := make([]FunctionResponse, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
var obj interface{}
|
||||
@@ -95,14 +93,14 @@ func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionRespo
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
startIdx := 0
|
||||
var system string
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||
system = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
@@ -135,9 +133,9 @@ func createGenerateContentRequest(
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case model.MessageRoleAssistant:
|
||||
case api.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case model.MessageRoleUser:
|
||||
case api.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
@@ -183,55 +181,14 @@ func createGenerateContentRequest(
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []model.ToolCall,
|
||||
callback api.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
@@ -243,42 +200,41 @@ func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Resp
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
) (string, error) {
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp GenerateContentResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
choice := completionResp.Candidates[0]
|
||||
@@ -301,58 +257,50 @@ func (c *Client) CreateChatCompletion(
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(
|
||||
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return content, nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (string, error) {
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -374,7 +322,7 @@ func (c *Client) CreateChatCompletionStream(
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
@@ -387,7 +335,7 @@ func (c *Client) CreateChatCompletionStream(
|
||||
var resp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &resp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||
@@ -409,21 +357,15 @@ func (c *Client) CreateChatCompletionStream(
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(
|
||||
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content.String(),
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
@@ -43,8 +42,8 @@ type OllamaResponse struct {
|
||||
}
|
||||
|
||||
func createOllamaRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) OllamaRequest {
|
||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||
|
||||
@@ -64,11 +63,11 @@ func createOllamaRequest(
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,12 +82,11 @@ func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*htt
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
) (string, error) {
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
@@ -96,46 +94,40 @@ func (c *OllamaClient) CreateChatCompletion(
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp OllamaResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := completionResp.Message.Content
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return content, nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: completionResp.Message.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (string, error) {
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
@@ -143,17 +135,17 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -166,7 +158,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
@@ -177,7 +169,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
||||
var streamResp OllamaResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
@@ -189,12 +181,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
||||
}
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -11,11 +11,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
openaiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
openaiTools[i].Type = "function"
|
||||
@@ -47,7 +45,7 @@ func convertTools(tools []model.Tool) []Tool {
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
||||
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
||||
converted := make([]ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].Type = "function"
|
||||
@@ -60,8 +58,8 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(toolCalls))
|
||||
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
||||
converted := make([]api.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Name = call.Function.Name
|
||||
@@ -71,8 +69,8 @@ func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) ChatCompletionRequest {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
|
||||
@@ -117,56 +115,15 @@ func createChatCompletionRequest(
|
||||
return request
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []ToolCall,
|
||||
callback api.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
@@ -178,35 +135,34 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*htt
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
) (string, error) {
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
choice := completionResp.Choices[0]
|
||||
@@ -221,34 +177,27 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return content, nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (string, error) {
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
@@ -256,17 +205,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -285,7 +234,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
@@ -301,7 +250,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
var streamResp ChatCompletionStreamResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delta := streamResp.Choices[0].Delta
|
||||
@@ -309,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
@@ -328,21 +277,15 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
} else {
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content.String(),
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user