From ea576d24a6980c43a35ee665c5b41b63dc490198 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 1 Jun 2024 01:38:45 +0000 Subject: [PATCH] Add Ollama support --- pkg/lmcli/lmcli.go | 9 ++ pkg/lmcli/provider/ollama/ollama.go | 200 ++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 pkg/lmcli/provider/ollama/ollama.go diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index f47b8d9..7f288c9 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -10,6 +10,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/util" @@ -113,6 +114,14 @@ func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletio BaseURL: url, APIKey: *p.APIKey, }, nil + case "ollama": + url := "http://localhost:11434/api" + if p.BaseURL != nil { + url = *p.BaseURL + } + return model, &ollama.OllamaClient{ + BaseURL: url, + }, nil case "openai": url := "https://api.openai.com/v1" if p.BaseURL != nil { diff --git a/pkg/lmcli/provider/ollama/ollama.go b/pkg/lmcli/provider/ollama/ollama.go new file mode 100644 index 0000000..6202ae4 --- /dev/null +++ b/pkg/lmcli/provider/ollama/ollama.go @@ -0,0 +1,200 @@ +package ollama + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" +) + +type OllamaClient struct { + BaseURL string +} + +type OllamaMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type OllamaRequest struct { + Model string `json:"model"` + Messages []OllamaMessage `json:"messages"` + Stream bool `json:"stream"` +} + +type OllamaResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message OllamaMessage `json:"message"` + Done bool `json:"done"` + TotalDuration uint64 `json:"total_duration,omitempty"` + LoadDuration uint64 `json:"load_duration,omitempty"` + PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"` + PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"` + EvalCount uint64 `json:"eval_count,omitempty"` + EvalDuration uint64 `json:"eval_duration,omitempty"` +} + +func createOllamaRequest( + params model.RequestParameters, + messages []model.Message, +) OllamaRequest { + requestMessages := make([]OllamaMessage, 0, len(messages)) + + for _, m := range messages { + message := OllamaMessage{ + Role: string(m.Role), + Content: m.Content, + } + requestMessages = append(requestMessages, message) + } + + request := OllamaRequest{ + Model: params.Model, + Messages: requestMessages, + } + + return request +} + +func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + bytes, _ := io.ReadAll(resp.Body) + return resp, fmt.Errorf("%v", string(bytes)) + } + + return resp, nil +} + +func (c *OllamaClient) CreateChatCompletion( + ctx context.Context, + params model.RequestParameters, + messages []model.Message, + callback provider.ReplyCallback, +) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + + req := createOllamaRequest(params, messages) + req.Stream = false + + jsonData, err := json.Marshal(req) + if err != nil { + return "", err + } + + httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + + resp, err := c.sendRequest(ctx, httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var completionResp OllamaResponse + err = json.NewDecoder(resp.Body).Decode(&completionResp) + if err != nil { + return "", err + } + + content := completionResp.Message.Content + + fmt.Println(content) + + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: content, + }) + } + + return content, nil +} + +func (c *OllamaClient) CreateChatCompletionStream( + ctx context.Context, + params model.RequestParameters, + messages []model.Message, + callback provider.ReplyCallback, + output chan<- string, +) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + + req := createOllamaRequest(params, messages) + req.Stream = true + + jsonData, err := json.Marshal(req) + if err != nil { + return "", err + } + + httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + + resp, err := c.sendRequest(ctx, httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + content := strings.Builder{} + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + return "", err + } + + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + var streamResp OllamaResponse + err = json.Unmarshal(line, &streamResp) + if err != nil { + return "", err + } + + if len(streamResp.Message.Content) > 0 { + output <- streamResp.Message.Content + content.WriteString(streamResp.Message.Content) + } + } + + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: content.String(), + }) + } + + return content.String(), nil +}