Private
Public Access
1
0

Refactor pkg/lmcli/provider

Moved `ChangeCompletionInterface` to `pkg/api`, moved individual
providers to `pkg/api/provider`
This commit is contained in:
2024-06-09 16:42:53 +00:00
parent d2d946b776
commit a2c860252f
12 changed files with 37 additions and 37 deletions

View File

@@ -0,0 +1,334 @@
package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
requestBody := Request{
Model: params.Model,
Messages: make([]Message, len(messages)),
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Stream: false,
StopSequences: []string{
FUNCTION_STOP_SEQUENCE,
"\n\nHuman:",
},
}
startIdx := 0
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:]
startIdx = 1
}
if len(params.ToolBag) > 0 {
if len(requestBody.System) > 0 {
// add a divider between existing system prompt and tools
requestBody.System += "\n\n---\n\n"
}
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
}
for i, msg := range messages[startIdx:] {
message := &requestBody.Messages[i]
switch msg.Role {
case model.MessageRoleToolCall:
message.Role = "assistant"
if msg.Content != "" {
message.Content = msg.Content
}
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
xmlString, err := xmlFuncCalls.XMLString()
if err != nil {
panic("Could not serialize []ToolCall to XMLFunctionCall")
}
if len(message.Content) > 0 {
message.Content += fmt.Sprintf("\n\n%s", xmlString)
} else {
message.Content = xmlString
}
case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString()
if err != nil {
panic("Could not serialize []ToolResult to XMLFunctionResults")
}
message.Role = "user"
message.Content = xmlString
default:
message.Role = string(msg.Role)
message.Content = msg.Content
}
}
return requestBody
}
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
jsonBody, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
}
return resp, nil
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback api.ReplyCallback,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
var response Response
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", fmt.Errorf("failed to decode response: %v", err)
}
sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
}
for _, content := range response.Content {
var reply model.Message
switch content.Type {
case "text":
reply = model.Message{
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
}
}
return sb.String(), nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback api.ReplyCallback,
output chan<- api.Chunk,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages)
request.Stream = true
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
continuation = true
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
if line[0] == '{' {
var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event: %s", line)
}
switch eventType {
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
}
} else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event type")
}
switch eventType {
case "message_start":
// noop
case "ping":
// signals start of text - currently ignoring
case "content_block_start":
// ignore?
case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid content block delta")
}
text, ok := delta["text"].(string)
if !ok {
return "", fmt.Errorf("invalid text delta")
}
sb.WriteString(text)
output <- api.Chunk{
Content: text,
}
case "content_block_stop":
// ignore?
case "message_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid message delta")
}
stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" {
stopSequence, ok := delta["stop_sequence"].(string)
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
content := sb.String()
start := strings.Index(content, "<function_calls>")
if start == -1 {
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- api.Chunk{
Content: FUNCTION_STOP_SEQUENCE,
}
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
// function call xml stripped from content for model interop
Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
}
case "message_stop":
// return the completed message
content := sb.String()
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
return content, nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
fmt.Printf("\nUnrecognized event: %s\n", data)
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read response body: %v", err)
}
return "", fmt.Errorf("unexpected end of stream")
}

View File

@@ -0,0 +1,232 @@
package anthropic
import (
"bytes"
"fmt"
"strings"
"text/template"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
const TOOL_PREAMBLE = `You have access to the following tools when replying.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:`
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
type XMLTools struct {
XMLName struct{} `xml:"tools"`
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
}
type XMLToolDescription struct {
ToolName string `xml:"tool_name"`
Description string `xml:"description"`
Parameters []XMLToolParameter `xml:"parameters>parameter"`
}
type XMLToolParameter struct {
Name string `xml:"name"`
Type string `xml:"type"`
Description string `xml:"description"`
}
type XMLFunctionCalls struct {
XMLName struct{} `xml:"function_calls"`
Invoke []XMLFunctionInvoke `xml:"invoke"`
}
type XMLFunctionInvoke struct {
ToolName string `xml:"tool_name"`
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
}
type XMLFunctionInvokeParameters struct {
String string `xml:",innerxml"`
}
type XMLFunctionResults struct {
XMLName struct{} `xml:"function_results"`
Result []XMLFunctionResult `xml:"result"`
}
type XMLFunctionResult struct {
ToolName string `xml:"tool_name"`
Stdout string `xml:"stdout"`
}
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
// parameters name to value
func parseFunctionParametersXML(params string) map[string]interface{} {
lines := strings.Split(params, "\n")
ret := make(map[string]interface{}, len(lines))
for _, line := range lines {
i := strings.Index(line, ">")
if i == -1 {
continue
}
j := strings.Index(line, "</")
if j == -1 {
continue
}
// chop from after opening < to first > to get parameter name,
// then chop after > to first </ to get parameter value
ret[line[1:i]] = line[i+1 : j]
}
return ret
}
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools {
converted[i].ToolName = tool.Name
converted[i].Description = tool.Description
params := make([]XMLToolParameter, len(tool.Parameters))
for j, param := range tool.Parameters {
params[j].Name = param.Name
params[j].Description = param.Description
params[j].Type = param.Type
}
converted[i].Parameters = params
}
return XMLTools{
ToolDescriptions: converted,
}
}
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
}
return toolCalls
}
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters
var paramXML string
for key, value := range toolCall.Parameters {
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
}
params.String = paramXML
converted[i] = XMLFunctionInvoke{
ToolName: toolCall.Name,
Parameters: params,
}
}
return XMLFunctionCalls{
Invoke: converted,
}
}
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults {
converted[i].ToolName = result.ToolName
converted[i].Stdout = result.Result
}
return XMLFunctionResults{
Result: converted,
}
}
func buildToolsSystemPrompt(tools []model.Tool) string {
xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString()
if err != nil {
panic("Could not serialize []model.Tool to XMLTools")
}
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
}
func (x XMLTools) XMLString() (string, error) {
tmpl, err := template.New("tools").Parse(`<tools>
{{range .ToolDescriptions}}<tool_description>
<tool_name>{{.ToolName}}</tool_name>
<description>
{{.Description}}
</description>
<parameters>
{{range .Parameters}}<parameter>
<name>{{.Name}}</name>
<type>{{.Type}}</type>
<description>{{.Description}}</description>
</parameter>
{{end}}</parameters>
</tool_description>
{{end}}</tools>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionResults) XMLString() (string, error) {
tmpl, err := template.New("function_results").Parse(`<function_results>
{{range .Result}}<result>
<tool_name>{{.ToolName}}</tool_name>
<stdout>{{.Stdout}}</stdout>
</result>
{{end}}</function_results>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionCalls) XMLString() (string, error) {
tmpl, err := template.New("function_calls").Parse(`<function_calls>
{{range .Invoke}}<invoke>
<tool_name>{{.ToolName}}</tool_name>
<parameters>{{.Parameters.String}}</parameters>
</invoke>
{{end}}</function_calls>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,38 @@
package anthropic
type AnthropicClient struct {
BaseURL string
APIKey string
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
}
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
}