Add Ollama support
This commit is contained in:
parent
465b1d333e
commit
ea576d24a6
@ -10,6 +10,7 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
@ -113,6 +114,14 @@ func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletio
|
|||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
}, nil
|
}, nil
|
||||||
|
case "ollama":
|
||||||
|
url := "http://localhost:11434/api"
|
||||||
|
if p.BaseURL != nil {
|
||||||
|
url = *p.BaseURL
|
||||||
|
}
|
||||||
|
return model, &ollama.OllamaClient{
|
||||||
|
BaseURL: url,
|
||||||
|
}, nil
|
||||||
case "openai":
|
case "openai":
|
||||||
url := "https://api.openai.com/v1"
|
url := "https://api.openai.com/v1"
|
||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
|
200
pkg/lmcli/provider/ollama/ollama.go
Normal file
200
pkg/lmcli/provider/ollama/ollama.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OllamaClient struct {
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []OllamaMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Message OllamaMessage `json:"message"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
TotalDuration uint64 `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration uint64 `json:"load_duration,omitempty"`
|
||||||
|
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
||||||
|
EvalCount uint64 `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func createOllamaRequest(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
) OllamaRequest {
|
||||||
|
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||||
|
|
||||||
|
for _, m := range messages {
|
||||||
|
message := OllamaMessage{
|
||||||
|
Role: string(m.Role),
|
||||||
|
Content: m.Content,
|
||||||
|
}
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
request := OllamaRequest{
|
||||||
|
Model: params.Model,
|
||||||
|
Messages: requestMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OllamaClient) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
callback provider.ReplyCallback,
|
||||||
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := createOllamaRequest(params, messages)
|
||||||
|
req.Stream = false
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var completionResp OllamaResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
content := completionResp.Message.Content
|
||||||
|
|
||||||
|
fmt.Println(content)
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OllamaClient) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
callback provider.ReplyCallback,
|
||||||
|
output chan<- string,
|
||||||
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := createOllamaRequest(params, messages)
|
||||||
|
req.Stream = true
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var streamResp OllamaResponse
|
||||||
|
err = json.Unmarshal(line, &streamResp)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(streamResp.Message.Content) > 0 {
|
||||||
|
output <- streamResp.Message.Content
|
||||||
|
content.WriteString(streamResp.Message.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.String(), nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user