lmcli/pkg/provider/ollama/ollama.go
Matt Low 0384c7cb66 Large refactor - it compiles!
This refactor splits out all conversation concerns into a new
`conversation` package. There is now a split between `conversation` and
`api`s representation of `Message`, the latter storing the minimum
information required for interaction with LLM providers. There is
necessary conversation between the two when making LLM calls.
2024-10-22 17:53:13 +00:00

184 lines
4.0 KiB
Go

package ollama
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
)
type OllamaClient struct {
BaseURL string
}
type OllamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type OllamaRequest struct {
Model string `json:"model"`
Messages []OllamaMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message OllamaMessage `json:"message"`
Done bool `json:"done"`
TotalDuration uint64 `json:"total_duration,omitempty"`
LoadDuration uint64 `json:"load_duration,omitempty"`
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
EvalCount uint64 `json:"eval_count,omitempty"`
EvalDuration uint64 `json:"eval_duration,omitempty"`
}
func createOllamaRequest(
params provider.RequestParameters,
messages []api.Message,
) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages))
for _, m := range messages {
message := OllamaMessage{
Role: string(m.Role),
Content: m.Content,
}
requestMessages = append(requestMessages, message)
}
request := OllamaRequest{
Model: params.Model,
Messages: requestMessages,
}
return request
}
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, nil
}
func (c *OllamaClient) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = false
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
return api.NewMessageWithAssistant(completionResp.Message.Content), nil
}
func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = true
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
}
if len(streamResp.Message.Content) > 0 {
output <- provider.Chunk{
Content: streamResp.Message.Content,
TokenCount: 1,
}
content.WriteString(streamResp.Message.Content)
}
}
return api.NewMessageWithAssistant(content.String()), nil
}