lmcli/pkg/provider/ollama/ollama.go

184 lines
4.0 KiB
Go
Raw Permalink Normal View History

2024-05-31 19:38:45 -06:00
package ollama
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
2024-05-31 19:38:45 -06:00
)
type OllamaClient struct {
BaseURL string
}
type OllamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type OllamaRequest struct {
Model string `json:"model"`
Messages []OllamaMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message OllamaMessage `json:"message"`
Done bool `json:"done"`
TotalDuration uint64 `json:"total_duration,omitempty"`
LoadDuration uint64 `json:"load_duration,omitempty"`
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
EvalCount uint64 `json:"eval_count,omitempty"`
EvalDuration uint64 `json:"eval_duration,omitempty"`
}
func createOllamaRequest(
params provider.RequestParameters,
messages []api.Message,
2024-05-31 19:38:45 -06:00
) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages))
for _, m := range messages {
message := OllamaMessage{
Role: string(m.Role),
Content: m.Content,
}
requestMessages = append(requestMessages, message)
}
request := OllamaRequest{
Model: params.Model,
Messages: requestMessages,
}
return request
}
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
2024-05-31 19:38:45 -06:00
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
2024-05-31 19:38:45 -06:00
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, nil
}
func (c *OllamaClient) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
2024-05-31 19:38:45 -06:00
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
2024-05-31 19:38:45 -06:00
}
req := createOllamaRequest(params, messages)
req.Stream = false
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
2024-05-31 19:38:45 -06:00
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
resp, err := c.sendRequest(httpReq)
2024-05-31 19:38:45 -06:00
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
defer resp.Body.Close()
var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
return api.NewMessageWithAssistant(completionResp.Message.Content), nil
2024-05-31 19:38:45 -06:00
}
func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
2024-05-31 19:38:45 -06:00
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
2024-05-31 19:38:45 -06:00
}
req := createOllamaRequest(params, messages)
req.Stream = true
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
2024-05-31 19:38:45 -06:00
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
resp, err := c.sendRequest(httpReq)
2024-05-31 19:38:45 -06:00
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
defer resp.Body.Close()
content := strings.Builder{}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
2024-05-31 19:38:45 -06:00
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
2024-05-31 19:38:45 -06:00
}
if len(streamResp.Message.Content) > 0 {
output <- provider.Chunk{
Content: streamResp.Message.Content,
TokenCount: 1,
}
2024-05-31 19:38:45 -06:00
content.WriteString(streamResp.Message.Content)
}
}
return api.NewMessageWithAssistant(content.String()), nil
2024-05-31 19:38:45 -06:00
}