Add Ollama support

This commit is contained in:
Matt Low 2024-06-01 01:38:45 +00:00
parent 465b1d333e
commit ea576d24a6
2 changed files with 209 additions and 0 deletions

View File

@ -10,6 +10,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
@ -113,6 +114,14 @@ func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletio
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil }, nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != nil {
url = *p.BaseURL
}
return model, &ollama.OllamaClient{
BaseURL: url,
}, nil
case "openai": case "openai":
url := "https://api.openai.com/v1" url := "https://api.openai.com/v1"
if p.BaseURL != nil { if p.BaseURL != nil {

View File

@ -0,0 +1,200 @@
package ollama
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
)
type OllamaClient struct {
BaseURL string
}
type OllamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type OllamaRequest struct {
Model string `json:"model"`
Messages []OllamaMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message OllamaMessage `json:"message"`
Done bool `json:"done"`
TotalDuration uint64 `json:"total_duration,omitempty"`
LoadDuration uint64 `json:"load_duration,omitempty"`
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
EvalCount uint64 `json:"eval_count,omitempty"`
EvalDuration uint64 `json:"eval_duration,omitempty"`
}
func createOllamaRequest(
params model.RequestParameters,
messages []model.Message,
) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages))
for _, m := range messages {
message := OllamaMessage{
Role: string(m.Role),
Content: m.Content,
}
requestMessages = append(requestMessages, message)
}
request := OllamaRequest{
Model: params.Model,
Messages: requestMessages,
}
return request
}
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, nil
}
func (c *OllamaClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = false
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
}
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return "", err
}
content := completionResp.Message.Content
fmt.Println(content)
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
return content, nil
}
func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = true
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
}
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
content := strings.Builder{}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return "", err
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return "", err
}
if len(streamResp.Message.Content) > 0 {
output <- streamResp.Message.Content
content.WriteString(streamResp.Message.Content)
}
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
return content.String(), nil
}