Refactor pkg/lmcli/provider
Moved `ChangeCompletionInterface` to `pkg/api`, moved individual providers to `pkg/api/provider`
This commit is contained in:
@@ -6,12 +6,12 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||
@@ -79,7 +79,7 @@ func (c *Context) GetModels() (models []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) {
|
||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
|
||||
parts := strings.Split(model, "/")
|
||||
|
||||
var provider string
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
||||
requestBody := Request{
|
||||
Model: params.Model,
|
||||
Messages: make([]Message, len(messages)),
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Stream: false,
|
||||
|
||||
StopSequences: []string{
|
||||
FUNCTION_STOP_SEQUENCE,
|
||||
"\n\nHuman:",
|
||||
},
|
||||
}
|
||||
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
requestBody.System = messages[0].Content
|
||||
requestBody.Messages = requestBody.Messages[1:]
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
if len(requestBody.System) > 0 {
|
||||
// add a divider between existing system prompt and tools
|
||||
requestBody.System += "\n\n---\n\n"
|
||||
}
|
||||
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
||||
}
|
||||
|
||||
for i, msg := range messages[startIdx:] {
|
||||
message := &requestBody.Messages[i]
|
||||
|
||||
switch msg.Role {
|
||||
case model.MessageRoleToolCall:
|
||||
message.Role = "assistant"
|
||||
if msg.Content != "" {
|
||||
message.Content = msg.Content
|
||||
}
|
||||
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
||||
xmlString, err := xmlFuncCalls.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
||||
}
|
||||
if len(message.Content) > 0 {
|
||||
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
||||
} else {
|
||||
message.Content = xmlString
|
||||
}
|
||||
case model.MessageRoleToolResult:
|
||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||
xmlString, err := xmlFuncResults.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||
}
|
||||
message.Role = "user"
|
||||
message.Content = xmlString
|
||||
default:
|
||||
message.Role = string(msg.Role)
|
||||
message.Content = msg.Content
|
||||
}
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
jsonBody, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", c.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
sb.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
for _, content := range response.Content {
|
||||
var reply model.Message
|
||||
switch content.Type {
|
||||
case "text":
|
||||
reply = model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.Text,
|
||||
}
|
||||
sb.WriteString(reply.Content)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
if callback != nil {
|
||||
callback(reply)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- provider.Chunk,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
request.Stream = true
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if messages[len(messages)-1].Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
sb.WriteString(lastMessage.Content)
|
||||
continuation = true
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if line[0] == '{' {
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(line), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||
}
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event: %s", line)
|
||||
}
|
||||
switch eventType {
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(data), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event type")
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// noop
|
||||
case "ping":
|
||||
// signals start of text - currently ignoring
|
||||
case "content_block_start":
|
||||
// ignore?
|
||||
case "content_block_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid content block delta")
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid text delta")
|
||||
}
|
||||
sb.WriteString(text)
|
||||
output <- provider.Chunk{
|
||||
Content: text,
|
||||
}
|
||||
case "content_block_stop":
|
||||
// ignore?
|
||||
case "message_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid message delta")
|
||||
}
|
||||
stopReason, ok := delta["stop_reason"].(string)
|
||||
if ok && stopReason == "stop_sequence" {
|
||||
stopSequence, ok := delta["stop_sequence"].(string)
|
||||
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
||||
content := sb.String()
|
||||
|
||||
start := strings.Index(content, "<function_calls>")
|
||||
if start == -1 {
|
||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
}
|
||||
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- provider.Chunk{
|
||||
Content: FUNCTION_STOP_SEQUENCE,
|
||||
}
|
||||
|
||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||
|
||||
var functionCalls XMLFunctionCalls
|
||||
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
// function call xml stripped from content for model interop
|
||||
Content: strings.TrimSpace(content[:start]),
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
|
||||
messages = append(messages, toolResult)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
content := sb.String()
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
return content, nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unexpected end of stream")
|
||||
}
|
||||
@@ -1,232 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
|
||||
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
||||
|
||||
You may call them like this:
|
||||
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:`
|
||||
|
||||
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
|
||||
|
||||
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
|
||||
|
||||
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
|
||||
|
||||
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
|
||||
|
||||
type XMLTools struct {
|
||||
XMLName struct{} `xml:"tools"`
|
||||
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
||||
}
|
||||
|
||||
type XMLToolDescription struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Description string `xml:"description"`
|
||||
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
||||
}
|
||||
|
||||
type XMLToolParameter struct {
|
||||
Name string `xml:"name"`
|
||||
Type string `xml:"type"`
|
||||
Description string `xml:"description"`
|
||||
}
|
||||
|
||||
type XMLFunctionCalls struct {
|
||||
XMLName struct{} `xml:"function_calls"`
|
||||
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvoke struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvokeParameters struct {
|
||||
String string `xml:",innerxml"`
|
||||
}
|
||||
|
||||
type XMLFunctionResults struct {
|
||||
XMLName struct{} `xml:"function_results"`
|
||||
Result []XMLFunctionResult `xml:"result"`
|
||||
}
|
||||
|
||||
type XMLFunctionResult struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Stdout string `xml:"stdout"`
|
||||
}
|
||||
|
||||
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
||||
// parameters name to value
|
||||
func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||
lines := strings.Split(params, "\n")
|
||||
ret := make(map[string]interface{}, len(lines))
|
||||
for _, line := range lines {
|
||||
i := strings.Index(line, ">")
|
||||
if i == -1 {
|
||||
continue
|
||||
}
|
||||
j := strings.Index(line, "</")
|
||||
if j == -1 {
|
||||
continue
|
||||
}
|
||||
// chop from after opening < to first > to get parameter name,
|
||||
// then chop after > to first </ to get parameter value
|
||||
ret[line[1:i]] = line[i+1 : j]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
||||
converted := make([]XMLToolDescription, len(tools))
|
||||
for i, tool := range tools {
|
||||
converted[i].ToolName = tool.Name
|
||||
converted[i].Description = tool.Description
|
||||
|
||||
params := make([]XMLToolParameter, len(tool.Parameters))
|
||||
for j, param := range tool.Parameters {
|
||||
params[j].Name = param.Name
|
||||
params[j].Description = param.Description
|
||||
params[j].Type = param.Type
|
||||
}
|
||||
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return XMLTools{
|
||||
ToolDescriptions: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
||||
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
||||
for i, invoke := range functionCalls.Invoke {
|
||||
toolCalls[i].Name = invoke.ToolName
|
||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||
for i, toolCall := range toolCalls {
|
||||
var params XMLFunctionInvokeParameters
|
||||
var paramXML string
|
||||
for key, value := range toolCall.Parameters {
|
||||
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
||||
}
|
||||
params.String = paramXML
|
||||
converted[i] = XMLFunctionInvoke{
|
||||
ToolName: toolCall.Name,
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
return XMLFunctionCalls{
|
||||
Invoke: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||
converted := make([]XMLFunctionResult, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
converted[i].ToolName = result.ToolName
|
||||
converted[i].Stdout = result.Result
|
||||
}
|
||||
return XMLFunctionResults{
|
||||
Result: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsSystemPrompt(tools []model.Tool) string {
|
||||
xmlTools := convertToolsToXMLTools(tools)
|
||||
xmlToolsString, err := xmlTools.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []model.Tool to XMLTools")
|
||||
}
|
||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
||||
}
|
||||
|
||||
func (x XMLTools) XMLString() (string, error) {
|
||||
tmpl, err := template.New("tools").Parse(`<tools>
|
||||
{{range .ToolDescriptions}}<tool_description>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<description>
|
||||
{{.Description}}
|
||||
</description>
|
||||
<parameters>
|
||||
{{range .Parameters}}<parameter>
|
||||
<name>{{.Name}}</name>
|
||||
<type>{{.Type}}</type>
|
||||
<description>{{.Description}}</description>
|
||||
</parameter>
|
||||
{{end}}</parameters>
|
||||
</tool_description>
|
||||
{{end}}</tools>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionResults) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_results").Parse(`<function_results>
|
||||
{{range .Result}}<result>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<stdout>{{.Stdout}}</stdout>
|
||||
</result>
|
||||
{{end}}</function_results>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionCalls) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
||||
{{range .Invoke}}<invoke>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<parameters>{{.Parameters.String}}</parameters>
|
||||
</invoke>
|
||||
{{end}}</function_calls>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
type AnthropicClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
//TopP float32 `json:"top_p,omitempty"`
|
||||
//TopK float32 `json:"top_k,omitempty"`
|
||||
}
|
||||
|
||||
type OriginalContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []OriginalContent `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence string `json:"stop_sequence"`
|
||||
}
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
geminiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
// TODO: proper enum handing
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Values: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
geminiTools[i] = Tool{
|
||||
FunctionDeclarations: []FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "OBJECT",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||
converted := make([]ContentPart, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
args := make(map[string]string)
|
||||
for k, v := range call.Parameters {
|
||||
args[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
converted[i].FunctionCall = &FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(functionCalls))
|
||||
for i, call := range functionCalls {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range call.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = call.Name
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
||||
results := make([]FunctionResponse, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||
}
|
||||
results[i] = FunctionResponse{
|
||||
Name: result.ToolName,
|
||||
Response: obj,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
startIdx := 0
|
||||
var system string
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
system = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
for _, m := range messages[startIdx:] {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
content := Content{
|
||||
Role: "model",
|
||||
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
case "tool_result":
|
||||
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// expand tool_result messages' results into multiple gemini messages
|
||||
for _, result := range results {
|
||||
content := Content{
|
||||
Role: "function",
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
FunctionResp: &result,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case model.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case model.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
panic("Unhandled role: " + m.Role)
|
||||
}
|
||||
|
||||
content := Content{
|
||||
Role: role,
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: m.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
}
|
||||
|
||||
request := &GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
GenerationConfig: &GenerationConfig{
|
||||
MaxOutputTokens: ¶ms.MaxTokens,
|
||||
Temperature: ¶ms.Temperature,
|
||||
TopP: ¶ms.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if system != "" {
|
||||
request.SystemInstruction = &Content{
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: system,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []model.ToolCall,
|
||||
callback provider.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp GenerateContentResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := completionResp.Candidates[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(
|
||||
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- provider.Chunk,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamResp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- provider.Chunk {
|
||||
Content: part.Text,
|
||||
}
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(
|
||||
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package google
|
||||
|
||||
type Client struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]string `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []ContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GenerationConfig struct {
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
type GenerateContentResponse struct {
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type FunctionDeclaration struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Values []string `json:"values,omitempty"`
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OllamaResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message OllamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration uint64 `json:"total_duration,omitempty"`
|
||||
LoadDuration uint64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount uint64 `json:"eval_count,omitempty"`
|
||||
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
func createOllamaRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) OllamaRequest {
|
||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
message := OllamaMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
|
||||
request := OllamaRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp OllamaResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content := completionResp.Message.Content
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- provider.Chunk,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var streamResp OllamaResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
output <- provider.Chunk{
|
||||
Content: streamResp.Message.Content,
|
||||
}
|
||||
content.WriteString(streamResp.Message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
}
|
||||
@@ -1,347 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
openaiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
openaiTools[i].Type = "function"
|
||||
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
openaiTools[i].Function = FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "object",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
||||
converted := make([]ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].Type = "function"
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Function.Name = call.Name
|
||||
|
||||
json, _ := json.Marshal(call.Parameters)
|
||||
converted[i].Function.Arguments = string(json)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Name = call.Function.Name
|
||||
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) ChatCompletionRequest {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = "assistant"
|
||||
message.Content = m.Content
|
||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
requestMessages = append(requestMessages, message)
|
||||
case "tool_result":
|
||||
// expand tool_result messages' results into multiple openAI messages
|
||||
for _, result := range m.ToolResults {
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = "tool"
|
||||
message.Content = result.Result
|
||||
message.ToolCallID = result.ToolCallID
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
default:
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = string(m.Role)
|
||||
message.Content = m.Content
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
request := ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Messages: requestMessages,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []ToolCall,
|
||||
callback provider.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := completionResp.Choices[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content + choice.Message.Content
|
||||
} else {
|
||||
content = choice.Message.Content
|
||||
}
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- provider.Chunk,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []ToolCall{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(line, []byte("[DONE]")) {
|
||||
break
|
||||
}
|
||||
|
||||
var streamResp ChatCompletionStreamResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
delta := streamResp.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta.Content) > 0 {
|
||||
output <- provider.Chunk {
|
||||
Content: delta.Content,
|
||||
}
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
} else {
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package openai
|
||||
|
||||
type OpenAIClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
N int `json:"n"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionChoice struct {
|
||||
Message ChatCompletionMessage `json:"message"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoice struct {
|
||||
Delta ChatCompletionMessage `json:"delta"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type ReplyCallback func(model.Message)
|
||||
|
||||
type Chunk struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
type ChatCompletionClient interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
// complete user-facing response is returned as a string.
|
||||
CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback ReplyCallback,
|
||||
) (string, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel as it's received.
|
||||
CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback ReplyCallback,
|
||||
output chan<- Chunk,
|
||||
) (string, error)
|
||||
}
|
||||
Reference in New Issue
Block a user