package google import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) func convertTools(tools []model.Tool) []Tool { geminiTools := make([]Tool, len(tools)) for i, tool := range tools { params := make(map[string]ToolParameter) var required []string for _, param := range tool.Parameters { // TODO: proper enum handing params[param.Name] = ToolParameter{ Type: param.Type, Description: param.Description, Values: param.Enum, } if param.Required { required = append(required, param.Name) } } geminiTools[i] = Tool{ FunctionDeclarations: []FunctionDeclaration{ { Name: tool.Name, Description: tool.Description, Parameters: ToolParameters{ Type: "OBJECT", Properties: params, Required: required, }, }, }, } } return geminiTools } func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart { converted := make([]ContentPart, len(toolCalls)) for i, call := range toolCalls { args := make(map[string]string) for k, v := range call.Parameters { args[k] = fmt.Sprintf("%v", v) } converted[i].FunctionCall = &FunctionCall{ Name: call.Name, Args: args, } } return converted } func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall { converted := make([]model.ToolCall, len(functionCalls)) for i, call := range functionCalls { params := make(map[string]interface{}) for k, v := range call.Args { params[k] = v } converted[i].Name = call.Name converted[i].Parameters = params } return converted } func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) { results := make([]FunctionResponse, len(toolResults)) for i, result := range toolResults { var obj interface{} err := json.Unmarshal([]byte(result.Result), &obj) if err != nil { return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err) } results[i] = FunctionResponse{ Name: result.ToolName, Response: obj, } } return results, nil } func createGenerateContentRequest( params model.RequestParameters, messages []model.Message, ) (*GenerateContentRequest, error) { requestContents := make([]Content, 0, len(messages)) startIdx := 0 var system string if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { system = messages[0].Content startIdx = 1 } for _, m := range messages[startIdx:] { switch m.Role { case "tool_call": content := Content{ Role: "model", Parts: convertToolCallToGemini(m.ToolCalls), } requestContents = append(requestContents, content) case "tool_result": results, err := convertToolResultsToGemini(m.ToolResults) if err != nil { return nil, err } // expand tool_result messages' results into multiple gemini messages for _, result := range results { content := Content{ Role: "function", Parts: []ContentPart{ { FunctionResp: &result, }, }, } requestContents = append(requestContents, content) } default: var role string switch m.Role { case model.MessageRoleAssistant: role = "model" case model.MessageRoleUser: role = "user" } if role == "" { panic("Unhandled role: " + m.Role) } content := Content{ Role: role, Parts: []ContentPart{ { Text: m.Content, }, }, } requestContents = append(requestContents, content) } } request := &GenerateContentRequest{ Contents: requestContents, GenerationConfig: &GenerationConfig{ MaxOutputTokens: ¶ms.MaxTokens, Temperature: ¶ms.Temperature, TopP: ¶ms.TopP, }, } if system != "" { request.SystemInstruction = &Content{ Parts: []ContentPart{ { Text: system, }, }, } } if len(params.ToolBag) > 0 { request.Tools = convertTools(params.ToolBag) } return request, nil } func handleToolCalls( params model.RequestParameters, content string, toolCalls []model.ToolCall, callback api.ReplyCallback, messages []model.Message, ) ([]model.Message, error) { lastMessage := messages[len(messages)-1] continuation := false if lastMessage.Role.IsAssistant() { continuation = true } toolCall := model.Message{ Role: model.MessageRoleToolCall, Content: content, ToolCalls: toolCalls, } toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) if err != nil { return nil, err } toolResult := model.Message{ Role: model.MessageRoleToolResult, ToolResults: toolResults, } if callback != nil { callback(toolCall) callback(toolResult) } if continuation { messages[len(messages)-1] = toolCall } else { messages = append(messages, toolCall) } messages = append(messages, toolResult) return messages, nil } func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") client := &http.Client{} resp, err := client.Do(req.WithContext(ctx)) if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("%v", string(bytes)) } return resp, err } func (c *Client) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, callback api.ReplyCallback, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } req, err := createGenerateContentRequest(params, messages) if err != nil { return "", err } jsonData, err := json.Marshal(req) if err != nil { return "", err } url := fmt.Sprintf( "%s/v1beta/models/%s:generateContent?key=%s", c.BaseURL, params.Model, c.APIKey, ) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return "", err } resp, err := c.sendRequest(ctx, httpReq) if err != nil { return "", err } defer resp.Body.Close() var completionResp GenerateContentResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { return "", err } choice := completionResp.Candidates[0] var content string lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content = lastMessage.Content } var toolCalls []FunctionCall for _, part := range choice.Content.Parts { if part.Text != "" { content += part.Text } if part.FunctionCall != nil { toolCalls = append(toolCalls, *part.FunctionCall) } } if len(toolCalls) > 0 { messages, err := handleToolCalls( params, content, convertToolCallToAPI(toolCalls), callback, messages, ) if err != nil { return content, err } return c.CreateChatCompletion(ctx, params, messages, callback) } if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant, Content: content, }) } return content, nil } func (c *Client) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, callback api.ReplyCallback, output chan<- api.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } req, err := createGenerateContentRequest(params, messages) if err != nil { return "", err } jsonData, err := json.Marshal(req) if err != nil { return "", err } url := fmt.Sprintf( "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", c.BaseURL, params.Model, c.APIKey, ) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return "", err } resp, err := c.sendRequest(ctx, httpReq) if err != nil { return "", err } defer resp.Body.Close() content := strings.Builder{} lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content.WriteString(lastMessage.Content) } var toolCalls []FunctionCall reader := bufio.NewReader(resp.Body) for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return "", err } line = bytes.TrimSpace(line) if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { continue } line = bytes.TrimPrefix(line, []byte("data: ")) var streamResp GenerateContentResponse err = json.Unmarshal(line, &streamResp) if err != nil { return "", err } for _, candidate := range streamResp.Candidates { for _, part := range candidate.Content.Parts { if part.FunctionCall != nil { toolCalls = append(toolCalls, *part.FunctionCall) } else if part.Text != "" { output <- api.Chunk { Content: part.Text, } content.WriteString(part.Text) } } } } // If there are function calls, handle them and recurse if len(toolCalls) > 0 { messages, err := handleToolCalls( params, content.String(), convertToolCallToAPI(toolCalls), callback, messages, ) if err != nil { return content.String(), err } return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant, Content: content.String(), }) } return content.String(), nil }