Compare commits
2 Commits
c50b6b154d
...
a1fc8a637b
Author | SHA1 | Date | |
---|---|---|---|
a1fc8a637b | |||
94d84ba7d7 |
2
TODO.md
2
TODO.md
@ -3,7 +3,7 @@
|
|||||||
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
||||||
when calling anthropic?
|
when calling anthropic?
|
||||||
- [x] `dir_tree` tool
|
- [x] `dir_tree` tool
|
||||||
- [ ] Implement native Anthropic API tool calling
|
- [x] Implement native Anthropic API tool calling
|
||||||
- [ ] Agents - a name given to a system prompt + set of available tools +
|
- [ ] Agents - a name given to a system prompt + set of available tools +
|
||||||
potentially other relevent data (e.g. external service credentials, files for
|
potentially other relevent data (e.g. external service credentials, files for
|
||||||
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
||||||
|
@ -40,3 +40,10 @@ type ChatCompletionProvider interface {
|
|||||||
chunks chan<- Chunk,
|
chunks chan<- Chunk,
|
||||||
) (*Message, error)
|
) (*Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsAssistantContinuation(messages []Message) bool {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return messages[len(messages)-1].Role == MessageRoleAssistant
|
||||||
|
}
|
||||||
|
@ -5,100 +5,223 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/xml"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
|
const ANTHROPIC_VERSION = "2023-06-01"
|
||||||
requestBody := Request{
|
|
||||||
Model: params.Model,
|
|
||||||
Messages: make([]Message, len(messages)),
|
|
||||||
MaxTokens: params.MaxTokens,
|
|
||||||
Temperature: params.Temperature,
|
|
||||||
Stream: false,
|
|
||||||
|
|
||||||
StopSequences: []string{
|
type AnthropicClient struct {
|
||||||
FUNCTION_STOP_SEQUENCE,
|
APIKey string
|
||||||
"\n\nHuman:",
|
BaseURL string
|
||||||
},
|
}
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content interface{} `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
InputSchema InputSchema `json:"input_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]Property `json:"properties"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Property struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentBlock struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input interface{} `json:"input,omitempty"`
|
||||||
|
partialJsonAccumulator string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Content []ContentBlock `json:"content"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamEvent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message interface{} `json:"message,omitempty"`
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
|
Delta interface{} `json:"delta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
|
anthropicTools := make([]Tool, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
properties := make(map[string]Property)
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
properties[param.Name] = Property{
|
||||||
|
Type: param.Type,
|
||||||
|
Description: param.Description,
|
||||||
|
Enum: param.Enum,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var required []string
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
if param.Required {
|
||||||
|
required = append(required, param.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicTools[i] = Tool{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: properties,
|
||||||
|
Required: required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return anthropicTools
|
||||||
|
}
|
||||||
|
|
||||||
|
func createChatCompletionRequest(
|
||||||
|
params api.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (string, ChatCompletionRequest) {
|
||||||
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||||
|
var systemMessage string
|
||||||
|
|
||||||
|
for _, m := range messages {
|
||||||
|
if m.Role == api.MessageRoleSystem {
|
||||||
|
systemMessage = m.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var content interface{}
|
||||||
|
role := string(m.Role)
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case api.MessageRoleToolCall:
|
||||||
|
role = "assistant"
|
||||||
|
contentBlocks := make([]map[string]interface{}, 0)
|
||||||
|
if m.Content != "" {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": m.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, toolCall := range m.ToolCalls {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolCall.ID,
|
||||||
|
"name": toolCall.Name,
|
||||||
|
"input": toolCall.Parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
content = contentBlocks
|
||||||
|
|
||||||
|
case api.MessageRoleToolResult:
|
||||||
|
role = "user"
|
||||||
|
contentBlocks := make([]map[string]interface{}, 0)
|
||||||
|
for _, result := range m.ToolResults {
|
||||||
|
contentBlock := map[string]interface{}{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": result.ToolCallID,
|
||||||
|
"content": result.Result,
|
||||||
|
}
|
||||||
|
contentBlocks = append(contentBlocks, contentBlock)
|
||||||
|
}
|
||||||
|
content = contentBlocks
|
||||||
|
|
||||||
|
default:
|
||||||
|
content = m.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
requestMessages = append(requestMessages, ChatCompletionMessage{
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
startIdx := 0
|
request := ChatCompletionRequest{
|
||||||
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
Model: params.Model,
|
||||||
requestBody.System = messages[0].Content
|
Messages: requestMessages,
|
||||||
requestBody.Messages = requestBody.Messages[1:]
|
System: systemMessage,
|
||||||
startIdx = 1
|
MaxTokens: params.MaxTokens,
|
||||||
|
Temperature: params.Temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(params.ToolBag) > 0 {
|
if len(params.ToolBag) > 0 {
|
||||||
if len(requestBody.System) > 0 {
|
request.Tools = convertTools(params.ToolBag)
|
||||||
// add a divider between existing system prompt and tools
|
|
||||||
requestBody.System += "\n\n---\n\n"
|
|
||||||
}
|
|
||||||
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, msg := range messages[startIdx:] {
|
var prefill string
|
||||||
message := &requestBody.Messages[i]
|
if api.IsAssistantContinuation(messages) {
|
||||||
|
prefill = messages[len(messages)-1].Content
|
||||||
switch msg.Role {
|
|
||||||
case api.MessageRoleToolCall:
|
|
||||||
message.Role = "assistant"
|
|
||||||
if msg.Content != "" {
|
|
||||||
message.Content = msg.Content
|
|
||||||
}
|
|
||||||
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
|
||||||
xmlString, err := xmlFuncCalls.XMLString()
|
|
||||||
if err != nil {
|
|
||||||
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
|
||||||
}
|
|
||||||
if len(message.Content) > 0 {
|
|
||||||
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
|
||||||
} else {
|
|
||||||
message.Content = xmlString
|
|
||||||
}
|
|
||||||
case api.MessageRoleToolResult:
|
|
||||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
|
||||||
xmlString, err := xmlFuncResults.XMLString()
|
|
||||||
if err != nil {
|
|
||||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
|
||||||
}
|
|
||||||
message.Role = "user"
|
|
||||||
message.Content = xmlString
|
|
||||||
default:
|
|
||||||
message.Role = string(msg.Role)
|
|
||||||
message.Content = msg.Content
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return requestBody
|
|
||||||
|
return prefill, request
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||||
jsonBody, err := json.Marshal(r)
|
jsonData, err := json.Marshal(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("x-api-key", c.APIKey)
|
req.Header.Set("x-api-key", c.APIKey)
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
||||||
req.Header.Set("content-type", "application/json")
|
req.Header.Set("content-type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletion(
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
@ -107,45 +230,25 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildRequest(params, messages)
|
_, req := createChatCompletionRequest(params, messages)
|
||||||
|
req.Stream = false
|
||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
resp, err := c.sendRequest(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var response Response
|
var completionResp ChatCompletionResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode response: %v", err)
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb := strings.Builder{}
|
return convertResponseToMessage(completionResp)
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
|
||||||
if lastMessage.Role.IsAssistant() {
|
|
||||||
// this is a continuation of a previous assistant reply, so we'll
|
|
||||||
// include its contents in the final result
|
|
||||||
sb.WriteString(lastMessage.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, content := range response.Content {
|
|
||||||
switch content.Type {
|
|
||||||
case "text":
|
|
||||||
sb.WriteString(content.Text)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &api.Message{
|
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: sb.String(),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
@ -155,144 +258,193 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
output chan<- api.Chunk,
|
output chan<- api.Chunk,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildRequest(params, messages)
|
prefill, req := createChatCompletionRequest(params, messages)
|
||||||
request.Stream = true
|
req.Stream = true
|
||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
resp, err := c.sendRequest(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
sb := strings.Builder{}
|
contentBlocks := make(map[int]*ContentBlock)
|
||||||
|
var finalMessage *ChatCompletionResponse
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
var firstChunkReceived bool
|
||||||
if messages[len(messages)-1].Role.IsAssistant() {
|
|
||||||
// this is a continuation of a previous assistant reply, so we'll
|
|
||||||
// include its contents in the final result
|
|
||||||
// TODO: handle this at higher level
|
|
||||||
sb.WriteString(lastMessage.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
for scanner.Scan() {
|
for {
|
||||||
line := scanner.Text()
|
select {
|
||||||
line = strings.TrimSpace(line)
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
if len(line) == 0 {
|
default:
|
||||||
continue
|
line, err := reader.ReadBytes('\n')
|
||||||
}
|
|
||||||
|
|
||||||
if line[0] == '{' {
|
|
||||||
var event map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(line), &event)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||||
}
|
}
|
||||||
eventType, ok := event["type"].(string)
|
|
||||||
if !ok {
|
line = bytes.TrimSpace(line)
|
||||||
return nil, fmt.Errorf("invalid event: %s", line)
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
switch eventType {
|
|
||||||
case "error":
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||||
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
|
||||||
default:
|
var streamEvent StreamEvent
|
||||||
return nil, fmt.Errorf("unknown event type: %s", eventType)
|
err = json.Unmarshal(line, &streamEvent)
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(line, "data:") {
|
|
||||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
||||||
var event map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(data), &event)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
|
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
eventType, ok := event["type"].(string)
|
switch streamEvent.Type {
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid event type")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch eventType {
|
|
||||||
case "message_start":
|
case "message_start":
|
||||||
// noop
|
finalMessage = &ChatCompletionResponse{}
|
||||||
case "ping":
|
err = json.Unmarshal(line, &struct {
|
||||||
// signals start of text - currently ignoring
|
Message *ChatCompletionResponse `json:"message"`
|
||||||
|
}{Message: finalMessage})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
||||||
|
}
|
||||||
case "content_block_start":
|
case "content_block_start":
|
||||||
// ignore?
|
var contentBlockStart struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
ContentBlock ContentBlock `json:"content_block"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(line, &contentBlockStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
delta, ok := event["delta"].(map[string]interface{})
|
if streamEvent.Index >= len(contentBlocks) {
|
||||||
if !ok {
|
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
||||||
return nil, fmt.Errorf("invalid content block delta")
|
|
||||||
}
|
}
|
||||||
text, ok := delta["text"].(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid text delta")
|
|
||||||
}
|
|
||||||
sb.WriteString(text)
|
|
||||||
output <- api.Chunk{
|
|
||||||
Content: text,
|
|
||||||
TokenCount: 1,
|
|
||||||
}
|
|
||||||
case "content_block_stop":
|
|
||||||
// ignore?
|
|
||||||
case "message_delta":
|
|
||||||
delta, ok := event["delta"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid message delta")
|
|
||||||
}
|
|
||||||
stopReason, ok := delta["stop_reason"].(string)
|
|
||||||
if ok && stopReason == "stop_sequence" {
|
|
||||||
stopSequence, ok := delta["stop_sequence"].(string)
|
|
||||||
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
|
||||||
content := sb.String()
|
|
||||||
|
|
||||||
start := strings.Index(content, "<function_calls>")
|
block := contentBlocks[streamEvent.Index]
|
||||||
if start == -1 {
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||||
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
deltaType, ok := delta["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("delta missing type field")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch deltaType {
|
||||||
|
case "text_delta":
|
||||||
|
if text, ok := delta["text"].(string); ok {
|
||||||
|
if !firstChunkReceived {
|
||||||
|
if prefill == "" {
|
||||||
|
// if there is no prefil, ensure we trim leading whitespace
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
firstChunkReceived = true
|
||||||
}
|
}
|
||||||
|
block.Text += text
|
||||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
|
||||||
output <- api.Chunk{
|
output <- api.Chunk{
|
||||||
Content: FUNCTION_STOP_SEQUENCE,
|
Content: text,
|
||||||
TokenCount: 1,
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
}
|
||||||
|
case "input_json_delta":
|
||||||
var functionCalls XMLFunctionCalls
|
if block.Type != "tool_use" {
|
||||||
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
||||||
if err != nil {
|
}
|
||||||
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||||
}
|
block.partialJsonAccumulator += partialJSON
|
||||||
|
|
||||||
return &api.Message{
|
|
||||||
Role: api.MessageRoleToolCall,
|
|
||||||
// function call xml stripped from content for model interop
|
|
||||||
Content: strings.TrimSpace(content[:start]),
|
|
||||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "content_block_stop":
|
||||||
|
if streamEvent.Index >= len(contentBlocks) {
|
||||||
|
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
block := contentBlocks[streamEvent.Index]
|
||||||
|
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
||||||
|
var inputData map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
||||||
|
}
|
||||||
|
block.Input = inputData
|
||||||
|
}
|
||||||
|
case "message_delta":
|
||||||
|
if finalMessage == nil {
|
||||||
|
return nil, fmt.Errorf("received message_delta before message_start")
|
||||||
|
}
|
||||||
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
||||||
|
}
|
||||||
|
if stopReason, ok := delta["stop_reason"].(string); ok {
|
||||||
|
finalMessage.StopReason = stopReason
|
||||||
|
}
|
||||||
|
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// return the completed message
|
// End of the stream
|
||||||
content := sb.String()
|
goto END_STREAM
|
||||||
return &api.Message{
|
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content,
|
|
||||||
}, nil
|
|
||||||
case "error":
|
case "error":
|
||||||
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
// Ignore unknown event types
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
END_STREAM:
|
||||||
return nil, fmt.Errorf("failed to read response body: %v", err)
|
if finalMessage == nil {
|
||||||
|
return nil, fmt.Errorf("no final message received")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unexpected end of stream")
|
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
||||||
|
for _, v := range contentBlocks {
|
||||||
|
finalMessage.Content = append(finalMessage.Content, *v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return convertResponseToMessage(*finalMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
||||||
|
content := strings.Builder{}
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
|
||||||
|
for _, block := range resp.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case "text":
|
||||||
|
content.WriteString(block.Text)
|
||||||
|
case "tool_use":
|
||||||
|
parameters, ok := block.Input.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
ID: block.ID,
|
||||||
|
Name: block.Name,
|
||||||
|
Parameters: parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message := &api.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: content.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
message.Role = api.MessageRoleToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
return message, nil
|
||||||
}
|
}
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
package anthropic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
|
||||||
|
|
||||||
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
|
||||||
|
|
||||||
You may call them like this:
|
|
||||||
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>$TOOL_NAME</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
|
||||||
...
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
|
|
||||||
Here are the tools available:`
|
|
||||||
|
|
||||||
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
|
|
||||||
|
|
||||||
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
|
|
||||||
|
|
||||||
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
|
|
||||||
|
|
||||||
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
|
|
||||||
|
|
||||||
type XMLTools struct {
|
|
||||||
XMLName struct{} `xml:"tools"`
|
|
||||||
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLToolDescription struct {
|
|
||||||
ToolName string `xml:"tool_name"`
|
|
||||||
Description string `xml:"description"`
|
|
||||||
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLToolParameter struct {
|
|
||||||
Name string `xml:"name"`
|
|
||||||
Type string `xml:"type"`
|
|
||||||
Description string `xml:"description"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionCalls struct {
|
|
||||||
XMLName struct{} `xml:"function_calls"`
|
|
||||||
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionInvoke struct {
|
|
||||||
ToolName string `xml:"tool_name"`
|
|
||||||
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionInvokeParameters struct {
|
|
||||||
String string `xml:",innerxml"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionResults struct {
|
|
||||||
XMLName struct{} `xml:"function_results"`
|
|
||||||
Result []XMLFunctionResult `xml:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionResult struct {
|
|
||||||
ToolName string `xml:"tool_name"`
|
|
||||||
Stdout string `xml:"stdout"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
|
||||||
// parameters name to value
|
|
||||||
func parseFunctionParametersXML(params string) map[string]interface{} {
|
|
||||||
lines := strings.Split(params, "\n")
|
|
||||||
ret := make(map[string]interface{}, len(lines))
|
|
||||||
for _, line := range lines {
|
|
||||||
i := strings.Index(line, ">")
|
|
||||||
if i == -1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
j := strings.Index(line, "</")
|
|
||||||
if j == -1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// chop from after opening < to first > to get parameter name,
|
|
||||||
// then chop after > to first </ to get parameter value
|
|
||||||
ret[line[1:i]] = line[i+1 : j]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
|
|
||||||
converted := make([]XMLToolDescription, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
converted[i].ToolName = tool.Name
|
|
||||||
converted[i].Description = tool.Description
|
|
||||||
|
|
||||||
params := make([]XMLToolParameter, len(tool.Parameters))
|
|
||||||
for j, param := range tool.Parameters {
|
|
||||||
params[j].Name = param.Name
|
|
||||||
params[j].Description = param.Description
|
|
||||||
params[j].Type = param.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
converted[i].Parameters = params
|
|
||||||
}
|
|
||||||
return XMLTools{
|
|
||||||
ToolDescriptions: converted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
|
|
||||||
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
|
|
||||||
for i, invoke := range functionCalls.Invoke {
|
|
||||||
toolCalls[i].Name = invoke.ToolName
|
|
||||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
|
||||||
}
|
|
||||||
return toolCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
|
|
||||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
|
||||||
for i, toolCall := range toolCalls {
|
|
||||||
var params XMLFunctionInvokeParameters
|
|
||||||
var paramXML string
|
|
||||||
for key, value := range toolCall.Parameters {
|
|
||||||
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
|
||||||
}
|
|
||||||
params.String = paramXML
|
|
||||||
converted[i] = XMLFunctionInvoke{
|
|
||||||
ToolName: toolCall.Name,
|
|
||||||
Parameters: params,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return XMLFunctionCalls{
|
|
||||||
Invoke: converted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
|
|
||||||
converted := make([]XMLFunctionResult, len(toolResults))
|
|
||||||
for i, result := range toolResults {
|
|
||||||
converted[i].ToolName = result.ToolName
|
|
||||||
converted[i].Stdout = result.Result
|
|
||||||
}
|
|
||||||
return XMLFunctionResults{
|
|
||||||
Result: converted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
|
|
||||||
xmlTools := convertToolsToXMLTools(tools)
|
|
||||||
xmlToolsString, err := xmlTools.XMLString()
|
|
||||||
if err != nil {
|
|
||||||
panic("Could not serialize []api.Tool to XMLTools")
|
|
||||||
}
|
|
||||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x XMLTools) XMLString() (string, error) {
|
|
||||||
tmpl, err := template.New("tools").Parse(`<tools>
|
|
||||||
{{range .ToolDescriptions}}<tool_description>
|
|
||||||
<tool_name>{{.ToolName}}</tool_name>
|
|
||||||
<description>
|
|
||||||
{{.Description}}
|
|
||||||
</description>
|
|
||||||
<parameters>
|
|
||||||
{{range .Parameters}}<parameter>
|
|
||||||
<name>{{.Name}}</name>
|
|
||||||
<type>{{.Type}}</type>
|
|
||||||
<description>{{.Description}}</description>
|
|
||||||
</parameter>
|
|
||||||
{{end}}</parameters>
|
|
||||||
</tool_description>
|
|
||||||
{{end}}</tools>`)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&buf, x); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x XMLFunctionResults) XMLString() (string, error) {
|
|
||||||
tmpl, err := template.New("function_results").Parse(`<function_results>
|
|
||||||
{{range .Result}}<result>
|
|
||||||
<tool_name>{{.ToolName}}</tool_name>
|
|
||||||
<stdout>{{.Stdout}}</stdout>
|
|
||||||
</result>
|
|
||||||
{{end}}</function_results>`)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&buf, x); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x XMLFunctionCalls) XMLString() (string, error) {
|
|
||||||
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
|
||||||
{{range .Invoke}}<invoke>
|
|
||||||
<tool_name>{{.ToolName}}</tool_name>
|
|
||||||
<parameters>{{.Parameters.String}}</parameters>
|
|
||||||
</invoke>
|
|
||||||
{{end}}</function_calls>`)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&buf, x); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
package anthropic
|
|
||||||
|
|
||||||
type AnthropicClient struct {
|
|
||||||
BaseURL string
|
|
||||||
APIKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
//TopP float32 `json:"top_p,omitempty"`
|
|
||||||
//TopK float32 `json:"top_k,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OriginalContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content []OriginalContent `json:"content"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
StopSequence string `json:"stop_sequence"`
|
|
||||||
}
|
|
||||||
|
|
@ -13,6 +13,85 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args map[string]string `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response interface{} `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Content struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Parts []ContentPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerationConfig struct {
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentRequest struct {
|
||||||
|
Contents []Content `json:"contents"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||||
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Candidate struct {
|
||||||
|
Content Content `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentResponse struct {
|
||||||
|
Candidates []Candidate `json:"candidates"`
|
||||||
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDeclaration struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Values []string `json:"values,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
func convertTools(tools []api.ToolSpec) []Tool {
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
geminiTools := make([]Tool, len(tools))
|
geminiTools := make([]Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
|
@ -1,80 +0,0 @@
|
|||||||
package google
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContentPart struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
|
||||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionCall struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Args map[string]string `json:"args"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionResponse struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Response interface{} `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Content struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Parts []ContentPart `json:"parts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerationConfig struct {
|
|
||||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
||||||
Temperature *float32 `json:"temperature,omitempty"`
|
|
||||||
TopP *float32 `json:"topP,omitempty"`
|
|
||||||
TopK *int `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateContentRequest struct {
|
|
||||||
Contents []Content `json:"contents"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
|
||||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Candidate struct {
|
|
||||||
Content Content `json:"content"`
|
|
||||||
FinishReason string `json:"finishReason"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UsageMetadata struct {
|
|
||||||
PromptTokenCount int `json:"promptTokenCount"`
|
|
||||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
||||||
TotalTokenCount int `json:"totalTokenCount"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateContentResponse struct {
|
|
||||||
Candidates []Candidate `json:"candidates"`
|
|
||||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionDeclaration struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters ToolParameters `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Values []string `json:"values,omitempty"`
|
|
||||||
}
|
|
@ -13,6 +13,76 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OpenAIClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
Function FunctionDefinition `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDefinition struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
Arguments string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function FunctionDefinition `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoice struct {
|
||||||
|
Message ChatCompletionMessage `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
Choices []ChatCompletionChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamChoice struct {
|
||||||
|
Delta ChatCompletionMessage `json:"delta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamResponse struct {
|
||||||
|
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
func convertTools(tools []api.ToolSpec) []Tool {
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
openaiTools := make([]Tool, len(tools))
|
openaiTools := make([]Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
type OpenAIClient struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
Index *int `json:"index,omitempty"`
|
|
||||||
Function FunctionDefinition `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionDefinition struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters ToolParameters `json:"parameters"`
|
|
||||||
Arguments string `json:"arguments,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function FunctionDefinition `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
Messages []ChatCompletionMessage `json:"messages"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
ToolChoice string `json:"tool_choice,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionChoice struct {
|
|
||||||
Message ChatCompletionMessage `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionResponse struct {
|
|
||||||
Choices []ChatCompletionChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionStreamChoice struct {
|
|
||||||
Delta ChatCompletionMessage `json:"delta"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionStreamResponse struct {
|
|
||||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
|
||||||
}
|
|
@ -96,7 +96,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
|
|||||||
if m == model {
|
if m == model {
|
||||||
switch *p.Kind {
|
switch *p.Kind {
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
url := "https://api.anthropic.com/v1"
|
url := "https://api.anthropic.com"
|
||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user