Compare commits

..

No commits in common. "main" and "tui" have entirely different histories.
main ... tui

63 changed files with 2744 additions and 6125 deletions

199
README.md
View File

@ -1,174 +1,59 @@
# lmcli - Large ____ Model CLI # lmcli
`lmcli` is a versatile command-line interface for interacting with LLMs and LMMs. `lmcli` is a (Large) Language Model CLI.
## Features Current features:
- Perform one-shot prompts with `lmcli prompt <message>`
- Manage persistent conversations with the `new`, `reply`, `view`, `rm`,
`edit`, `retry`, `continue` sub-commands.
- Syntax highlighted output
- Tool calling, see the [Tools](#tools) section.
- Multiple model backends (Ollama, OpenAI, Anthropic, Google) Maybe features:
- Customizable agents with tool calling - Chat-like interface (`lmcli chat`) for rapid back-and-forth conversations
- Persistent conversation management - Support for additional models/APIs besides just OpenAI
- Message branching (edit and re-prompt to your heart's desire)
- Interactive terminal interface for seamless chat experiences
- Utilizes `$EDITOR`, write and edit prompts from the comfort of your own editor :)
- `vi`-like bindings
- Syntax highlighting!
## Screenshots
[TODO: Add screenshots of the TUI in action, showing different views and features]
## Installation
To install `lmcli`, make sure you have Go installed on your system, then run:
```sh
go install git.mlow.ca/mlow/lmcli@latest
```
## Configuration
`lmcli` uses a YAML configuration file located at `~/.config/lmcli/config.yaml`. Here's a sample configuration:
```yaml
defaults:
model: claude-3-5-sonnet-20240620
maxTokens: 3072
temperature: 0.2
conversations:
titleGenerationModel: claude-3-haiku-20240307
chroma:
style: onedark
formatter: terminal16m
agents:
- name: coder
tools:
- dir_tree
- read_file
#- write_file
systemPrompt: |
You are an experienced software engineer...
# ...
providers:
- kind: ollama
models:
- phi3:instruct
- llama3:8b
- kind: anthropic
apiKey: your-api-key-here
models:
- claude-3-5-sonnet-20240620
- claude-3-opus-20240229
- claude-3-haiku-20240307
- kind: openai
apiKey: your-api-key-here
models:
- gpt-4o
- gpt-4-turbo
- name: openrouter
kind: openai
apiKey: your-api-key-here
baseUrl: https://openrouter.ai/api/
models:
- qwen/qwen-2-72b-instruct
# ...
```
Customize this file to add your own providers, agents, and models.
### Syntax highlighting
Syntax highlighting is performed by [Chroma](https://github.com/alecthomas/chroma).
Refer to [`Chroma/styles`](https://github.com/alecthomas/chroma/tree/master/styles) for available styles (TODO: add support for custom Chroma styles).
Available formatters:
- `terminal` - 8 colors
- `terminal16` - 16 colors
- `terminal256` - 256 colors
- `terminal16m` - true color (default)
## Agents
Agents in `lmcli` combine a system prompt with a set of available tools. Agents are defined in `config.yaml` and are called upon with the `-a`/`--agent` flag.
Agent functionality is expected to be expanded on, bringing them to close parity with something like OpenAI's "Assistants" feature.
## Tools ## Tools
Tools must be explicitly enabled by adding the tool's name to the
`openai.enabledTools` array in `config.yaml`.
Tools are used by agents to acquire information from and interact with external systems. The following built-in tools are available: Note: all filesystem related tools operate relative to the current directory
only. They do not accept absolute paths, and efforts are made to ensure they
cannot escape above the working directory). **Close attention must be paid to
where you are running `lmcli`, as the model could at any time decide to use one
of these tools to discover and read potentially sensitive information from your
filesystem.**
- `dir_tree`: Display a directory structure It's best to only have tools enabled in `config.yaml` when you intend to be
- `read_file`: Read the contents of a file using them, since their descriptions (see `pkg/cli/functions.go`) count towards
- `write_file`: Write content to a file context usage.
- `file_insert_lines`: Insert lines at a specific position in a file
- `file_replace_lines`: Replace a range of lines in a file
Obviously, some of these tools carry significant risk. Use wisely :) Available tools:
More tool features are planned, including the ability to define arbitrary tools which call out to external scripts, tools to spawn sub-agents, perform web searches, etc. - `read_dir` - Read the contents of a directory.
- `read_file` - Read the contents of a file.
- `write_file` - Write contents to a file.
- `file_insert_lines` - Insert lines at a position within a file. Tricky for
the model to use, but can potentially save tokens.
- `file_replace_lines` - Remove or replace a range of lines within a file. Even
trickier for the model to use.
## Install
```shell
$ go install git.mlow.ca/mlow/lmcli@latest
```
## Usage ## Usage
```console Invoke `lmcli` at least once:
```shell
$ lmcli help $ lmcli help
lmcli - Large Language Model CLI
Usage:
lmcli <command> [flags]
lmcli [command]
Available Commands:
chat Open the chat interface
clone Clone conversations
completion Generate the autocompletion script for the specified shell
continue Continue a conversation from the last message
edit Edit the last user reply in a conversation
help Help about any command
list List conversations
new Start a new conversation
prompt Do a one-shot prompt
rename Rename a conversation
reply Reply to a conversation
retry Retry the last user reply in a conversation
rm Remove conversations
view View messages in a conversation
Flags:
-h, --help help for lmcli
Use "lmcli [command] --help" for more information about a command.
``` ```
### Examples Edit `~/.config/lmcli/config.yaml` and set `openai.apiKey` to your API key.
Start a new chat with the `coder` agent: Refer back to the output of `lmcli help` for usage.
```console Enjoy!
$ lmcli chat --agent coder
```
Start a new conversation, imperative style (no tui):
```console
$ lmcli new "Help me plan meals for the next week"
```
Send a one-shot prompt (no persistence):
```console
$ lmcli prompt "What is the answer to life, the universe, and everything?"
```
## Contributing
Contributions to `lmcli` are welcome! Feel free to open issues or submit pull requests on the project repository.
For a full list of planned features and improvements, check out the [TODO.md](TODO.md) file.
## License
To be determined
## Acknowledgements
`lmcli` is a small hobby project. Special thanks to the Go community and the creators of the libraries used in this project.

37
TODO.md
View File

@ -1,37 +0,0 @@
# TODO
- [x] Strip anthropic XML function call scheme from content, to reconstruct
when calling anthropic?
- [x] `dir_tree` tool
- [x] Implement native Anthropic API tool calling
- [x] Agents - a name given to a system prompt + set of available tools +
potentially other relevent data (e.g. external service credentials, files for
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
code-helper`, `lmcli chat -a financier`).
- [ ] Specialized agents which have integrations beyond basic tool calling,
e.g. a coding agent which bakes in efficient code context management
(only the current state of relevant files get shown to the model in the
system prompt, rather than having them in the conversation messages)
- [ ] Agents may have some form of long term memory management (key-value?
natural lang?).
- [ ] Sandboxed python, js interpreters (implemented with containers)
- [ ] Support for arbitrary external script tools
- [ ] Search - RAG driven search of existing conversation "hey, remind me of
the conversation we had six months ago about X")
- [ ] Conversation categorization - model driven category creation and
conversation classification
- [ ] Image input
- [ ] Image output (sixel support?)
- [ ] Conversation exports to html/pdf/json
- [ ] Store message generation model
- [ ] Hidden CoT
- [ ] Token accounting
## UI
- [x] Prettify/normalize tool_call and tool_result outputs so they can be
shown/optionally hidden in `lmcli view` and `lmcli chat`
- [x] Conversation deletion in conversations view
- [ ] User confirmation before calling (some?) tools
- [ ] Message deletion, Ctrl+D to delete a message and attach its children to
its parent, Ctrl+Shift+D to delete a message and its descendents
- [ ] Show available key bindings and their action in any given view

41
go.mod
View File

@ -3,41 +3,42 @@ module git.mlow.ca/mlow/lmcli
go 1.21 go 1.21
require ( require (
github.com/alecthomas/chroma/v2 v2.14.0 github.com/alecthomas/chroma/v2 v2.11.1
github.com/charmbracelet/bubbles v0.20.0 github.com/charmbracelet/bubbles v0.18.0
github.com/charmbracelet/bubbletea v1.1.1 github.com/charmbracelet/bubbletea v0.25.0
github.com/charmbracelet/lipgloss v0.13.0 github.com/charmbracelet/lipgloss v0.10.0
github.com/muesli/reflow v0.3.0 github.com/go-yaml/yaml v2.1.0+incompatible
github.com/spf13/cobra v1.8.1 github.com/sashabaranov/go-openai v1.17.7
github.com/spf13/cobra v1.8.0
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/sqlite v1.5.4
gorm.io/driver/sqlite v1.5.6 gorm.io/gorm v1.25.5
gorm.io/gorm v1.25.12
) )
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/x/ansi v0.3.1 // indirect github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
github.com/charmbracelet/x/term v0.2.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mattn/go-sqlite3 v1.14.23 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sync v0.8.0 // indirect golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.25.0 // indirect golang.org/x/sys v0.14.0 // indirect
golang.org/x/text v0.18.0 // indirect golang.org/x/term v0.6.0 // indirect
golang.org/x/text v0.3.8 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect
) )

91
go.sum
View File

@ -1,31 +1,27 @@
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/alecthomas/assert/v2 v2.2.1 h1:XivOgYcduV98QCahG8T5XTezV5bylXe+lBxLG2K2ink=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/alecthomas/assert/v2 v2.2.1/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJWtdzJPs=
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE= github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU= github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
github.com/charmbracelet/bubbletea v1.1.1 h1:KJ2/DnmpfqFtDNVTvYZ6zpPFL9iRCRr0qqKOCvppbPY= github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
github.com/charmbracelet/bubbletea v1.1.1/go.mod h1:9Ogk0HrdbHolIKHdjfFpyXJmiCzGwy+FesYkZr7hYU4= github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
github.com/charmbracelet/lipgloss v0.13.0 h1:4X3PPeoWEDCMvzDvGmTajSyYPcZM4+y8sCA/SsA3cjw= github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
github.com/charmbracelet/lipgloss v0.13.0/go.mod h1:nw4zy0SBX/F/eAO1cWdcvy6qnkDUxr8Lw7dvFrAIbbY= github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/charmbracelet/x/ansi v0.3.1 h1:CRO6lc/6HCx2/D6S/GZ87jDvRvk6GtPyFP+IljkNtqI= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
github.com/charmbracelet/x/ansi v0.3.1/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@ -40,17 +36,17 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
@ -65,26 +61,31 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View File

@ -1,142 +0,0 @@
package toolbox
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
Use these results for your own reference in completing your task, they do not need to be shown to the user.
Example result:
{
"message": "success",
"result": ".
a_directory/
file1.txt (100 bytes)
file2.txt (200 bytes)
a_file.txt (123 bytes)
another_file.txt (456 bytes)"
}
`
var DirTreeTool = api.ToolSpec{
Name: "dir_tree",
Description: TREE_DESCRIPTION,
Parameters: []api.ToolParameter{
{
Name: "relative_path",
Type: "string",
Description: "If set, display the tree starting from this path relative to the current one.",
},
{
Name: "depth",
Type: "integer",
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
var relativeDir string
if tmp, ok := args["relative_path"]; ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
}
}
var depth int = 0 // Default value if not provided
if tmp, ok := args["depth"]; ok {
switch v := tmp.(type) {
case float64:
depth = int(v)
case string:
var err error
if depth, err = strconv.Atoi(v); err != nil {
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
}
depth = max(0, min(5, depth))
default:
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
}
}
result := tree(relativeDir, depth)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("could not serialize result: %v", err)
}
return ret, nil
},
}
func tree(path string, depth int) api.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
var treeOutput strings.Builder
treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", depth)
if err != nil {
return api.CallResult{
Message: err.Error(),
}
}
return api.CallResult{Result: treeOutput.String()}
}
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
files, err := os.ReadDir(path)
if err != nil {
return err
}
for i, file := range files {
if strings.HasPrefix(file.Name(), ".") {
// Skip hidden files and directories
continue
}
isLast := i == len(files)-1
var branch string
if isLast {
branch = "└── "
} else {
branch = "├── "
}
info, _ := file.Info()
size := info.Size()
sizeStr := fmt.Sprintf(" (%d bytes)", size)
output.WriteString(prefix + branch + file.Name())
if file.IsDir() {
output.WriteString("/\n")
if depth > 0 {
var nextPrefix string
if isLast {
nextPrefix = prefix + " "
} else {
nextPrefix = prefix + "│ "
}
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
}
} else {
output.WriteString(sizeStr + "\n")
}
}
return nil
}

View File

@ -1,178 +0,0 @@
package toolbox
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
var MODIFY_FILE_DESCRIPTION = []string{
"Modify a file. If the file does not exist, it will be created.",
"",
"Content can be either inserted, replaced, or removed through a combination of the start, stop, and content parameters.",
"Use the start and stop line numbers to limit the range of modification to the file.",
"If both `start` and `stop` are left unset (or set to 0), the entire file's contents will be updated.",
"If `start` is set to n and `stop` to n+1, content will be inserted at line n (the content that was at line n will be shifted below the newly inserted content).",
"If only `start` is set, content from the given line and onwards will be updated.",
"If only `stop` is set, content up to but not including the given line will be updated.",
"",
"Examples:",
"1. Append to a file:",
" {\"path\": \"example.txt\", \"start\": <last_line_number + 1>, \"content\": \"New content to append\"}",
"",
"2. Insert at a specific line:",
" {\"path\": \"example.txt\", \"start\": 5, \"stop\": 5, \"content\": \"New line inserted above the previous line 5\"}",
"",
"3. Replace a range of lines:",
" {\"path\": \"example.txt\", \"start\": 3, \"stop\": 7, \"content\": \"New content replacing lines 3-7\"}",
"",
"4. Remove a range of lines:",
" {\"path\": \"example.txt\", \"start\": 2, \"stop\": 5}",
"",
"5. Replace entire file contents:",
" {\"path\": \"example.txt\", \"content\": \"New file contents\"}",
"",
"6. Update from a specific line to the end of the file:",
" {\"path\": \"example.txt\", \"start\": 10, \"content\": \"New content from line 10 onwards\"}",
"",
"7. Update from the beginning of the file to a specific line:",
" {\"path\": \"example.txt\", \"stop\": 6, \"content\": \"New content for first 5 lines\"}",
"",
"Note: Always use specific line numbers based on the current file content. Avoid using arbitrarily large numbers for start or stop.",
}
var ModifyFile = api.ToolSpec{
Name: "modify_file",
Description: strings.Join(MODIFY_FILE_DESCRIPTION, "\n"),
Parameters: []api.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "start",
Type: "integer",
Description: `Start line of the range to modify (inclusive). If omitted, the beginning of the file is implied.`,
},
{
Name: "stop",
Type: "integer",
Description: `End line of the range to modify (inclusive). If omitted, the end of the file is implied.`,
},
{
Name: "content",
Type: "string",
Description: "Content to insert/replace at the range defined by `start` and `stop`. If omitted, the range is removed.",
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to modify_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start int
tmp, ok = args["start"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start in function arguments: %v", tmp)
}
start = int(tmp)
}
var stop int
tmp, ok = args["stop"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid stop in function arguments: %v", tmp)
}
stop = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileModifyContents(path, start, stop, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileModifyContents(path string, startLine int, stopLine int, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
// If both start and stop are unset, update the entire file
if startLine == 0 && stopLine == 0 {
lines = contentLines
} else {
if startLine < 1 {
startLine = 1
}
if stopLine == 0 || stopLine > len(lines) {
stopLine = len(lines)
}
before := lines[:startLine-1]
after := lines[stopLine:]
// Handle insertion case
if startLine == stopLine {
lines = append(before, append(contentLines, lines[startLine-1:]...)...)
} else {
// If content is omitted, remove the specified range
if content == "" {
lines = append(before, after...)
} else {
lines = append(before, append(contentLines, after...)...)
}
}
}
newContent := strings.Join(lines, "\n")
// Write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return api.CallResult{Result: util.AddLineNumbers(newContent)}
}

View File

@ -1,47 +0,0 @@
package agents
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
"dir_tree": toolbox.DirTreeTool,
"read_dir": toolbox.ReadDirTool,
"read_file": toolbox.ReadFileTool,
"modify_file": toolbox.ModifyFile,
"write_file": toolbox.WriteFileTool,
}
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
var toolResults []api.ToolResult
for _, call := range calls {
var tool *api.ToolSpec
for i := range available {
if available[i].Name == call.Name {
tool = &available[i]
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
}
// Execute the tool
result, err := tool.Impl(tool, call.Parameters)
if err != nil {
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
}
toolResult := api.ToolResult{
ToolCallID: call.ID,
ToolName: call.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -1,126 +0,0 @@
package api
import (
"encoding/json"
"fmt"
)
type MessageRole string
const (
MessageRoleSystem MessageRole = "system"
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
Content string // TODO: support multi-part messages
Role MessageRole
ToolCalls []ToolCall
ToolResults []ToolResult
}
type ToolSpec struct {
Name string
Description string
Parameters []ToolParameter
Impl func(*ToolSpec, map[string]interface{}) (string, error)
}
type ToolParameter struct {
Name string `json:"name"`
Type string `json:"type"` // "string", "integer", "boolean"
Required bool `json:"required"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ToolCall struct {
ID string `json:"id" yaml:"-"`
Name string `json:"name" yaml:"tool"`
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
}
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
func NewMessageWithAssistant(content string) *Message {
return &Message{
Role: MessageRoleAssistant,
Content: content,
}
}
func NewMessageWithToolCalls(content string, toolCalls []ToolCall) *Message {
return &Message{
Role: MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
}
func (m MessageRole) IsAssistant() bool {
switch m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
func (m MessageRole) IsUser() bool {
switch m {
case MessageRoleUser, MessageRoleToolResult:
return true
}
return false
}
func (m MessageRole) IsSystem() bool {
switch m {
case MessageRoleSystem:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m MessageRole) FriendlyRole() string {
switch m {
case MessageRoleUser:
return "You"
case MessageRoleSystem:
return "System"
case MessageRoleAssistant:
return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default:
return string(m)
}
}
// TODO: remove this
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@ -3,10 +3,8 @@ package cmd
import ( import (
"fmt" "fmt"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui" "git.mlow.ca/mlow/lmcli/pkg/tui"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -16,34 +14,12 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Open the chat interface", Short: "Open the chat interface",
Long: `Open the chat interface, optionally on a given conversation.`, Long: `Open the chat interface, optionally on a given conversation.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd) // TODO: implement jump-to-conversation logic
if err != nil { shortname := ""
return err if len(args) == 1 {
shortname = args[0]
} }
err := tui.Launch(ctx, shortname)
var opts []tui.LaunchOption
list, err := cmd.Flags().GetBool("list")
if err != nil {
return err
}
if !list && len(args) == 1 {
shortname := args[0]
if shortname != ""{
conv, err := cmdutil.LookupConversationE(ctx, shortname)
if err != nil {
return err
}
opts = append(opts, tui.WithInitialConversation(conv))
}
}
if list {
opts = append(opts, tui.WithInitialView(shared.ViewConversations))
}
err = tui.Launch(ctx, opts...)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }
@ -54,13 +30,8 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
// -l, --list
cmd.Flags().BoolP("list", "l", false, "View/manage conversations")
applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -5,6 +5,7 @@ import (
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -27,12 +28,36 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
return err return err
} }
clone, messageCnt, err := ctx.Conversations.CloneConversation(*toClone) messagesToCopy, err := ctx.Store.Messages(toClone)
if err != nil { if err != nil {
return fmt.Errorf("Failed to clone conversation: %v", err) return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
} }
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title) clone := &model.Conversation{
Title: toClone.Title + " - Clone",
}
if err := ctx.Store.SaveConversation(clone); err != nil {
return fmt.Errorf("Cloud not create clone: %s", err)
}
var errors []error
messageCnt := 0
for _, message := range messagesToCopy {
newMessage := message
newMessage.ConversationID = clone.ID
newMessage.ID = 0
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
errors = append(errors, err)
} else {
messageCnt++
}
}
if len(errors) > 0 {
return fmt.Errorf("Messages failed to be cloned: %v", errors)
}
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -40,7 +65,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
return cmd return cmd

View File

@ -1,8 +1,6 @@
package cmd package cmd
import ( import (
"fmt"
"slices"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
@ -10,6 +8,10 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var (
systemPromptFile string
)
func RootCmd(ctx *lmcli.Context) *cobra.Command { func RootCmd(ctx *lmcli.Context) *cobra.Command {
var root = &cobra.Command{ var root = &cobra.Command{
Use: "lmcli <command> [flags]", Use: "lmcli <command> [flags]",
@ -21,72 +23,58 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
chatCmd := ChatCmd(ctx)
continueCmd := ContinueCmd(ctx)
cloneCmd := CloneCmd(ctx)
editCmd := EditCmd(ctx)
listCmd := ListCmd(ctx)
newCmd := NewCmd(ctx)
promptCmd := PromptCmd(ctx)
renameCmd := RenameCmd(ctx)
replyCmd := ReplyCmd(ctx)
retryCmd := RetryCmd(ctx)
rmCmd := RemoveCmd(ctx)
viewCmd := ViewCmd(ctx)
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
for _, cmd := range inputCmds {
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
root.AddCommand( root.AddCommand(
ChatCmd(ctx), chatCmd,
ContinueCmd(ctx), cloneCmd,
CloneCmd(ctx), continueCmd,
EditCmd(ctx), editCmd,
ListCmd(ctx), listCmd,
NewCmd(ctx), newCmd,
PromptCmd(ctx), promptCmd,
RenameCmd(ctx), renameCmd,
ReplyCmd(ctx), replyCmd,
RetryCmd(ctx), retryCmd,
RemoveCmd(ctx), rmCmd,
ViewCmd(ctx), viewCmd,
) )
return root return root
} }
func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) { func getSystemPrompt(ctx *lmcli.Context) string {
f := cmd.Flags() if systemPromptFile != "" {
content, err := util.ReadFileContents(systemPromptFile)
// -m, --model
f.StringVarP(
ctx.Config.Defaults.Model, "model", "m",
*ctx.Config.Defaults.Model, "Which model to generate a response with",
)
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
// -a, --agent
f.StringVarP(&ctx.Config.Defaults.Agent, "agent", "a", ctx.Config.Defaults.Agent, "Which agent to interact with")
cmd.RegisterFlagCompletionFunc("agent", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetAgents(), cobra.ShellCompDirectiveDefault
})
// --max-length
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
// --temperature
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
// --system-prompt
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
// --system-prompt-file
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
f := cmd.Flags()
model, err := f.GetString("model")
if err != nil { if err != nil {
return fmt.Errorf("Error parsing --model: %w", err) lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
} }
if model != "" && !slices.Contains(ctx.GetModels(), model) { return content
return fmt.Errorf("Unknown model: %s", model)
} }
return *ctx.Config.Defaults.SystemPrompt
agent, err := f.GetString("agent")
if err != nil {
return fmt.Errorf("Error parsing --agent: %w", err)
}
if agent != "" && agent != "none" && !slices.Contains(ctx.GetAgents(), agent) {
return fmt.Errorf("Unknown agent: %s", agent)
}
return nil
} }
// inputFromArgsOrEditor returns either the provided input from the args slice // inputFromArgsOrEditor returns either the provided input from the args slice

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -23,15 +23,10 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortName := args[0] shortName := args[0]
c := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot) messages, err := ctx.Store.Messages(conversation)
if err != nil { if err != nil {
return fmt.Errorf("could not retrieve conversation messages: %v", err) return fmt.Errorf("could not retrieve conversation messages: %v", err)
} }
@ -41,7 +36,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
} }
lastMessage := &messages[len(messages)-1] lastMessage := &messages[len(messages)-1]
if lastMessage.Role != api.MessageRoleAssistant { if lastMessage.Role != model.MessageRoleAssistant {
return fmt.Errorf("the last message in the conversation is not an assistant message") return fmt.Errorf("the last message in the conversation is not an assistant message")
} }
@ -49,16 +44,16 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.Prompt(ctx, messages, nil) continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
// Update the original message // Update the original message
err = ctx.Conversations.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)
if err != nil { if err != nil {
return fmt.Errorf("could not update the last message: %v", err) return fmt.Errorf("could not update the last message: %v", err)
} }
@ -70,9 +65,8 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -22,11 +22,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0] shortName := args[0]
c := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot) messages, err := ctx.Store.Messages(conversation)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
offset, _ := cmd.Flags().GetInt("offset") offset, _ := cmd.Flags().GetInt("offset")
@ -39,7 +39,21 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
} }
desiredIdx := len(messages) - 1 - offset desiredIdx := len(messages) - 1 - offset
toEdit := messages[desiredIdx]
// walk backwards through the conversation deleting messages until and
// including the last user message
toRemove := []model.Message{}
var toEdit *model.Message
for i := len(messages) - 1; i >= 0; i-- {
if i == desiredIdx {
toEdit = &messages[i]
}
toRemove = append(toRemove, messages[i])
messages = messages[:i]
if toEdit != nil {
break
}
}
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content) newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
switch newContents { switch newContents {
@ -49,51 +63,38 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
toEdit.Content = newContents
role, _ := cmd.Flags().GetString("role") role, _ := cmd.Flags().GetString("role")
if role != "" { if role == "" {
if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) { role = string(toEdit.Role)
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
} }
toEdit.Role = api.MessageRole(role)
}
// Update the message in-place for _, message := range toRemove {
inplace, _ := cmd.Flags().GetBool("in-place") err = ctx.Store.DeleteMessage(&message)
if inplace {
return ctx.Conversations.UpdateMessage(&toEdit)
}
// Otherwise, create a branch for the edited message
message, _, err := ctx.Conversations.CloneBranch(toEdit)
if err != nil { if err != nil {
return err lmcli.Warn("Could not delete message: %v\n", err)
}
} }
if desiredIdx > 0 { cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
// update selected reply ConversationID: conversation.ID,
messages[desiredIdx-1].SelectedReply = message Role: model.MessageRole(role),
err = ctx.Conversations.UpdateMessage(&messages[desiredIdx-1]) Content: newContents,
} else { })
// update selected root return nil
c.SelectedRoot = message
err = ctx.Conversations.UpdateConversation(c)
}
return err
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch")
cmd.Flags().Int("offset", 1, "Offset from the last message to edit") cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)") cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)")
return cmd return cmd
} }

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"slices"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
@ -20,9 +21,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
Short: "List conversations", Short: "List conversations",
Long: `List conversations in order of recent activity`, Long: `List conversations in order of recent activity`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
list, err := ctx.Conversations.LoadConversationList() conversations, err := ctx.Store.Conversations()
if err != nil { if err != nil {
return fmt.Errorf("Could not load conversations: %v", err) return fmt.Errorf("Could not fetch conversations: %v", err)
} }
type Category struct { type Category struct {
@ -57,12 +58,17 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
all, _ := cmd.Flags().GetBool("all") all, _ := cmd.Flags().GetBool("all")
for _, item := range list.Items { for _, conversation := range conversations {
age := now.Sub(item.LastMessageAt) lastMessage, err := ctx.Store.LastMessage(&conversation)
if lastMessage == nil || err != nil {
continue
}
messageAge := now.Sub(lastMessage.CreatedAt)
var category string var category string
for _, c := range categories { for _, c := range categories {
if age < c.cutoff { if messageAge < c.cutoff {
category = c.name category = c.name
break break
} }
@ -70,14 +76,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
formatted := fmt.Sprintf( formatted := fmt.Sprintf(
"%s - %s - %s", "%s - %s - %s",
item.ShortName, conversation.ShortName.String,
util.HumanTimeElapsedSince(age), util.HumanTimeElapsedSince(messageAge),
item.Title, conversation.Title,
) )
categorized[category] = append( categorized[category] = append(
categorized[category], categorized[category],
ConversationLine{age, formatted}, ConversationLine{messageAge, formatted},
) )
} }
@ -90,10 +96,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
continue continue
} }
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
return int(a.timeSinceReply - b.timeSinceReply)
})
fmt.Printf("%s:\n", category.name) fmt.Printf("%s:\n", category.name)
for _, conv := range conversationLines { for _, conv := range conversationLines {
if conversationsPrinted >= count && !all { if conversationsPrinted >= count && !all {
fmt.Printf("%d remaining conversation(s), use --all to view.\n", list.Total-conversationsPrinted) fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
break outer break outer
} }
@ -105,8 +115,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
cmd.Flags().BoolP("all", "a", false, "Show all conversations") cmd.Flags().Bool("all", false, "Show all conversations")
cmd.Flags().IntP("count", "c", LS_COUNT, "How many conversations to show") cmd.Flags().Int("count", LS_COUNT, "How many conversations to show")
return cmd return cmd
} }

View File

@ -3,10 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -16,42 +15,46 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Start a new conversation", Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`, Long: `Start a new conversation with the Large Language Model.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd) messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
if err != nil { if messageContents == "" {
return err
}
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
if input == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
messages := []conversation.Message{{ conversation := &model.Conversation{}
Role: api.MessageRoleUser, err := ctx.Store.SaveConversation(conversation)
Content: input,
}}
conversation, messages, err := ctx.Conversations.StartConversation(messages...)
if err != nil { if err != nil {
return fmt.Errorf("Could not start a new conversation: %v", err) return fmt.Errorf("Could not save new conversation: %v", err)
} }
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true) messages := []model.Message{
{
ConversationID: conversation.ID,
Role: model.MessageRoleSystem,
Content: getSystemPrompt(ctx),
},
{
ConversationID: conversation.ID,
Role: model.MessageRoleUser,
Content: messageContents,
},
}
title, err := cmdutil.GenerateTitle(ctx, messages) cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
title, err := cmdutil.GenerateTitle(ctx, conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err) lmcli.Warn("Could not generate title for conversation: %v\n", err)
} }
conversation.Title = title conversation.Title = title
err = ctx.Conversations.UpdateConversation(conversation)
err = ctx.Store.SaveConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation title: %v\n", err) lmcli.Warn("Could not save conversation after generating title: %v\n", err)
} }
return nil return nil
}, },
} }
applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -3,10 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -16,29 +15,28 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Do a one-shot prompt", Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`, Long: `Prompt the Large Language Model and get a response.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd) message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
if err != nil { if message == "" {
return err
}
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
if input == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
messages := []conversation.Message{{ messages := []model.Message{
Role: api.MessageRoleUser, {
Content: input, Role: model.MessageRoleSystem,
}} Content: getSystemPrompt(ctx),
},
{
Role: model.MessageRoleUser,
Content: message,
},
}
_, err = cmdutil.Prompt(ctx, messages, nil) _, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }
return nil return nil
}, },
} }
applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -5,8 +5,8 @@ import (
"strings" "strings"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -23,14 +23,14 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
var toRemove []*conversation.Conversation var toRemove []*model.Conversation
for _, shortName := range args { for _, shortName := range args {
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
toRemove = append(toRemove, conversation) toRemove = append(toRemove, conversation)
} }
var errors []error var errors []error
for _, c := range toRemove { for _, c := range toRemove {
err := ctx.Conversations.DeleteConversation(c) err := ctx.Store.DeleteConversation(c)
if err != nil { if err != nil {
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err)) errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
} }
@ -44,7 +44,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
var completions []string var completions []string
outer: outer:
for _, completion := range ctx.Conversations.ConversationShortNameCompletions(toComplete) { for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
parts := strings.Split(completion, "\t") parts := strings.Split(completion, "\t")
for _, arg := range args { for _, arg := range args {
if parts[0] == arg { if parts[0] == arg {

View File

@ -24,17 +24,12 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
var err error var err error
var title string
generate, _ := cmd.Flags().GetBool("generate") generate, _ := cmd.Flags().GetBool("generate")
var title string
if generate { if generate {
messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot) title, err = cmdutil.GenerateTitle(ctx, conversation)
if err != nil {
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
}
title, err = cmdutil.GenerateTitle(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("Could not generate conversation title: %v", err) return fmt.Errorf("Could not generate conversation title: %v", err)
} }
@ -46,9 +41,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
} }
conversation.Title = title conversation.Title = title
err = ctx.Conversations.UpdateConversation(conversation) err = ctx.Store.SaveConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not update conversation title: %v\n", err) lmcli.Warn("Could not save conversation with new title: %v\n", err)
} }
return nil return nil
}, },
@ -57,7 +52,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -3,10 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -23,21 +22,17 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortName := args[0] shortName := args[0]
c := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "") reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
if reply == "" { if reply == "" {
return fmt.Errorf("No reply was provided.") return fmt.Errorf("No reply was provided.")
} }
cmdutil.HandleConversationReply(ctx, c, true, conversation.Message{ cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
Role: api.MessageRoleUser, ConversationID: conversation.ID,
Role: model.MessageRoleUser,
Content: reply, Content: reply,
}) })
return nil return nil
@ -47,10 +42,8 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "retry <conversation>", Use: "retry <conversation>",
Short: "Retry the last user reply in a conversation", Short: "Retry the last user reply in a conversation",
Long: `Prompt the conversation from the last user response.`, Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
argCount := 1 argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
@ -22,44 +22,28 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortName := args[0] shortName := args[0]
c := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
// Load the complete thread from the root message messages, err := ctx.Store.Messages(conversation)
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
offset, _ := cmd.Flags().GetInt("offset") // walk backwards through the conversation and delete messages, break
if offset < 0 { // when we find the latest user response
offset = -offset for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == model.MessageRoleUser {
break
} }
if offset > len(messages)-1 { err = ctx.Store.DeleteMessage(&messages[i])
return fmt.Errorf("Offset %d is before the start of the conversation.", offset) if err != nil {
lmcli.Warn("Could not delete previous reply: %v\n", err)
}
} }
retryFromIdx := len(messages) - 1 - offset cmdutil.HandleConversationReply(ctx, conversation, true)
// decrease retryFromIdx until we hit a user message
for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser {
retryFromIdx--
}
if messages[retryFromIdx].Role != api.MessageRoleUser {
return fmt.Errorf("No user messages to retry")
}
fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx])
// Start a new branch at the last user message
cmdutil.HandleReply(ctx, &messages[retryFromIdx], true)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -67,12 +51,8 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
cmd.Flags().Int("offset", 0, "Offset from the last message to retry from.")
applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -2,59 +2,42 @@ package util
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"os" "os"
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
// Prompt prompts the configured the configured model and streams the response // fetchAndShowCompletion prompts the LLM with the given messages and streams
// to stdout. Returns all model reply messages. // the response to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) { func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") content := make(chan string) // receives the reponse from LLM
if err != nil {
return nil, err
}
params := provider.RequestParameters{
Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature,
}
system := ctx.DefaultSystemPrompt()
agent := ctx.GetAgent(ctx.Config.Defaults.Agent)
if agent != nil {
if agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
params.Toolbox = agent.Toolbox
}
if system != "" {
messages = conversation.ApplySystemPrompt(messages, system, false)
}
content := make(chan provider.Chunk)
defer close(content) defer close(content)
// render the content received over the channel // render all content received over the channel
go ShowDelayedContent(content) go ShowDelayedContent(content)
reply, err := p.CreateChatCompletionStream( completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
context.Background(), params, conversation.MessagesToAPI(messages), content, if err != nil {
) return "", err
}
if reply.Content != "" { requestParams := model.RequestParameters{
Model: *ctx.Config.Defaults.Model,
MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature,
ToolBag: ctx.EnabledTools,
}
response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content,
)
if response != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
fmt.Println() fmt.Println()
@ -63,99 +46,85 @@ func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(c
err = nil err = nil
} }
} }
return reply, err return response, nil
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
// short name or exits the program // short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *conversation.Conversation { func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
c, err := ctx.Conversations.FindConversationByShortName(shortName) c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil { if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err) lmcli.Fatal("Could not lookup conversation: %v\n", err)
} }
if c.ID == 0 { if c.ID == 0 {
lmcli.Fatal("Conversation not found: %s\n", shortName) lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
} }
return c return c
} }
func LookupConversationE(ctx *lmcli.Context, shortName string) (*conversation.Conversation, error) { func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
c, err := ctx.Conversations.FindConversationByShortName(shortName) c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err) return nil, fmt.Errorf("Could not lookup conversation: %v", err)
} }
if c.ID == 0 { if c.ID == 0 {
return nil, fmt.Errorf("Conversation not found: %s", shortName) return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
} }
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *conversation.Conversation, persist bool, toSend ...conversation.Message) {
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *conversation.Message, persist bool, messages ...conversation.Message) { func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
if to == nil { existing, err := ctx.Store.Messages(c)
lmcli.Fatal("Can't prompt from an empty message.")
}
existing, err := ctx.Conversations.PathToRoot(to)
if err != nil { if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err) lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
} }
RenderConversation(ctx, append(existing, messages...), true) if persist {
for _, message := range toSend {
var savedReplies []conversation.Message err = ctx.Store.SaveMessage(&message)
if persist && len(messages) > 0 {
savedReplies, err = ctx.Conversations.Reply(to, messages...)
if err != nil { if err != nil {
lmcli.Warn("Could not save messages: %v\n", err) lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
} }
} }
}
allMessages := append(existing, toSend...)
RenderConversation(ctx, allMessages, true)
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&conversation.Message{Role: api.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastSavedMessage *conversation.Message replyCallback := func(reply model.Message) {
lastSavedMessage = to
if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply conversation.Message) {
if !persist { if !persist {
return return
} }
savedReplies, err = ctx.Conversations.Reply(lastSavedMessage, reply)
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil { if err != nil {
lmcli.Warn("Could not save reply: %v\n", err) lmcli.Warn("Could not save reply: %v\n", err)
} }
lastSavedMessage = &savedReplies[0]
} }
_, err = Prompt(ctx, append(existing, messages...), replyCallback) _, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
} }
func FormatForExternalPrompt(messages []conversation.Message, system bool) string { func FormatForExternalPrompt(messages []model.Message, system bool) string {
sb := strings.Builder{} sb := strings.Builder{}
for _, message := range messages { for _, message := range messages {
if message.Content == "" { if message.Content == "" {
continue continue
} }
switch message.Role { switch message.Role {
case api.MessageRoleAssistant, api.MessageRoleToolCall: case model.MessageRoleAssistant, model.MessageRoleToolCall:
sb.WriteString("Assistant:\n\n") sb.WriteString("Assistant:\n\n")
case api.MessageRoleUser: case model.MessageRoleUser:
sb.WriteString("User:\n\n") sb.WriteString("User:\n\n")
default: default:
continue continue
@ -165,76 +134,60 @@ func FormatForExternalPrompt(messages []conversation.Message, system bool) strin
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, messages []conversation.Message) (string, error) { func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below. messages, err := ctx.Store.Messages(c)
if err != nil {
return "", err
}
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
Example conversation: Example conversation:
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}] """
User:
Hello!
Assistant:
Hello! How may I assist you?
"""
Example response: Example response:
{"title": "Help with math homework"} """
Title: A brief introduction
"""
` `
type msg struct { conversation := FormatForExternalPrompt(messages, false)
Role string
Content string generateRequest := []model.Message{
{
Role: model.MessageRoleUser,
Content: fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n%s", conversation, prompt),
},
} }
var msgs []msg completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
for _, m := range messages {
switch m.Role {
case api.MessageRoleAssistant, api.MessageRoleUser:
msgs = append(msgs, msg{string(m.Role), m.Content})
}
}
// Serialize the conversation to JSON
jsonBytes, err := json.Marshal(msgs)
if err != nil { if err != nil {
return "", err return "", err
} }
generateRequest := []conversation.Message{ requestParams := model.RequestParameters{
{ Model: *ctx.Config.Conversations.TitleGenerationModel,
Role: api.MessageRoleSystem,
Content: systemPrompt,
},
{
Role: api.MessageRoleUser,
Content: string(jsonBytes),
},
}
m, _, p, err := ctx.GetModelProvider(
*ctx.Config.Conversations.TitleGenerationModel, "",
)
if err != nil {
return "", err
}
requestParams := provider.RequestParameters{
Model: m,
MaxTokens: 25, MaxTokens: 25,
} }
response, err := p.CreateChatCompletion( response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
context.Background(), requestParams, conversation.MessagesToAPI(generateRequest),
)
if err != nil { if err != nil {
return "", err return "", err
} }
// Parse the JSON response response = strings.TrimPrefix(response, "Title: ")
var jsonResponse struct { response = strings.Trim(response, "\"")
Title string `json:"title"`
}
err = json.Unmarshal([]byte(response.Content), &jsonResponse)
if err != nil {
return "", err
}
return jsonResponse.Title, nil return response, nil
} }
// ShowWaitAnimation prints an animated ellipses to stdout until something is // ShowWaitAnimation prints an animated ellipses to stdout until something is
@ -274,7 +227,7 @@ func ShowWaitAnimation(signal chan any) {
// chunked) content is received on the channel, the waiting animation is // chunked) content is received on the channel, the waiting animation is
// replaced by the content. // replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func ShowDelayedContent(content <-chan provider.Chunk) { func ShowDelayedContent(content <-chan string) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)
@ -287,14 +240,14 @@ func ShowDelayedContent(content <-chan provider.Chunk) {
<-waitSignal <-waitSignal
firstChunk = false firstChunk = false
} }
fmt.Print(chunk.Content) fmt.Print(chunk)
} }
} }
// RenderConversation renders the given messages to TTY, with optional space // RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters // for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true) // are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spaceForResponse bool) { func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
l := len(messages) l := len(messages)
for i, message := range messages { for i, message := range messages {
RenderMessage(ctx, &message) RenderMessage(ctx, &message)
@ -305,7 +258,7 @@ func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spa
} }
} }
func RenderMessage(ctx *lmcli.Context, m *conversation.Message) { func RenderMessage(ctx *lmcli.Context, m *model.Message) {
var messageAge string var messageAge string
if m.CreatedAt.IsZero() { if m.CreatedAt.IsZero() {
messageAge = "now" messageAge = "now"
@ -317,11 +270,11 @@ func RenderMessage(ctx *lmcli.Context, m *conversation.Message) {
headerStyle := lipgloss.NewStyle().Bold(true) headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role { switch m.Role {
case api.MessageRoleSystem: case model.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case api.MessageRoleUser: case model.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case api.MessageRoleAssistant: case model.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
} }

View File

@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Store.Messages(conversation)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
cmdutil.RenderConversation(ctx, messages, false) cmdutil.RenderConversation(ctx, messages, false)
@ -37,7 +37,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -1,99 +0,0 @@
package conversation
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
RootMessages []Message `gorm:"-:all"`
LastMessageAt time.Time
}
type MessageMeta struct {
GenerationProvider *string `json:"generation_provider,omitempty"`
GenerationModel *string `json:"generation_model,omitempty"`
}
type Message struct {
ID uint `gorm:"primaryKey"`
CreatedAt time.Time
Metadata MessageMeta
ConversationID *uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
Role api.MessageRole
Content string
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
}
func (m *MessageMeta) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), m)
}
func (m MessageMeta) Value() (driver.Value, error) {
return json.Marshal(m)
}
type ToolCalls []api.ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResults []api.ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]api.ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@ -1,520 +0,0 @@
package conversation
import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"time"
sqids "github.com/sqids/sqids-go"
"gorm.io/gorm"
)
// Repo exposes low-level message and conversation management. See
// Service for high-level helpers
type Repo interface {
LoadConversationList() (ConversationList, error)
FindConversationByShortName(shortName string) (*Conversation, error)
ConversationShortNameCompletions(search string) []string
GetConversationByID(int uint) (*Conversation, error)
GetRootMessages(conversationID uint) ([]Message, error)
CreateConversation(title string) (*Conversation, error)
UpdateConversation(*Conversation) error
DeleteConversation(*Conversation) error
DeleteConversationById(id uint) error
GetMessageByID(messageID uint) (*Message, error)
SaveMessage(message Message) (*Message, error)
UpdateMessage(message *Message) error
DeleteMessage(message *Message, prune bool) error
CloneBranch(toClone Message) (*Message, uint, error)
Reply(to *Message, messages ...Message) ([]Message, error)
PathToRoot(message *Message) ([]Message, error)
PathToLeaf(message *Message) ([]Message, error)
// Retrieves and return the "selected thread" of the conversation.
// The "selected thread" of the conversation is a chain of messages
// starting from the Conversation's SelectedRoot Message, following each
// Message's SelectedReply until the tail Message is reached.
GetSelectedThread(*Conversation) ([]Message, error)
// Start a new conversation with the given messages
StartConversation(messages ...Message) (*Conversation, []Message, error)
CloneConversation(toClone Conversation) (*Conversation, uint, error)
}
type repo struct {
db *gorm.DB
sqids *sqids.Sqids
}
func NewRepo(db *gorm.DB) (Repo, error) {
models := []any{
&Conversation{},
&Message{},
}
for _, x := range models {
err := db.AutoMigrate(x)
if err != nil {
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
}
}
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &repo{db, _sqids}, nil
}
type ConversationListItem struct {
ID uint
ShortName string
Title string
LastMessageAt time.Time
}
type ConversationList struct {
Total int
Items []ConversationListItem
}
// LoadConversationList loads existing conversations, ordered by the date
// of their latest message, from most recent to oldest.
func (s *repo) LoadConversationList() (ConversationList, error) {
list := ConversationList{}
var convos []Conversation
err := s.db.Order("last_message_at DESC").Find(&convos).Error
if err != nil {
return list, err
}
for _, c := range convos {
list.Items = append(list.Items, ConversationListItem{
ID: c.ID,
ShortName: c.ShortName.String,
Title: c.Title,
LastMessageAt: c.LastMessageAt,
})
}
list.Total = len(list.Items)
return list, nil
}
func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) {
if shortName == "" {
return nil, errors.New("shortName is empty")
}
var conversation Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *repo) ConversationShortNameCompletions(shortName string) []string {
var conversations []Conversation
// ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
}
func (s *repo) GetConversationByID(id uint) (*Conversation, error) {
var conversation Conversation
err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error
if err != nil {
return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err)
}
rootMessages, err := s.GetRootMessages(id)
if err != nil {
return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err)
}
conversation.RootMessages = rootMessages
return &conversation, nil
}
func (s *repo) CreateConversation(title string) (*Conversation, error) {
// Create the new conversation
c := &Conversation{Title: title}
err := s.db.Create(c).Error
if err != nil {
return nil, err
}
// Generate and save its "short name"
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
c.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Updates(c).Error
if err != nil {
return nil, err
}
return c, nil
}
func (s *repo) UpdateConversation(c *Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
}
return s.db.Updates(c).Error
}
func (s *repo) DeleteConversation(c *Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
}
return s.DeleteConversationById(c.ID)
}
func (s *repo) DeleteConversationById(id uint) error {
if id == 0 {
return fmt.Errorf("Invalid conversation ID: %d", id)
}
err := s.db.Where("conversation_id = ?", id).Delete(&Message{}).Error
if err != nil {
return err
}
return s.db.Where("id = ?", id).Delete(&Conversation{}).Error
}
func (s *repo) SaveMessage(m Message) (*Message, error) {
if m.Conversation == nil {
return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)")
}
newMessage := m
newMessage.ID = 0
newMessage.CreatedAt = time.Now()
return &newMessage, s.db.Create(&newMessage).Error
}
func (s *repo) UpdateMessage(m *Message) error {
if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)")
}
return s.db.Updates(m).Error
}
func (s *repo) DeleteMessage(message *Message, prune bool) error {
return s.db.Delete(&message).Error
}
func (s *repo) GetMessageByID(messageID uint) (*Message, error) {
var message Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err
}
// Reply to a message with a series of messages (each followed by the next)
func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) {
var savedMessages []Message
err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to
for i := range messages {
parent := currentParent
message := messages[i]
message.Parent = parent
message.Conversation = parent.Conversation
message.ID = 0
message.CreatedAt = time.Time{}
if err := tx.Create(&message).Error; err != nil {
return err
}
// update parent selected reply
parent.Replies = append(parent.Replies, message)
parent.SelectedReply = &message
if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil {
return err
}
savedMessages = append(savedMessages, message)
currentParent = &message
}
return nil
})
if err != nil {
return savedMessages, err
}
to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt
err = s.UpdateConversation(to.Conversation)
return savedMessages, err
}
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the messageToClone
func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
newMessage.Replies = nil
newMessage.SelectedReplyID = nil
newMessage.SelectedReply = nil
originalReplies := messageToClone.Replies
if err := s.db.Create(&newMessage).Error; err != nil {
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
}
var replyCount uint = 0
for _, reply := range originalReplies {
replyCount++
newReply := reply
newReply.ConversationID = messageToClone.ConversationID
newReply.ParentID = &newMessage.ID
newReply.Parent = &newMessage
res, c, err := s.CloneBranch(newReply)
if err != nil {
return nil, 0, err
}
newMessage.Replies = append(newMessage.Replies, *res)
replyCount += c
if reply.ID == *messageToClone.SelectedReplyID {
newMessage.SelectedReplyID = &res.ID
if err := s.UpdateMessage(&newMessage); err != nil {
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
}
}
}
return &newMessage, replyCount, nil
}
func fetchMessages(db *gorm.DB) ([]Message, error) {
var messages []Message
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err)
}
messageMap := make(map[uint]Message)
for i, message := range messages {
messageMap[messages[i].ID] = message
}
// Create a map to store replies by their parent ID
repliesMap := make(map[uint][]Message)
for i, message := range messages {
if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
}
}
// Assign replies, parent, and selected reply to each message
for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]Message, len(replies))
for j, m := range replies {
messages[i].Replies[j] = m
}
}
if messages[i].ParentID != nil {
if parent, exists := messageMap[*messages[i].ParentID]; exists {
messages[i].Parent = &parent
}
}
if messages[i].SelectedReplyID != nil {
if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists {
messages[i].SelectedReply = &selectedReply
}
}
}
return messages, nil
}
func (r repo) GetRootMessages(conversationID uint) ([]Message, error) {
var rootMessages []Message
err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil {
return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err)
}
return rootMessages, nil
}
func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) {
var messages []Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil {
return nil, err
}
// Create a map to store messages by their ID
messageMap := make(map[uint]*Message, len(messages))
for i := range messages {
messageMap[messages[i].ID] = &messages[i]
}
// Construct Replies
repliesMap := make(map[uint][]*Message, len(messages))
for _, m := range messageMap {
if m.ParentID == nil {
continue
}
if p, ok := messageMap[*m.ParentID]; ok {
repliesMap[p.ID] = append(repliesMap[p.ID], m)
}
}
// Add replies to messages
for _, m := range messageMap {
if replies, ok := repliesMap[m.ID]; ok {
m.Replies = make([]Message, len(replies))
for idx, reply := range replies {
m.Replies[idx] = *reply
}
}
}
// Build the path
var path []Message
nextID := &message.ID
for {
current, exists := messageMap[*nextID]
if !exists {
return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID)
}
path = append(path, *current)
nextID = getNext(current)
if nextID == nil {
break
}
}
return path, nil
}
// PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided)
func (s *repo) PathToRoot(message *Message) ([]Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
path, err := s.buildPath(message, func(m *Message) *uint {
return m.ParentID
})
if err != nil {
return nil, err
}
slices.Reverse(path)
return path, nil
}
// PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf)
func (s *repo) PathToLeaf(message *Message) ([]Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
return s.buildPath(message, func(m *Message) *uint {
return m.SelectedReplyID
})
}
func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
// Create new conversation
conversation, err := s.CreateConversation("")
if err != nil {
return nil, nil, err
}
messages[0].Conversation = conversation
// Create first message
firstMessage, err := s.SaveMessage(messages[0])
if err != nil {
return nil, nil, err
}
messages[0] = *firstMessage
// Update conversation's selected root message
conversation.RootMessages = []Message{messages[0]}
conversation.SelectedRoot = &messages[0]
conversation.LastMessageAt = messages[0].CreatedAt
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]Message{messages[0]}, newMessages...)
conversation.LastMessageAt = messages[len(messages)-1].CreatedAt
}
err = s.UpdateConversation(conversation)
return conversation, messages, err
}
// CloneConversation clones the given conversation and all of its meesages
func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) {
rootMessages, err := s.GetRootMessages(toClone.ID)
if err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
}
clone, err := s.CreateConversation(toClone.Title + " - Clone")
if err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
}
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = &clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) {
if c.SelectedRoot == nil {
return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug")
}
return s.PathToLeaf(c.SelectedRoot)
}

View File

@ -1,55 +0,0 @@
package conversation
import (
"git.mlow.ca/mlow/lmcli/pkg/api"
)
// ApplySystemPrompt updates the contents of an existing system Message if it
// exists, or returns a new slice with the system Message prepended.
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
if len(m) > 0 && m[0].Role == api.MessageRoleSystem {
if force {
m[0].Content = system
}
return m
} else {
return append([]Message{{
Role: api.MessageRoleSystem,
Content: system,
}}, m...)
}
}
func MessageToAPI(m Message) api.Message {
return api.Message{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolResults: m.ToolResults,
}
}
func MessagesToAPI(messages []Message) []api.Message {
ret := make([]api.Message, 0, len(messages))
for _, m := range messages {
ret = append(ret, MessageToAPI(m))
}
return ret
}
func MessageFromAPI(m api.Message) Message {
return Message{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolResults: m.ToolResults,
}
}
func MessagesFromAPI(messages []api.Message) []Message {
ret := make([]Message, 0, len(messages))
for _, m := range messages {
ret = append(ret, MessageFromAPI(m))
}
return ret
}

View File

@ -5,39 +5,34 @@ import (
"os" "os"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"gopkg.in/yaml.v3" "github.com/go-yaml/yaml"
) )
type Config struct { type Config struct {
Defaults *struct { Defaults *struct {
Model *string `yaml:"model" default:"gpt-4"` SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
MaxTokens *int `yaml:"maxTokens" default:"256"` MaxTokens *int `yaml:"maxTokens" default:"256"`
Temperature *float32 `yaml:"temperature" default:"0.2"` Temperature *float32 `yaml:"temperature" default:"0.7"`
SystemPrompt string `yaml:"systemPrompt,omitempty"` Model *string `yaml:"model" default:"gpt-4"`
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
Agent string `yaml:"agent"`
} `yaml:"defaults"` } `yaml:"defaults"`
Conversations *struct { Conversations *struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
} `yaml:"conversations"` } `yaml:"conversations"`
Tools *struct {
EnabledTools *[]string `yaml:"enabledTools"`
} `yaml:"tools"`
OpenAI *struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"`
Models *[]string `yaml:"models"`
} `yaml:"openai"`
Anthropic *struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"`
Models *[]string `yaml:"models"`
} `yaml:"anthropic"`
Chroma *struct { Chroma *struct {
Style *string `yaml:"style" default:"onedark"` Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"` Formatter *string `yaml:"formatter" default:"terminal16m"`
} `yaml:"chroma"` } `yaml:"chroma"`
Agents []*struct {
Name string `yaml:"name"`
SystemPrompt string `yaml:"systemPrompt"`
Tools []string `yaml:"tools"`
} `yaml:"agents"`
Providers []*struct {
Name string `yaml:"name,omitempty"`
Display string `yaml:"display,omitempty"`
Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"`
Models []string `yaml:"models"`
Headers map[string]string `yaml:"headers"`
} `yaml:"providers"`
} }
func NewConfig(configFile string) (*Config, error) { func NewConfig(configFile string) (*Config, error) {
@ -65,9 +60,8 @@ func NewConfig(configFile string) (*Config, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err) return nil, fmt.Errorf("Could not open config file for writing: %v", err)
} }
encoder := yaml.NewEncoder(file) bytes, _ := yaml.Marshal(c)
encoder.SetIndent(2) _, err = file.Write(bytes)
err = encoder.Encode(c)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err) return nil, fmt.Errorf("Could not save default configuration: %v", err)
} }

View File

@ -1,231 +1,86 @@
package lmcli package lmcli
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/provider/google" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
) )
type Agent struct {
Name string
SystemPrompt string
Toolbox []api.ToolSpec
}
type Context struct { type Context struct {
// high level app configuration, may be mutated at runtime Config *Config
Config Config Store ConversationStore
Conversations conversation.Repo
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
EnabledTools []model.Tool
} }
func NewContext() (*Context, error) { func NewContext() (*Context, error) {
configFile := filepath.Join(configDir(), "config.yaml") configFile := filepath.Join(configDir(), "config.yaml")
config, err := NewConfig(configFile) config, err := NewConfig(configFile)
if err != nil { if err != nil {
return nil, err Fatal("%v\n", err)
} }
store, err := getConversationService()
if err != nil {
return nil, err
}
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
return &Context{*config, store, chroma}, nil
}
func createOrOpenAppend(path string) (*os.File, error) {
var file *os.File
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
file, err = os.Create(path)
if err != nil {
return nil, err
}
} else {
file, err = os.OpenFile(path, os.O_APPEND, fs.ModeAppend)
if err != nil {
return nil, err
}
}
return file, nil
}
func getConversationService() (conversation.Repo, error) {
databaseFile := filepath.Join(dataDir(), "conversations.db") databaseFile := filepath.Join(dataDir(), "conversations.db")
gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log")) db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("Could not open database log file: %v", err)
}
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
Logger: logger.New(log.New(gormLogFile, "\n", log.LstdFlags), logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
Colorful: true,
}),
})
if err != nil { if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err) return nil, fmt.Errorf("Error establishing connection to store: %v", err)
} }
repo, err := conversation.NewRepo(db) store, err := NewSQLStore(db)
if err != nil { if err != nil {
return nil, err Fatal("%v\n", err)
}
return repo, nil
}
func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers {
name := p.Kind
if p.Name != "" {
name = p.Name
} }
for _, m := range p.Models { chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
modelCounts[m]++
models = append(models, fmt.Sprintf("%s@%s", m, name))
}
}
for m, c := range modelCounts { var enabledTools []model.Tool
if c == 1 { for _, toolName := range *config.Tools.EnabledTools {
models = append(models, m) tool, ok := tools.AvailableTools[toolName]
}
}
return
}
func (c *Context) GetAgents() (agents []string) {
for _, p := range c.Config.Agents {
agents = append(agents, p.Name)
}
return
}
func (c *Context) GetAgent(name string) *Agent {
if name == "" || name == "none" {
return nil
}
for _, a := range c.Config.Agents {
if name != a.Name {
continue
}
var enabledTools []api.ToolSpec
for _, toolName := range a.Tools {
tool, ok := agents.AvailableTools[toolName]
if ok { if ok {
enabledTools = append(enabledTools, tool) enabledTools = append(enabledTools, tool)
} }
} }
return &Agent{ return &Context{config, store, chroma, enabledTools}, nil
Name: a.Name,
SystemPrompt: a.SystemPrompt,
Toolbox: enabledTools,
}
}
return nil
} }
func (c *Context) DefaultSystemPrompt() string { func (c *Context) GetModels() (models []string) {
if c.Config.Defaults.SystemPromptFile != "" { for _, m := range *c.Config.Anthropic.Models {
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile) models = append(models, m)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
} }
return content for _, m := range *c.Config.OpenAI.Models {
models = append(models, m)
} }
return c.Config.Defaults.SystemPrompt return
} }
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) { func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
parts := strings.Split(model, "@") for _, m := range *c.Config.Anthropic.Models {
if provider == "" && len(parts) > 1 {
model = parts[0]
provider = parts[1]
}
for _, p := range c.Config.Providers {
name := p.Kind
if p.Name != "" {
name = p.Name
}
if provider != "" && name != provider {
continue
}
for _, m := range p.Models {
if m == model { if m == model {
switch p.Kind { anthropic := &anthropic.AnthropicClient{
case "anthropic": APIKey: *c.Config.Anthropic.APIKey,
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
} }
return model, name, &anthropic.AnthropicClient{ return anthropic, nil
BaseURL: url,
APIKey: p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &google.Client{
BaseURL: url,
APIKey: p.APIKey,
}, nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &ollama.OllamaClient{
BaseURL: url,
}, nil
case "openai":
url := "https://api.openai.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, name, &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
Headers: p.Headers,
}, nil
default:
return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
} }
} }
for _, m := range *c.Config.OpenAI.Models {
if m == model {
openai := &openai.OpenAIClient{
APIKey: *c.Config.OpenAI.APIKey,
}
return openai, nil
} }
} }
return "", "", nil, fmt.Errorf("unknown model: %s", model) return nil, fmt.Errorf("unknown model: %s", model)
} }
func configDir() string { func configDir() string {

View File

@ -0,0 +1,58 @@
package model
import (
"database/sql"
"time"
)
type MessageRole string
const (
MessageRoleSystem MessageRole = "system"
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"`
Content string
Role MessageRole
CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the modl)
ToolResults ToolResults // a json array of tool results
}
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
SystemPrompt string
ToolBag []Tool
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m *MessageRole) FriendlyRole() string {
var friendlyRole string
switch *m {
case MessageRoleUser:
friendlyRole = "You"
case MessageRoleSystem:
friendlyRole = "System"
case MessageRoleAssistant:
friendlyRole = "Assistant"
default:
friendlyRole = string(*m)
}
return friendlyRole
}

98
pkg/lmcli/model/tool.go Normal file
View File

@ -0,0 +1,98 @@
package model
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type Tool struct {
Name string
Description string
Parameters []ToolParameter
Impl func(*Tool, map[string]interface{}) (string, error)
}
type ToolParameter struct {
Name string `json:"name"`
Type string `json:"type"` // "string", "integer", "boolean"
Required bool `json:"required"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Parameters map[string]interface{} `json:"parameters"`
}
type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResult struct {
ToolCallID string `json:"toolCallID"`
ToolName string `json:"toolName,omitempty"`
Result string `json:"result,omitempty"`
}
type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@ -0,0 +1,348 @@
package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
type AnthropicClient struct {
APIKey string
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
}
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
}
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
requestBody := Request{
Model: params.Model,
Messages: make([]Message, len(messages)),
System: params.SystemPrompt,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Stream: false,
StopSequences: []string{
FUNCTION_STOP_SEQUENCE,
"\n\nHuman:",
},
}
startIdx := 0
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:]
startIdx = 1
}
if len(params.ToolBag) > 0 {
if len(requestBody.System) > 0 {
// add a divider between existing system prompt and tools
requestBody.System += "\n\n---\n\n"
}
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
}
for i, msg := range messages[startIdx:] {
message := &requestBody.Messages[i]
switch msg.Role {
case model.MessageRoleToolCall:
message.Role = "assistant"
if msg.Content != "" {
message.Content = msg.Content
}
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
xmlString, err := xmlFuncCalls.XMLString()
if err != nil {
panic("Could not serialize []ToolCall to XMLFunctionCall")
}
if len(message.Content) > 0 {
message.Content += fmt.Sprintf("\n\n%s", xmlString)
} else {
message.Content = xmlString
}
case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString()
if err != nil {
panic("Could not serialize []ToolResult to XMLFunctionResults")
}
message.Role = "user"
message.Content = xmlString
default:
message.Role = string(msg.Role)
message.Content = msg.Content
}
}
return requestBody
}
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
}
return resp, nil
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
var response Response
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", fmt.Errorf("failed to decode response: %v", err)
}
sb := strings.Builder{}
for _, content := range response.Content {
var reply model.Message
switch content.Type {
case "text":
reply = model.Message{
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
}
}
return sb.String(), nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
request := buildRequest(params, messages)
request.Stream = true
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{}
isToolCall := false
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
if line[0] == '{' {
var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event: %s", line)
}
switch eventType {
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
}
} else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event type")
}
switch eventType {
case "message_start":
// noop
case "ping":
// write an empty string to signal start of text
output <- ""
case "content_block_start":
// ignore?
case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid content block delta")
}
text, ok := delta["text"].(string)
if !ok {
return "", fmt.Errorf("invalid text delta")
}
sb.WriteString(text)
output <- text
case "content_block_stop":
// ignore?
case "message_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid message delta")
}
stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" {
stopSequence, ok := delta["stop_sequence"].(string)
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
content := sb.String()
start := strings.Index(content, "<function_calls>")
if start == -1 {
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
isToolCall = true
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE
// Extract function calls
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
}
// Execute function calls
toolCall := model.Message{
Role: model.MessageRoleToolCall,
// xml stripped from content
Content: content[:start],
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolReply := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolReply)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
}
case "message_stop":
// return the completed message
if callback != nil {
if !isToolCall {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
}
}
return sb.String(), nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
fmt.Printf("\nUnrecognized event: %s\n", data)
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read response body: %v", err)
}
return "", fmt.Errorf("unexpected end of stream")
}

View File

@ -0,0 +1,230 @@
package anthropic
import (
"bytes"
"fmt"
"strings"
"text/template"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
const TOOL_PREAMBLE = `You have access to the following tools when replying.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:`
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
type XMLTools struct {
XMLName struct{} `xml:"tools"`
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
}
type XMLToolDescription struct {
ToolName string `xml:"tool_name"`
Description string `xml:"description"`
Parameters []XMLToolParameter `xml:"parameters>parameter"`
}
type XMLToolParameter struct {
Name string `xml:"name"`
Type string `xml:"type"`
Description string `xml:"description"`
}
type XMLFunctionCalls struct {
XMLName struct{} `xml:"function_calls"`
Invoke []XMLFunctionInvoke `xml:"invoke"`
}
type XMLFunctionInvoke struct {
ToolName string `xml:"tool_name"`
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
}
type XMLFunctionInvokeParameters struct {
String string `xml:",innerxml"`
}
type XMLFunctionResults struct {
XMLName struct{} `xml:"function_results"`
Result []XMLFunctionResult `xml:"result"`
}
type XMLFunctionResult struct {
ToolName string `xml:"tool_name"`
Stdout string `xml:"stdout"`
}
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
// parameters name to value
func parseFunctionParametersXML(params string) map[string]interface{} {
lines := strings.Split(params, "\n")
ret := make(map[string]interface{}, len(lines))
for _, line := range lines {
i := strings.Index(line, ">")
if i == -1 {
continue
}
j := strings.Index(line, "</")
if j == -1 {
continue
}
// chop from after opening < to first > to get parameter name,
// then chop after > to first </ to get parameter value
ret[line[1:i]] = line[i+1 : j]
}
return ret
}
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools {
converted[i].ToolName = tool.Name
converted[i].Description = tool.Description
params := make([]XMLToolParameter, len(tool.Parameters))
for j, param := range tool.Parameters {
params[j].Name = param.Name
params[j].Description = param.Description
params[j].Type = param.Type
}
converted[i].Parameters = params
}
return XMLTools{
ToolDescriptions: converted,
}
}
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
}
return toolCalls
}
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters
var paramXML string
for key, value := range toolCall.Parameters {
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
}
params.String = paramXML
converted[i] = XMLFunctionInvoke{
ToolName: toolCall.Name,
Parameters: params,
}
}
return XMLFunctionCalls{
Invoke: converted,
}
}
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults {
converted[i].ToolName = result.ToolName
converted[i].Stdout = result.Result
}
return XMLFunctionResults{
Result: converted,
}
}
func buildToolsSystemPrompt(tools []model.Tool) string {
xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString()
if err != nil {
panic("Could not serialize []model.Tool to XMLTools")
}
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
}
func (x XMLTools) XMLString() (string, error) {
tmpl, err := template.New("tools").Parse(`<tools>
{{range .ToolDescriptions}}<tool_description>
<tool_name>{{.ToolName}}</tool_name>
<description>
{{.Description}}
</description>
<parameters>
{{range .Parameters}}<parameter>
<name>{{.Name}}</name>
<type>{{.Type}}</type>
<description>{{.Description}}</description>
</parameter>
{{end}}</parameters>
</tool_description>
{{end}}</tools>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionResults) XMLString() (string, error) {
tmpl, err := template.New("function_results").Parse(`<function_results>
{{range .Result}}<result>
<tool_name>{{.ToolName}}</tool_name>
<stdout>{{.Stdout}}</stdout>
</result>
{{end}}</function_results>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionCalls) XMLString() (string, error) {
tmpl, err := template.New("function_calls").Parse(`<function_calls>
{{range .Invoke}}<invoke>
<tool_name>{{.ToolName}}</tool_name>
<parameters>{{.Parameters.String}}</parameters>
</invoke>
{{end}}</function_calls>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}

View File

@ -0,0 +1,278 @@
package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
)
type OpenAIClient struct {
APIKey string
}
type OpenAIToolParameters struct {
Type string `json:"type"`
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type OpenAIToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
func convertTools(tools []model.Tool) []openai.Tool {
openaiTools := make([]openai.Tool, len(tools))
for i, tool := range tools {
openaiTools[i].Type = "function"
params := make(map[string]OpenAIToolParameter)
var required []string
for _, param := range tool.Parameters {
params[param.Name] = OpenAIToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
openaiTools[i].Function = openai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: OpenAIToolParameters{
Type: "object",
Properties: params,
Required: required,
},
}
}
return openaiTools
}
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
converted := make([]openai.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].Type = "function"
converted[i].ID = call.ID
converted[i].Function.Name = call.Name
json, _ := json.Marshal(call.Parameters)
converted[i].Function.Arguments = string(json)
}
return converted
}
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
converted := make([]model.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].ID = call.ID
converted[i].Name = call.Function.Name
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
}
return converted
}
func createChatCompletionRequest(
c *OpenAIClient,
params model.RequestParameters,
messages []model.Message,
) openai.ChatCompletionRequest {
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
message := openai.ChatCompletionMessage{}
message.Role = "assistant"
message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
requestMessages = append(requestMessages, message)
case "tool_result":
// expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults {
message := openai.ChatCompletionMessage{}
message.Role = "tool"
message.Content = result.Result
message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message)
}
default:
message := openai.ChatCompletionMessage{}
message.Role = string(m.Role)
message.Content = m.Content
requestMessages = append(requestMessages, message)
}
}
request := openai.ChatCompletionRequest{
Model: params.Model,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Messages: requestMessages,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
request.ToolChoice = "auto"
}
return request
}
func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []openai.ToolCall,
) ([]model.Message, error) {
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
return []model.Message{toolCall, toolResult}, nil
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
choice := resp.Choices[0]
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
if err != nil {
return "", err
}
if callback != nil {
for _, result := range results {
callback(result)
}
}
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: choice.Message.Content,
})
}
// Return the user-facing message.
return choice.Message.Content, nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return "", err
}
defer stream.Close()
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
}
if e != nil {
err = e
break
}
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, content.String(), toolCalls)
if err != nil {
return content.String(), err
}
if callback != nil {
for _, result := range results {
callback(result)
}
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} else {
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
}
return content.String(), err
}

View File

@ -0,0 +1,31 @@
package provider
import (
"context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type ReplyCallback func(model.Message)
type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string.
CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback ReplyCallback,
) (string, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received.
CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback ReplyCallback,
output chan<- string,
) (string, error)
}

132
pkg/lmcli/store.go Normal file
View File

@ -0,0 +1,132 @@
package lmcli
import (
"database/sql"
"errors"
"fmt"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
sqids "github.com/sqids/sqids-go"
"gorm.io/gorm"
)
type ConversationStore interface {
Conversations() ([]model.Conversation, error)
ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationShortNameCompletions(search string) []string
SaveConversation(conversation *model.Conversation) error
DeleteConversation(conversation *model.Conversation) error
Messages(conversation *model.Conversation) ([]model.Message, error)
LastMessage(conversation *model.Conversation) (*model.Message, error)
SaveMessage(message *model.Message) error
DeleteMessage(message *model.Message) error
UpdateMessage(message *model.Message) error
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
}
type SQLStore struct {
db *gorm.DB
sqids *sqids.Sqids
}
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
models := []any{
&model.Conversation{},
&model.Message{},
}
for _, x := range models {
err := db.AutoMigrate(x)
if err != nil {
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
}
}
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &SQLStore{db, _sqids}, nil
}
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
err := s.db.Save(&conversation).Error
if err != nil {
return err
}
if !conversation.ShortName.Valid {
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error
}
return err
}
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
return s.db.Delete(&conversation).Error
}
func (s *SQLStore) SaveMessage(message *model.Message) error {
return s.db.Create(message).Error
}
func (s *SQLStore) DeleteMessage(message *model.Message) error {
return s.db.Delete(&message).Error
}
func (s *SQLStore) UpdateMessage(message *model.Message) error {
return s.db.Updates(&message).Error
}
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
var conversations []model.Conversation
err := s.db.Find(&conversations).Error
return conversations, err
}
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var completions []string
conversations, _ := s.Conversations() // ignore error for completions
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
}
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
if shortName == "" {
return nil, errors.New("shortName is empty")
}
var conversation model.Conversation
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
var messages []model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
return messages, err
}
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
var message model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
return &message, err
}
// AddReply adds the given messages as a reply to the given conversation, can be
// used to easily copy a message associated with one conversation, to another
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
m.ConversationID = c.ID
m.ID = 0
m.CreatedAt = time.Time{}
return &m, s.SaveMessage(&m)
}

View File

@ -1,22 +1,22 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.` Make sure your inserts match the flow and indentation of surrounding content.`
var FileInsertLinesTool = api.ToolSpec{ var FileInsertLinesTool = model.Tool{
Name: "file_insert_lines", Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION, Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -36,7 +36,7 @@ var FileInsertLinesTool = api.ToolSpec{
Required: true, Required: true,
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.") return "", fmt.Errorf("path parameter to write_file was not included.")
@ -72,27 +72,27 @@ var FileInsertLinesTool = api.ToolSpec{
}, },
} }
func fileInsertLines(path string, position int, content string) api.CallResult { func fileInsertLines(path string, position int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
// Read the existing file's content // Read the existing file's content
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
_, err = os.Create(path) _, err = os.Create(path)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
} }
data = []byte{} data = []byte{}
} }
if position < 1 { if position < 1 {
return api.CallResult{Message: "start_line cannot be less than 1"} return model.CallResult{Message: "start_line cannot be less than 1"}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -107,8 +107,8 @@ func fileInsertLines(path string, position int, content string) api.CallResult {
// Join the lines and write back to the file // Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644) err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return api.CallResult{Result: newContent} return model.CallResult{Result: newContent}
} }

View File

@ -0,0 +1,133 @@
package tools
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
Useful for re-writing snippets/blocks of code or entire functions.
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
var FileReplaceLinesTool = model.Tool{
Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "start_line",
Type: "integer",
Description: `Line number which specifies the start of the replacement range (inclusive).`,
Required: true,
},
{
Name: "end_line",
Type: "integer",
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
},
{
Name: "content",
Type: "string",
Description: `Content to replace specified range. Omit to remove the specified range.`,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start_line int
tmp, ok = args["start_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
}
start_line = int(tmp)
}
var end_line int
tmp, ok = args["end_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
}
end_line = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileReplaceLines(path, start_line, end_line, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
if startLine < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
before := lines[:startLine-1]
after := lines[endLine:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{Result: newContent}
}

View File

@ -1,4 +1,4 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
@ -6,8 +6,8 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
@ -25,17 +25,17 @@ Example result:
For files, size represents the size of the file, in bytes. For files, size represents the size of the file, in bytes.
For directories, size represents the number of entries in that directory.` For directories, size represents the number of entries in that directory.`
var ReadDirTool = api.ToolSpec{ var ReadDirTool = model.Tool{
Name: "read_dir", Name: "read_dir",
Description: READ_DIR_DESCRIPTION, Description: READ_DIR_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "relative_dir", Name: "relative_dir",
Type: "string", Type: "string",
Description: "If set, read the contents of a directory relative to the current one.", Description: "If set, read the contents of a directory relative to the current one.",
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string var relativeDir string
tmp, ok := args["relative_dir"] tmp, ok := args["relative_dir"]
if ok { if ok {
@ -53,18 +53,18 @@ var ReadDirTool = api.ToolSpec{
}, },
} }
func readDir(path string) api.CallResult { func readDir(path string) model.CallResult {
if path == "" { if path == "" {
path = "." path = "."
} }
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
files, err := os.ReadDir(path) files, err := os.ReadDir(path)
if err != nil { if err != nil {
return api.CallResult{ return model.CallResult{
Message: err.Error(), Message: err.Error(),
} }
} }
@ -96,5 +96,5 @@ func readDir(path string) api.CallResult {
}) })
} }
return api.CallResult{Result: dirContents} return model.CallResult{Result: dirContents}
} }

View File

@ -1,16 +1,15 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory. const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
Each line of the returned content is prefixed with its line number and a tab (\t). Each line of the returned content is prefixed with its line number and a tab (\t).
@ -20,10 +19,10 @@ Example result:
"result": "1\tthe contents\n2\tof the file\n" "result": "1\tthe contents\n2\tof the file\n"
}` }`
var ReadFileTool = api.ToolSpec{ var ReadFileTool = model.Tool{
Name: "read_file", Name: "read_file",
Description: READ_FILE_DESCRIPTION, Description: READ_FILE_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -32,7 +31,7 @@ var ReadFileTool = api.ToolSpec{
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.") return "", fmt.Errorf("Path parameter to read_file was not included.")
@ -50,16 +49,23 @@ var ReadFileTool = api.ToolSpec{
}, },
} }
func readFile(path string) api.CallResult { func readFile(path string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
return api.CallResult{
Result: toolutil.AddLineNumbers(string(data)), lines := strings.Split(string(data), "\n")
content := strings.Builder{}
for i, line := range lines {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return model.CallResult{
Result: content.String(),
} }
} }

47
pkg/lmcli/tools/tools.go Normal file
View File

@ -0,0 +1,47 @@
package tools
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
var AvailableTools map[string]model.Tool = map[string]model.Tool{
"read_dir": ReadDirTool,
"read_file": ReadFileTool,
"write_file": WriteFileTool,
"file_insert_lines": FileInsertLinesTool,
"file_replace_lines": FileReplaceLinesTool,
}
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
var toolResults []model.ToolResult
for _, toolCall := range toolCalls {
var tool *model.Tool
for _, available := range toolBag {
if available.Name == toolCall.Name {
tool = &available
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
}
// Execute the tool
result, err := tool.Impl(tool, toolCall.Parameters)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
}
toolResult := model.ToolResult{
ToolCallID: toolCall.ID,
ToolName: toolCall.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -65,14 +65,3 @@ func IsPathWithinCWD(path string) (bool, string) {
} }
return true, "" return true, ""
} }
// AddLineNumbers takes a string of content and returns a new string with line
// numbers prefixed
func AddLineNumbers(content string) string {
lines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
result := strings.Builder{}
for i, line := range lines {
result.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return result.String()
}

View File

@ -1,11 +1,11 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
@ -15,10 +15,10 @@ Example result:
"message": "success" "message": "success"
}` }`
var WriteFileTool = api.ToolSpec{ var WriteFileTool = model.Tool{
Name: "write_file", Name: "write_file",
Description: WRITE_FILE_DESCRIPTION, Description: WRITE_FILE_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -32,7 +32,7 @@ var WriteFileTool = api.ToolSpec{
Required: true, Required: true,
}, },
}, },
Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.") return "", fmt.Errorf("Path parameter to write_file was not included.")
@ -58,14 +58,14 @@ var WriteFileTool = api.ToolSpec{
}, },
} }
func writeFile(path string, content string) api.CallResult { func writeFile(path string, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
err := os.WriteFile(path, []byte(content), 0644) err := os.WriteFile(path, []byte(content), 0644)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return api.CallResult{} return model.CallResult{}
} }

View File

@ -1,447 +0,0 @@
package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
)
const ANTHROPIC_VERSION = "2023-06-01"
type AnthropicClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties map[string]Property `json:"properties"`
Required []string `json:"required"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
System string `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature,omitempty"`
Stream bool `json:"stream"`
}
type ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
partialJsonAccumulator string
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason"`
Usage Usage `json:"usage"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type StreamEvent struct {
Type string `json:"type"`
Message interface{} `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Delta interface{} `json:"delta,omitempty"`
}
func convertTools(tools []api.ToolSpec) []Tool {
anthropicTools := make([]Tool, len(tools))
for i, tool := range tools {
properties := make(map[string]Property)
for _, param := range tool.Parameters {
properties[param.Name] = Property{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
}
var required []string
for _, param := range tool.Parameters {
if param.Required {
required = append(required, param.Name)
}
}
anthropicTools[i] = Tool{
Name: tool.Name,
Description: tool.Description,
InputSchema: InputSchema{
Type: "object",
Properties: properties,
Required: required,
},
}
}
return anthropicTools
}
func createChatCompletionRequest(
params provider.RequestParameters,
messages []api.Message,
) (string, ChatCompletionRequest) {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
var systemMessage string
for _, m := range messages {
if m.Role == api.MessageRoleSystem {
systemMessage = m.Content
continue
}
var content interface{}
role := string(m.Role)
switch m.Role {
case api.MessageRoleToolCall:
role = "assistant"
contentBlocks := make([]map[string]interface{}, 0)
if m.Content != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": m.Content,
})
}
for _, toolCall := range m.ToolCalls {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use",
"id": toolCall.ID,
"name": toolCall.Name,
"input": toolCall.Parameters,
})
}
content = contentBlocks
case api.MessageRoleToolResult:
role = "user"
contentBlocks := make([]map[string]interface{}, 0)
for _, result := range m.ToolResults {
contentBlock := map[string]interface{}{
"type": "tool_result",
"tool_use_id": result.ToolCallID,
"content": result.Result,
}
contentBlocks = append(contentBlocks, contentBlock)
}
content = contentBlocks
default:
content = m.Content
}
requestMessages = append(requestMessages, ChatCompletionMessage{
Role: role,
Content: content,
})
}
request := ChatCompletionRequest{
Model: params.Model,
Messages: requestMessages,
System: systemMessage,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
}
var prefill string
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
// Prompting on an assitant message, use its content as prefill
prefill = messages[len(messages)-1].Content
}
return prefill, request
}
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages")
}
_, req := createChatCompletionRequest(params, messages)
req.Stream = false
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return convertResponseToMessage(completionResp)
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages")
}
prefill, req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
contentBlocks := make(map[int]*ContentBlock)
var finalMessage *ChatCompletionResponse
var firstChunkReceived bool
reader := bufio.NewReader(resp.Body)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, fmt.Errorf("error reading stream: %w", err)
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var streamEvent StreamEvent
err = json.Unmarshal(line, &streamEvent)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
}
switch streamEvent.Type {
case "message_start":
finalMessage = &ChatCompletionResponse{}
err = json.Unmarshal(line, &struct {
Message *ChatCompletionResponse `json:"message"`
}{Message: finalMessage})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
}
case "content_block_start":
var contentBlockStart struct {
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
err = json.Unmarshal(line, &contentBlockStart)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
}
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
case "content_block_delta":
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
}
block := contentBlocks[streamEvent.Index]
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
}
deltaType, ok := delta["type"].(string)
if !ok {
return nil, fmt.Errorf("delta missing type field")
}
switch deltaType {
case "text_delta":
if text, ok := delta["text"].(string); ok {
if !firstChunkReceived {
if prefill == "" {
// if there is no prefil, ensure we trim leading whitespace
text = strings.TrimSpace(text)
}
firstChunkReceived = true
}
block.Text += text
output <- provider.Chunk{
Content: text,
// rough, anthropic performs some chunking
TokenCount: uint(len(strings.Split(text, " "))),
}
}
case "input_json_delta":
if block.Type != "tool_use" {
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
}
if partialJSON, ok := delta["partial_json"].(string); ok {
block.partialJsonAccumulator += partialJSON
}
}
case "content_block_stop":
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
}
block := contentBlocks[streamEvent.Index]
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
var inputData map[string]interface{}
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
}
block.Input = inputData
}
case "message_delta":
if finalMessage == nil {
return nil, fmt.Errorf("received message_delta before message_start")
}
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
}
if stopReason, ok := delta["stop_reason"].(string); ok {
finalMessage.StopReason = stopReason
}
case "message_stop":
// End of the stream
goto END_STREAM
case "error":
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
default:
// Ignore unknown event types
}
}
}
END_STREAM:
if finalMessage == nil {
return nil, fmt.Errorf("no final message received")
}
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
for _, v := range contentBlocks {
finalMessage.Content = append(finalMessage.Content, *v)
}
return convertResponseToMessage(*finalMessage)
}
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
content := strings.Builder{}
var toolCalls []api.ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
content.WriteString(block.Text)
case "tool_use":
parameters, ok := block.Input.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
}
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Name: block.Name,
Parameters: parameters,
})
}
}
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content.String(), toolCalls), nil
}
return api.NewMessageWithAssistant(content.String()), nil
}

View File

@ -1,436 +0,0 @@
package google
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
)
type Client struct {
APIKey string
BaseURL string
}
type ContentPart struct {
Text string `json:"text,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Args map[string]string `json:"args"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response interface{} `json:"response"`
}
type Content struct {
Role string `json:"role"`
Parts []ContentPart `json:"parts"`
}
type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
SystemInstruction *Content `json:"systemInstruction,omitempty"`
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
}
type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
}
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"`
UsageMetadata UsageMetadata `json:"usageMetadata"`
}
type Tool struct {
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
}
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Values []string `json:"values,omitempty"`
}
func convertTools(tools []api.ToolSpec) []Tool {
geminiTools := make([]Tool, len(tools))
for i, tool := range tools {
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
// TODO: proper enum handing
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Values: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
geminiTools[i] = Tool{
FunctionDeclarations: []FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "OBJECT",
Properties: params,
Required: required,
},
},
},
}
}
return geminiTools
}
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls {
args := make(map[string]string)
for k, v := range call.Parameters {
args[k] = fmt.Sprintf("%v", v)
}
converted[i].FunctionCall = &FunctionCall{
Name: call.Name,
Args: args,
}
}
return converted
}
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
converted := make([]api.ToolCall, len(functionCalls))
for i, call := range functionCalls {
params := make(map[string]interface{})
for k, v := range call.Args {
params[k] = v
}
converted[i].Name = call.Name
converted[i].Parameters = params
}
return converted
}
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
err := json.Unmarshal([]byte(result.Result), &obj)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
}
results[i] = FunctionResponse{
Name: result.ToolName,
Response: obj,
}
}
return results, nil
}
func createGenerateContentRequest(
params provider.RequestParameters,
messages []api.Message,
) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages))
startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
system = messages[0].Content
startIdx = 1
}
for _, m := range messages[startIdx:] {
switch m.Role {
case "tool_call":
content := Content{
Role: "model",
Parts: convertToolCallToGemini(m.ToolCalls),
}
requestContents = append(requestContents, content)
case "tool_result":
results, err := convertToolResultsToGemini(m.ToolResults)
if err != nil {
return nil, err
}
// expand tool_result messages' results into multiple gemini messages
for _, result := range results {
content := Content{
Role: "function",
Parts: []ContentPart{
{
FunctionResp: &result,
},
},
}
requestContents = append(requestContents, content)
}
default:
var role string
switch m.Role {
case api.MessageRoleAssistant:
role = "model"
case api.MessageRoleUser:
role = "user"
}
if role == "" {
panic("Unhandled role: " + m.Role)
}
content := Content{
Role: role,
Parts: []ContentPart{
{
Text: m.Content,
},
},
}
requestContents = append(requestContents, content)
}
}
request := &GenerateContentRequest{
Contents: requestContents,
GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature,
TopP: &params.TopP,
},
}
if system != "" {
request.SystemInstruction = &Content{
Parts: []ContentPart{
{
Text: system,
},
},
}
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
}
return request, nil
}
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *Client) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return nil, err
}
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
choice := completionResp.Candidates[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content
}
var toolCalls []FunctionCall
for _, part := range choice.Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
}
}
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
}
return api.NewMessageWithAssistant(content), nil
}
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return nil, err
}
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body)
lastTokenCount := 0
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var resp GenerateContentResponse
err = json.Unmarshal(line, &resp)
if err != nil {
return nil, err
}
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
lastTokenCount += tokens
choice := resp.Candidates[0]
for _, part := range choice.Content.Parts {
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" {
output <- provider.Chunk{
Content: part.Text,
TokenCount: uint(tokens),
}
content.WriteString(part.Text)
}
}
}
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
}
return api.NewMessageWithAssistant(content.String()), nil
}

View File

@ -1,183 +0,0 @@
package ollama
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
)
type OllamaClient struct {
BaseURL string
}
type OllamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type OllamaRequest struct {
Model string `json:"model"`
Messages []OllamaMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message OllamaMessage `json:"message"`
Done bool `json:"done"`
TotalDuration uint64 `json:"total_duration,omitempty"`
LoadDuration uint64 `json:"load_duration,omitempty"`
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
EvalCount uint64 `json:"eval_count,omitempty"`
EvalDuration uint64 `json:"eval_duration,omitempty"`
}
func createOllamaRequest(
params provider.RequestParameters,
messages []api.Message,
) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages))
for _, m := range messages {
message := OllamaMessage{
Role: string(m.Role),
Content: m.Content,
}
requestMessages = append(requestMessages, message)
}
request := OllamaRequest{
Model: params.Model,
Messages: requestMessages,
}
return request
}
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, nil
}
func (c *OllamaClient) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = false
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
return api.NewMessageWithAssistant(completionResp.Message.Content), nil
}
func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = true
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
}
if len(streamResp.Message.Content) > 0 {
output <- provider.Chunk{
Content: streamResp.Message.Content,
TokenCount: 1,
}
content.WriteString(streamResp.Message.Content)
}
}
return api.NewMessageWithAssistant(content.String()), nil
}

View File

@ -1,343 +0,0 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
)
type OpenAIClient struct {
APIKey string
BaseURL string
Headers map[string]string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall struct {
Type string `json:"type"`
ID string `json:"id"`
Index *int `json:"index,omitempty"`
Function FunctionDefinition `json:"function"`
}
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
Arguments string `json:"arguments,omitempty"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Messages []ChatCompletionMessage `json:"messages"`
N int `json:"n"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionChoice struct {
Message ChatCompletionMessage `json:"message"`
}
type ChatCompletionResponse struct {
Choices []ChatCompletionChoice `json:"choices"`
}
type ChatCompletionStreamChoice struct {
Delta ChatCompletionMessage `json:"delta"`
}
type ChatCompletionStreamResponse struct {
Choices []ChatCompletionStreamChoice `json:"choices"`
}
func convertTools(tools []api.ToolSpec) []Tool {
openaiTools := make([]Tool, len(tools))
for i, tool := range tools {
openaiTools[i].Type = "function"
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
openaiTools[i].Function = FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "object",
Properties: params,
Required: required,
},
}
}
return openaiTools
}
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].Type = "function"
converted[i].ID = call.ID
converted[i].Function.Name = call.Name
json, _ := json.Marshal(call.Parameters)
converted[i].Function.Arguments = string(json)
}
return converted
}
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
converted := make([]api.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].ID = call.ID
converted[i].Name = call.Function.Name
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
}
return converted
}
func createChatCompletionRequest(
params provider.RequestParameters,
messages []api.Message,
) ChatCompletionRequest {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
message := ChatCompletionMessage{}
message.Role = "assistant"
message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
requestMessages = append(requestMessages, message)
case "tool_result":
// expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults {
message := ChatCompletionMessage{}
message.Role = "tool"
message.Content = result.Result
message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message)
}
default:
message := ChatCompletionMessage{}
message.Role = string(m.Role)
message.Content = m.Content
requestMessages = append(requestMessages, message)
}
}
request := ChatCompletionRequest{
Model: params.Model,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Messages: requestMessages,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
request.ToolChoice = "auto"
}
return request
}
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
for header, val := range c.Headers {
req.Header.Set(header, val)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createChatCompletionRequest(params, messages)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
choice := completionResp.Choices[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content + choice.Message.Content
} else {
content = choice.Message.Content
}
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
}
return api.NewMessageWithAssistant(content), nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(line, []byte("[DONE]")) {
break
}
var streamResp ChatCompletionStreamResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
}
delta := streamResp.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
}
if len(delta.Content) > 0 {
output <- provider.Chunk{
Content: delta.Content,
TokenCount: 1,
}
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
}
return api.NewMessageWithAssistant(content.String()), nil
}

View File

@ -1,41 +0,0 @@
package provider
import (
"context"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type Chunk struct {
Content string
TokenCount uint
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
Toolbox []api.ToolSpec
}
type ChatCompletionProvider interface {
// CreateChatCompletion generates a chat completion response to the
// provided messages.
CreateChatCompletion(
ctx context.Context,
params RequestParameters,
messages []api.Message,
) (*api.Message, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel.
CreateChatCompletionStream(
ctx context.Context,
params RequestParameters,
messages []api.Message,
chunks chan<- Chunk,
) (*api.Message, error)
}

View File

@ -1,67 +0,0 @@
package bubbles
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type ConfirmPrompt struct {
Question string
Style lipgloss.Style
Payload interface{}
value bool
answered bool
focused bool
}
func NewConfirmPrompt(question string, payload interface{}) ConfirmPrompt {
return ConfirmPrompt{
Question: question,
Style: lipgloss.NewStyle(),
Payload: payload,
focused: true, // focus by default
}
}
type MsgConfirmPromptAnswered struct {
Value bool
Payload interface{}
}
func (b ConfirmPrompt) Update(msg tea.Msg) (ConfirmPrompt, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
if !b.focused || b.answered {
return b, nil
}
switch msg.String() {
case "y", "Y":
b.value = true
b.answered = true
b.focused = false
return b, func() tea.Msg { return MsgConfirmPromptAnswered{true, b.Payload} }
case "n", "N", "esc":
b.value = false
b.answered = true
b.focused = false
return b, func() tea.Msg { return MsgConfirmPromptAnswered{false, b.Payload} }
}
}
return b, nil
}
func (b ConfirmPrompt) View() string {
return b.Style.Render(b.Question) + lipgloss.NewStyle().Faint(true).Render(" (y/n)")
}
func (b *ConfirmPrompt) Focus() {
b.focused = true
}
func (b *ConfirmPrompt) Blur() {
b.focused = false
}
func (b ConfirmPrompt) Focused() bool {
return b.focused
}

View File

@ -1,260 +0,0 @@
package list
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type Option struct {
Label string
Value interface{}
}
type OptionGroup struct {
Name string
Options []Option
}
type Model struct {
ID int
HeaderStyle lipgloss.Style
ItemStyle lipgloss.Style
SelectedStyle lipgloss.Style
ItemRender func(Option, bool) string
Width int
Height int
optionGroups []OptionGroup
selected int
filterInput textinput.Model
filteredIndices []filteredIndex
content viewport.Model
itemYOffsets []int
}
type filteredIndex struct {
groupIndex int
optionIndex int
}
type MsgOptionSelected struct {
ID int
Option Option
}
func New(opts []Option) Model {
return NewWithGroups([]OptionGroup{{Options: opts}})
}
func NewWithGroups(groups []OptionGroup) Model {
ti := textinput.New()
ti.Prompt = "/"
ti.PromptStyle = lipgloss.NewStyle().Faint(true)
m := Model{
HeaderStyle: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("12")).Padding(1, 0, 1, 1),
ItemStyle: lipgloss.NewStyle(),
SelectedStyle: lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6")),
optionGroups: groups,
selected: 0,
filterInput: ti,
filteredIndices: make([]filteredIndex, 0),
content: viewport.New(0, 0),
itemYOffsets: make([]int, 0),
}
m.filterItems()
m.content.SetContent(m.renderOptionsList())
return m
}
func (m *Model) Focused() {
m.filterInput.Focused()
}
func (m *Model) Focus() {
m.filterInput.Focus()
}
func (m *Model) Blur() {
m.filterInput.Blur()
}
func (m *Model) filterItems() {
filterText := strings.ToLower(m.filterInput.Value())
var prevSelection *filteredIndex
if m.selected <= len(m.filteredIndices)-1 {
prevSelection = &m.filteredIndices[m.selected]
}
m.filteredIndices = make([]filteredIndex, 0)
for groupIndex, group := range m.optionGroups {
for optionIndex, option := range group.Options {
if filterText == "" ||
strings.Contains(strings.ToLower(option.Label), filterText) ||
(group.Name != "" && strings.Contains(strings.ToLower(group.Name), filterText)) {
m.filteredIndices = append(m.filteredIndices, filteredIndex{groupIndex, optionIndex})
}
}
}
found := false
if len(m.filteredIndices) > 0 && prevSelection != nil {
// Preserve previous selection if possible
for i, filterIdx := range m.filteredIndices {
if prevSelection.groupIndex == filterIdx.groupIndex && prevSelection.optionIndex == filterIdx.optionIndex {
m.selected = i
found = true
break
}
}
}
if !found {
m.selected = 0
}
}
func (m *Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
if m.filterInput.Focused() {
switch msg.String() {
case "esc":
m.filterInput.Blur()
m.filterInput.SetValue("")
m.filterItems()
m.refreshContent()
return *m, shared.KeyHandled(msg)
case "enter":
m.filterInput.Blur()
m.refreshContent()
break
case "up", "down":
break
default:
m.filterInput, cmd = m.filterInput.Update(msg)
m.filterItems()
m.refreshContent()
return *m, cmd
}
}
switch msg.String() {
case "up", "k":
m.moveSelection(-1)
return *m, shared.KeyHandled(msg)
case "down", "j":
m.moveSelection(1)
return *m, shared.KeyHandled(msg)
case "enter":
return *m, func() tea.Msg {
idx := m.filteredIndices[m.selected]
return MsgOptionSelected{
ID: m.ID,
Option: m.optionGroups[idx.groupIndex].Options[idx.optionIndex],
}
}
case "/":
m.filterInput.Focus()
return *m, textinput.Blink
}
}
m.content, cmd = m.content.Update(msg)
return *m, cmd
}
func (m *Model) refreshContent() {
m.content.SetContent(m.renderOptionsList())
m.ensureSelectedVisible()
}
func (m *Model) ensureSelectedVisible() {
if m.selected == 0 {
m.content.GotoTop()
} else if m.selected == len(m.filteredIndices)-1 {
m.content.GotoBottom()
} else {
tuiutil.ScrollIntoView(&m.content, m.itemYOffsets[m.selected], 0)
}
}
func (m *Model) moveSelection(delta int) {
prev := m.selected
m.selected = min(len(m.filteredIndices)-1, max(0, m.selected+delta))
if prev != m.selected {
m.refreshContent()
}
}
func (m *Model) View() string {
filter := ""
if m.filterInput.Focused() {
m.filterInput.Width = m.Width
filter = m.filterInput.View()
}
contentHeight := m.Height - tuiutil.Height(filter)
m.content.Width, m.content.Height = m.Width, contentHeight
parts := []string{m.content.View()}
if filter != "" {
parts = append(parts, filter)
}
return lipgloss.JoinVertical(lipgloss.Left, parts...)
}
func (m *Model) renderOptionsList() string {
yOffset := 0
lastGroupIndex := -1
m.itemYOffsets = make([]int, len(m.filteredIndices))
var sb strings.Builder
for i, idx := range m.filteredIndices {
if idx.groupIndex != lastGroupIndex {
group := m.optionGroups[idx.groupIndex].Name
if group != "" {
headingStr := m.HeaderStyle.Render(group)
yOffset += tuiutil.Height(headingStr)
sb.WriteString(headingStr)
sb.WriteRune('\n')
}
lastGroupIndex = idx.groupIndex
}
m.itemYOffsets[i] = yOffset
option := m.optionGroups[idx.groupIndex].Options[idx.optionIndex]
var item string
if m.ItemRender != nil {
item = m.ItemRender(option, i == m.selected)
} else {
prefix := " "
if i == m.selected {
prefix = "> "
item = m.SelectedStyle.Render(option.Label)
} else {
item = m.ItemStyle.Render(option.Label)
}
item = fmt.Sprintf("%s%s", prefix, item)
}
sb.WriteString(item)
yOffset += tuiutil.Height(item)
if i < len(m.filteredIndices)-1 {
sb.WriteRune('\n')
}
}
return sb.String()
}

View File

@ -1,281 +0,0 @@
package model
import (
"context"
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"github.com/charmbracelet/lipgloss"
)
type AppModel struct {
Ctx *lmcli.Context
Conversations conversation.ConversationList
Conversation conversation.Conversation
Messages []conversation.Message
Model string
ProviderName string
Provider provider.ChatCompletionProvider
Agent *lmcli.Agent
}
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Model: *ctx.Config.Defaults.Model,
}
if initialConversation == nil {
app.NewConversation()
} else {
}
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model
app.ProviderName = provider
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app
}
var (
defaultStyle = lipgloss.NewStyle().Faint(true)
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
)
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
}
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
func (m *AppModel) ClearConversation() {
m.Conversation = conversation.Conversation{}
m.Messages = []conversation.Message{}
}
func (m *AppModel) ApplySystemPrompt() {
var system string
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system == "" {
system = m.Ctx.DefaultSystemPrompt()
}
if system != "" {
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
}
}
func (m *AppModel) NewConversation() {
m.ClearConversation()
m.ApplySystemPrompt()
}
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
}
}
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
return a.Ctx.Conversations.UpdateMessage(message)
}
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
if len(conv.RootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Conversations.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = a.Ctx.Conversations.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistMessages() ([]conversation.Message, error) {
messages := make([]conversation.Message, len(a.Messages))
for i, m := range a.Messages {
if i == 0 && m.ID == 0 {
m.Conversation = &a.Conversation
m, err := a.Ctx.Conversations.SaveMessage(m)
if err != nil {
return nil, fmt.Errorf("Could not create first message %d: %v", a.Messages[i].ID, err)
}
messages[i] = *m
// let's set the conversation root message(s), as this is the first message
m.Conversation.RootMessages = []conversation.Message{*m}
m.Conversation.SelectedRoot = &m.Conversation.RootMessages[0]
a.Ctx.Conversations.UpdateConversation(m.Conversation)
} else if m.ID > 0 {
// Existing message, update it
err := a.Ctx.Conversations.UpdateMessage(&m)
if err != nil {
return nil, fmt.Errorf("Could not update message %d: %v", a.Messages[i].ID, err)
}
messages[i] = m
} else if i > 0 {
// New message, reply to previous
replies, err := a.Ctx.Conversations.Reply(&messages[i-1], m)
if err != nil {
return nil, fmt.Errorf("Could not reply with new message: %v", err)
}
messages[i] = replies[0]
} else {
return nil, fmt.Errorf("No messages to reply to (this is a bug)")
}
}
return messages, nil
}
func (a *AppModel) PersistConversation() (conversation.Conversation, error) {
conv := a.Conversation
var err error
if a.Conversation.ID > 0 {
err = a.Ctx.Conversations.UpdateConversation(&conv)
} else {
c, e := a.Ctx.Conversations.CreateConversation("")
err = e
if e == nil && c != nil {
conv = *c
}
}
return conv, err
}
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent == nil {
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
}
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
}
func (a *AppModel) Prompt(
messages []conversation.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*conversation.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
}
params := provider.RequestParameters{
Model: model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature,
}
if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-stopSignal:
cancel()
}
}()
msg, err := p.CreateChatCompletionStream(
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
)
if msg != nil {
msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
return &msg, err
}
return nil, err
}

View File

@ -1,66 +0,0 @@
package shared
import (
tea "github.com/charmbracelet/bubbletea"
)
// An analogue to tea.Model with support for checking if the model has been
// initialized before
type ViewModel interface {
Init() tea.Cmd
Update(tea.Msg) (ViewModel, tea.Cmd)
// View methods
Header(width int) string
// Render the view's main content into a container of the given dimensions
Content(width, height int) string
Footer(width int) string
}
type View int
const (
ViewChat View = iota
ViewConversations
ViewSettings
//StateHelp
)
type (
// send to change the current state
MsgViewChange View
// sent to a state when it is entered, with the view we're leaving
MsgViewEnter View
// sent when a recoverable error occurs (displayed to user)
MsgError struct { Err error }
// sent when the view has handled a key input
MsgKeyHandled tea.KeyMsg
)
func ViewEnter(from View) tea.Cmd {
return func() tea.Msg {
return MsgViewEnter(from)
}
}
func ChangeView(view View) tea.Cmd {
return func() tea.Msg {
return MsgViewChange(view)
}
}
func KeyHandled(key tea.KeyMsg) tea.Cmd {
return func() tea.Msg {
return MsgKeyHandled(key)
}
}
func WrapError(err error) tea.Cmd {
return func() tea.Msg {
return MsgError{ Err: err }
}
}
func AsMsgError(err error) MsgError {
return MsgError{ Err: err }
}

View File

@ -1,8 +0,0 @@
package styles
import "github.com/charmbracelet/lipgloss"
var Header = lipgloss.NewStyle().
PaddingLeft(1).
PaddingRight(1).
Background(lipgloss.Color("0"))

View File

@ -1,163 +1,845 @@
package tui package tui
import ( // The terminal UI for lmcli, launched from the `lmcli chat` command
"fmt" // TODO:
// - conversation list view
// - change model
// - rename conversation
// - set system prompt
// - system prompt library?
"git.mlow.ca/mlow/lmcli/pkg/conversation" import (
"context"
"fmt"
"strings"
"time"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui/model" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "github.com/charmbracelet/bubbles/spinner"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" "github.com/charmbracelet/bubbles/textarea"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/chat" "github.com/charmbracelet/bubbles/viewport"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/settings"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/wordwrap"
) )
type Model struct { type focusState int
App *model.AppModel
// window size const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type model struct {
width int width int
height int height int
// errors to display ctx *lmcli.Context
// TODO: allow dismissing errors convShortname string
errs []error
activeView shared.View // application state
views map[shared.View]shared.ViewModel conversation *models.Conversation
messages []models.Message
waitingForReply bool
editorTarget editorTarget
stopSignal chan interface{}
replyChan chan models.Message
replyChunkChan chan string
persistence bool // whether we will save new messages in the conversation
err error
// ui state
focus focusState
wrap bool // whether message content is wrapped to viewport width
status string // a general status message
highlightCache []string // a cache of syntax highlighted message content
messageOffsets []int
selectedMessage int
// ui elements
content viewport.Model
input textarea.Model
spinner spinner.Model
} }
func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model { type message struct {
app := model.NewAppModel(ctx, opts.InitialConversation) role string
content string
}
m := Model{ // custom tea.Msg types
App: app, type (
activeView: opts.InitialView, // sent on each chunk received from LLM
views: map[shared.View]shared.ViewModel{ msgResponseChunk string
shared.ViewChat: chat.Chat(app), // sent when response is finished being received
shared.ViewConversations: conversations.Conversations(app), msgResponseEnd string
shared.ViewSettings: settings.Settings(app), // a special case of msgError that stops the response waiting animation
}, msgResponseError error
// sent on each completed reply
msgAssistantReply models.Message
// sent when a conversation is (re)loaded
msgConversationLoaded *models.Conversation
// sent when a new conversation title is set
msgConversationTitleChanged string
// send when a conversation's messages are laoded
msgMessagesLoaded []models.Message
// sent when an error occurs
msgError error
)
// styles
var (
userStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("10"))
assistantStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("12"))
messageStyle = lipgloss.NewStyle().PaddingLeft(2).PaddingRight(2)
headerStyle = lipgloss.NewStyle().
Background(lipgloss.Color("0"))
conversationStyle = lipgloss.NewStyle().
MarginTop(1).
MarginBottom(1)
footerStyle = lipgloss.NewStyle().
BorderTop(true).
BorderStyle(lipgloss.NormalBorder())
)
func (m model) Init() tea.Cmd {
return tea.Batch(
textarea.Blink,
m.spinner.Tick,
m.loadConversation(m.convShortname),
m.waitForChunk(),
m.waitForReply(),
)
}
func wrapError(err error) tea.Cmd {
return func() tea.Msg {
return msgError(err)
} }
return &m
} }
func (m *Model) Init() tea.Cmd { func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
for _, v := range m.views {
// Init views switch msg := msg.(type) {
cmds = append(cmds, v.Init()) case msgTempfileEditorClosed:
contents := string(msg)
switch m.editorTarget {
case input:
m.input.SetValue(contents)
case selectedMessage:
m.setMessageContents(m.selectedMessage, contents)
if m.persistence && m.messages[m.selectedMessage].ID > 0 {
// update persisted message
err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage])
if err != nil {
cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err)))
} }
cmds = append(cmds, func() tea.Msg { }
// Initial view change m.updateContent()
return shared.MsgViewChange(m.activeView) }
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c":
if m.waitingForReply {
m.stopSignal <- ""
} else {
return m, tea.Quit
}
case "ctrl+p":
m.persistence = !m.persistence
case "ctrl+w":
m.wrap = !m.wrap
m.updateContent()
case "q":
if m.focus != focusInput {
return m, tea.Quit
}
default:
var inputHandled tea.Cmd
switch m.focus {
case focusInput:
inputHandled = m.handleInputKey(msg)
case focusMessages:
inputHandled = m.handleMessagesKey(msg)
}
if inputHandled != nil {
return m, inputHandled
}
}
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.content.Width = msg.Width
m.content.Height = msg.Height - m.getFixedComponentHeight()
m.input.SetWidth(msg.Width - 1)
m.updateContent()
case msgConversationLoaded:
m.conversation = (*models.Conversation)(msg)
cmds = append(cmds, m.loadMessages(m.conversation))
case msgMessagesLoaded:
m.setMessages(msg)
m.updateContent()
case msgResponseChunk:
chunk := string(msg)
last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant {
m.setMessageContents(last, m.messages[last].Content+chunk)
} else {
m.addMessage(models.Message{
Role: models.MessageRoleAssistant,
Content: chunk,
}) })
return tea.Batch(cmds...) }
m.updateContent()
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
case msgAssistantReply:
// the last reply that was being worked on is finished
reply := models.Message(msg)
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgReply")
}
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
if m.messages[last].Role == models.MessageRoleAssistant {
// the last message was an assistant message, so this is a continuation
if reply.Role == models.MessageRoleToolCall {
// update last message rrole to tool call
m.messages[last].Role = models.MessageRoleToolCall
}
} else {
m.addMessage(reply)
}
if m.persistence {
var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil {
cmds = append(cmds, wrapError(err))
} else {
cmds = append(cmds, m.persistConversation())
}
}
if m.conversation.Title == "" {
cmds = append(cmds, m.generateConversationTitle())
}
m.updateContent()
cmds = append(cmds, m.waitForReply())
case msgResponseEnd:
m.waitingForReply = false
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgResponseEnd")
}
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.updateContent()
m.status = "Press ctrl+s to send"
case msgResponseError:
m.waitingForReply = false
m.status = "Press ctrl+s to send"
m.err = error(msg)
case msgConversationTitleChanged:
title := string(msg)
m.conversation.Title = title
if m.persistence {
err := m.ctx.Store.SaveConversation(m.conversation)
if err != nil {
cmds = append(cmds, wrapError(err))
}
}
case msgError:
m.err = error(msg)
}
var cmd tea.Cmd
m.spinner, cmd = m.spinner.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
inputCaptured := false
m.input, cmd = m.input.Update(msg)
if cmd != nil {
inputCaptured = true
cmds = append(cmds, cmd)
}
if !inputCaptured {
m.content, cmd = m.content.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
}
return m, tea.Batch(cmds...)
} }
func (m *Model) handleGlobalInput(msg tea.KeyMsg) tea.Cmd { func (m model) View() string {
view, cmd := m.views[m.activeView].Update(msg) if m.width == 0 {
m.views[m.activeView] = view // this is the case upon initial startup, but it's also a safe bet that
if cmd != nil { // we can just skip rendering if the terminal is really 0 width...
return cmd // without this, the m.*View() functions may crash
return ""
} }
sections := make([]string, 0, 6)
sections = append(sections, m.headerView())
sections = append(sections, m.contentView())
error := m.errorView()
if error != "" {
sections = append(sections, error)
}
sections = append(sections, m.inputView())
sections = append(sections, m.footerView())
return lipgloss.JoinVertical(
lipgloss.Left,
sections...,
)
}
// returns the total height of "fixed" components, which are those which don't
// change height dependent on window size.
func (m *model) getFixedComponentHeight() int {
h := 0
h += m.input.Height()
h += lipgloss.Height(m.headerView())
h += lipgloss.Height(m.footerView())
errorView := m.errorView()
if errorView != "" {
h += lipgloss.Height(errorView)
}
return h
}
func (m *model) headerView() string {
titleStyle := lipgloss.NewStyle().
PaddingLeft(1).
PaddingRight(1).
Bold(true)
var title string
if m.conversation != nil && m.conversation.Title != "" {
title = m.conversation.Title
} else {
title = "Untitled"
}
part := titleStyle.Render(title)
return headerStyle.Width(m.width).Render(part)
}
func (m *model) contentView() string {
return m.content.View()
}
func (m *model) errorView() string {
if m.err == nil {
return ""
}
return lipgloss.NewStyle().
Width(m.width).
AlignHorizontal(lipgloss.Center).
Bold(true).
Foreground(lipgloss.Color("1")).
Render(fmt.Sprintf("%s", m.err))
}
func (m *model) inputView() string {
return m.input.View()
}
func (m *model) footerView() string {
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
segmentSeparator := "|"
savingStyle := segmentStyle.Copy().Bold(true)
saving := ""
if m.persistence {
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾")
} else {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
}
status := m.status
if m.waitingForReply {
status += m.spinner.View()
}
leftSegments := []string{
saving,
segmentStyle.Render(status),
}
rightSegments := []string{
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
}
left := strings.Join(leftSegments, segmentSeparator)
right := strings.Join(rightSegments, segmentSeparator)
totalWidth := lipgloss.Width(left) + lipgloss.Width(right)
remaining := m.width - totalWidth
var padding string
if remaining > 0 {
padding = strings.Repeat(" ", remaining)
}
footer := left + padding + right
if remaining < 0 {
ellipses := "... "
// this doesn't work very well, due to trying to trim a string with
// ansii chars already in it
footer = footer[:(len(footer)+remaining)-len(ellipses)-3] + ellipses
}
return footerStyle.Width(m.width).Render(footer)
}
func initialModel(ctx *lmcli.Context, convShortname string) model {
m := model{
ctx: ctx,
convShortname: convShortname,
conversation: &models.Conversation{},
persistence: true,
stopSignal: make(chan interface{}),
replyChan: make(chan models.Message),
replyChunkChan: make(chan string),
wrap: true,
selectedMessage: -1,
}
m.content = viewport.New(0, 0)
m.input = textarea.New()
m.input.CharLimit = 0
m.input.Placeholder = "Enter a message"
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
m.input.ShowLineNumbers = false
m.input.SetHeight(4)
m.input.Focus()
m.spinner = spinner.New(spinner.WithSpinner(
spinner.Spinner{
Frames: []string{
". ",
".. ",
"...",
".. ",
". ",
" ",
},
FPS: time.Second / 3,
},
))
m.waitingForReply = false
m.status = "Press ctrl+s to send"
return m
}
// fraction is the fraction of the total screen height into view the offset
// should be scrolled into view. 0.5 = items will be snapped to middle of
// view
func scrollIntoView(vp *viewport.Model, offset int, fraction float32) {
currentOffset := vp.YOffset
if offset >= currentOffset && offset < currentOffset+vp.Height {
return
}
distance := currentOffset - offset
if distance < 0 {
// we should scroll down until it just comes into view
vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1)
} else {
// we should scroll up
vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction))
}
}
func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() { switch msg.String() {
case "ctrl+c", "ctrl+q": case "tab":
return tea.Quit m.focus = focusInput
m.updateContent()
m.input.Focus()
case "e":
message := m.messages[m.selectedMessage]
cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
m.editorTarget = selectedMessage
return cmd
case "ctrl+k":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage--
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+j":
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage++
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+r":
// resubmit the conversation with all messages up until and including
// the selected message
if len(m.messages) == 0 {
return nil
}
m.messages = m.messages[:m.selectedMessage+1]
m.highlightCache = m.highlightCache[:m.selectedMessage+1]
m.updateContent()
m.content.GotoBottom()
return m.promptLLM()
} }
return nil return nil
} }
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
switch msg := msg.(type) { switch msg.String() {
case tea.WindowSizeMsg: case "esc":
m.width, m.height = msg.Width, msg.Height m.focus = focusMessages
case tea.KeyMsg: if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
cmd := m.handleGlobalInput(msg) m.selectedMessage = len(m.messages) - 1
}
m.updateContent()
m.input.Blur()
case "ctrl+s":
userInput := strings.TrimSpace(m.input.Value())
if strings.TrimSpace(userInput) == "" {
return nil
}
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser {
return wrapError(fmt.Errorf("Can't reply to a user message"))
}
reply := models.Message{
Role: models.MessageRoleUser,
Content: userInput,
}
if m.persistence {
var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil {
return wrapError(err)
}
// ensure all messages up to the one we're about to add are
// persistent
cmd := m.persistConversation()
if cmd != nil { if cmd != nil {
return m, cmd return cmd
} }
case shared.MsgViewChange: // persist our new message, returning with any possible errors
currView := m.activeView savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
m.activeView = shared.View(msg) if err != nil {
return m, tea.Batch(tea.WindowSize(), shared.ViewEnter(currView)) return wrapError(err)
case shared.MsgError: }
m.errs = append(m.errs, msg.Err) reply = *savedReply
} }
view, cmd := m.views[m.activeView].Update(msg) m.input.SetValue("")
m.views[m.activeView] = view m.addMessage(reply)
return m, cmd
m.updateContent()
m.content.GotoBottom()
return m.promptLLM()
case "ctrl+e":
cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input
return cmd
}
return nil
} }
func (m *Model) View() string { func (m *model) loadConversation(shortname string) tea.Cmd {
if m.width == 0 || m.height == 0 { return func() tea.Msg {
// we're dimensionless! if shortname == "" {
return "" return nil
} }
c, err := m.ctx.Store.ConversationByShortName(shortname)
header := m.views[m.activeView].Header(m.width) if err != nil {
footer := m.views[m.activeView].Footer(m.width) return msgError(fmt.Errorf("Could not lookup conversation: %v", err))
fixedUIHeight := tuiutil.Height(header) + tuiutil.Height(footer)
errBanners := make([]string, len(m.errs))
for idx, err := range m.errs {
errBanners[idx] = tuiutil.ErrorBanner(err, m.width)
fixedUIHeight += tuiutil.Height(errBanners[idx])
} }
if c.ID == 0 {
content := m.views[m.activeView].Content(m.width, m.height-fixedUIHeight) return msgError(fmt.Errorf("Conversation not found: %s", shortname))
sections := make([]string, 0, 4)
if header != "" {
sections = append(sections, header)
} }
if content != "" { return msgConversationLoaded(c)
sections = append(sections, content)
}
if footer != "" {
sections = append(sections, footer)
}
for _, errBanner := range errBanners {
sections = append(sections, errBanner)
}
return lipgloss.JoinVertical(lipgloss.Left, sections...)
}
type LaunchOptions struct {
InitialConversation *conversation.Conversation
InitialView shared.View
}
type LaunchOption func(*LaunchOptions)
func WithInitialConversation(conv *conversation.Conversation) LaunchOption {
return func(opts *LaunchOptions) {
opts.InitialConversation = conv
} }
} }
func WithInitialView(view shared.View) LaunchOption { func (m *model) loadMessages(c *models.Conversation) tea.Cmd {
return func(opts *LaunchOptions) { return func() tea.Msg {
opts.InitialView = view messages, err := m.ctx.Store.Messages(c)
if err != nil {
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
return msgMessagesLoaded(messages)
} }
} }
func Launch(ctx *lmcli.Context, options ...LaunchOption) error { func (m *model) waitForReply() tea.Cmd {
opts := &LaunchOptions{ return func() tea.Msg {
InitialView: shared.ViewChat, return msgAssistantReply(<-m.replyChan)
} }
for _, opt := range options { }
opt(opts)
func (m *model) waitForChunk() tea.Cmd {
return func() tea.Msg {
return msgResponseChunk(<-m.replyChunkChan)
}
}
func (m *model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
if err != nil {
return msgError(err)
}
return msgConversationTitleChanged(title)
}
}
func (m *model) promptLLM() tea.Cmd {
m.waitingForReply = true
m.status = "Press ctrl+c to cancel"
return func() tea.Msg {
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
if err != nil {
return msgError(err)
} }
program := tea.NewProgram(initialModel(ctx, *opts), tea.WithAltScreen()) requestParams := models.RequestParameters{
if _, err := program.Run(); err != nil { Model: *m.ctx.Config.Defaults.Model,
MaxTokens: *m.ctx.Config.Defaults.MaxTokens,
Temperature: *m.ctx.Config.Defaults.Temperature,
ToolBag: m.ctx.EnabledTools,
}
replyHandler := func(msg models.Message) {
m.replyChan <- msg
}
ctx, cancel := context.WithCancel(context.Background())
canceled := false
go func() {
select {
case <-m.stopSignal:
canceled = true
cancel()
}
}()
resp, err := completionProvider.CreateChatCompletionStream(
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
)
if err != nil && !canceled {
return msgResponseError(err)
}
return msgResponseEnd(resp)
}
}
func (m *model) persistConversation() tea.Cmd {
existingMessages, err := m.ctx.Store.Messages(m.conversation)
if err != nil {
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
}
existingById := make(map[uint]*models.Message, len(existingMessages))
for _, msg := range existingMessages {
existingById[msg.ID] = &msg
}
currentById := make(map[uint]*models.Message, len(m.messages))
for _, msg := range m.messages {
currentById[msg.ID] = &msg
}
for _, msg := range existingMessages {
_, ok := currentById[msg.ID]
if !ok {
err := m.ctx.Store.DeleteMessage(&msg)
if err != nil {
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
}
}
}
for i, msg := range m.messages {
if msg.ID > 0 {
exist, ok := existingById[msg.ID]
if ok {
if msg.Content == exist.Content {
continue
}
// update message when contents don't match that of store
err := m.ctx.Store.UpdateMessage(&msg)
if err != nil {
return wrapError(err)
}
} else {
// this would be quite odd... and I'm not sure how to handle
// it at the time of writing this
}
} else {
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
}
}
return nil
}
func (m *model) setMessages(messages []models.Message) {
m.messages = messages
m.highlightCache = make([]string, len(messages))
for i, msg := range m.messages {
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
m.highlightCache[i] = highlighted
}
}
func (m *model) setMessage(i int, msg models.Message) {
if i >= len(m.messages) {
panic("i out of range")
}
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
m.messages[i] = msg
m.highlightCache[i] = highlighted
}
func (m *model) addMessage(msg models.Message) {
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
m.messages = append(m.messages, msg)
m.highlightCache = append(m.highlightCache, highlighted)
}
func (m *model) setMessageContents(i int, content string) {
if i >= len(m.messages) {
panic("i out of range")
}
highlighted, _ := m.ctx.Chroma.HighlightS(content)
m.messages[i].Content = content
m.highlightCache[i] = highlighted
}
func (m *model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationView())
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
}
// render the conversation into a string
func (m *model) conversationView() string {
sb := strings.Builder{}
msgCnt := len(m.messages)
m.messageOffsets = make([]int, len(m.messages))
lineCnt := conversationStyle.GetMarginTop()
for i, message := range m.messages {
m.messageOffsets[i] = lineCnt
icon := "⚙️"
friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Bold(true).Faint(true)
switch message.Role {
case models.MessageRoleUser:
icon = ""
style = userStyle
case models.MessageRoleAssistant:
icon = ""
style = assistantStyle
case models.MessageRoleToolCall, models.MessageRoleToolResult:
icon = "🔧"
}
// write message heading with space for content
user := style.Render(icon + friendly)
var prefix string
var suffix string
faint := lipgloss.NewStyle().Faint(true)
if m.focus == focusMessages {
if i == m.selectedMessage {
prefix = "> "
}
suffix += faint.Render(fmt.Sprintf(" (%d/%d)", i+1, msgCnt))
}
if message.ID == 0 {
suffix += faint.Render(" (not saved)")
}
header := lipgloss.NewStyle().PaddingLeft(1).Render(prefix + user + suffix)
sb.WriteString(header)
lineCnt += lipgloss.Height(header)
// TODO: special rendering for tool calls/results?
if message.Content != "" {
sb.WriteString("\n\n")
lineCnt += 1
// write message contents
var highlighted string
if m.highlightCache[i] == "" {
highlighted = message.Content
} else {
highlighted = m.highlightCache[i]
}
var contents string
if m.wrap {
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 2
wrapped := wordwrap.String(highlighted, wrapWidth)
contents = wrapped
} else {
contents = highlighted
}
sb.WriteString(messageStyle.Width(0).Render(contents))
lineCnt += lipgloss.Height(contents)
}
if i < msgCnt-1 {
sb.WriteString("\n\n")
lineCnt += 1
}
}
return conversationStyle.Render(sb.String())
}
func Launch(ctx *lmcli.Context, convShortname string) error {
p := tea.NewProgram(initialModel(ctx, convShortname), tea.WithAltScreen())
if _, err := p.Run(); err != nil {
return fmt.Errorf("Error running program: %v", err) return fmt.Errorf("Error running program: %v", err)
} }
return nil return nil

42
pkg/tui/util.go Normal file
View File

@ -0,0 +1,42 @@
package tui
import (
"os"
"os/exec"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type msgTempfileEditorClosed string
// openTempfileEditor opens an $EDITOR on a new temporary file with the given
// content. Upon closing, the contents of the file are read back returned
// wrapped in a msgTempfileEditorClosed returned by the tea.Cmd
func openTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
msgFile, _ := os.CreateTemp("/tmp", pattern)
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
if err != nil {
return wrapError(err)
}
editor := os.Getenv("EDITOR")
if editor == "" {
editor = "vim"
}
c := exec.Command(editor, msgFile.Name())
return tea.ExecProcess(c, func(err error) tea.Msg {
bytes, err := os.ReadFile(msgFile.Name())
if err != nil {
return msgError(err)
}
fileContents := string(bytes)
if strings.HasPrefix(fileContents, placeholder) {
fileContents = fileContents[len(placeholder):]
}
stripped := strings.Trim(fileContents, "\n \t")
return msgTempfileEditorClosed(stripped)
})
}

View File

@ -1,137 +0,0 @@
package util
import (
"fmt"
"os"
"os/exec"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/ansi"
)
type MsgTempfileEditorClosed string
// OpenTempfileEditor opens $EDITOR on a temporary file with the given content.
// Upon closing, the contents of the file are read and returned wrapped in a
// MsgTempfileEditorClosed
func OpenTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
msgFile, _ := os.CreateTemp("/tmp", pattern)
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
if err != nil {
return func() tea.Msg { return err }
}
editor := os.Getenv("EDITOR")
if editor == "" {
editor = "vim"
}
c := exec.Command(editor, msgFile.Name())
return tea.ExecProcess(c, func(err error) tea.Msg {
bytes, err := os.ReadFile(msgFile.Name())
if err != nil {
return err
}
os.Remove(msgFile.Name())
fileContents := string(bytes)
if strings.HasPrefix(fileContents, placeholder) {
fileContents = fileContents[len(placeholder):]
}
stripped := strings.Trim(fileContents, "\n \t")
return MsgTempfileEditorClosed(stripped)
})
}
// similar to lipgloss.Height, except returns 0 instead of 1 on empty strings
func Height(str string) int {
if str == "" {
return 0
}
return strings.Count(str, "\n") + 1
}
func Width(str string) int {
if str == "" {
return 0
}
return ansi.PrintableRuneWidth(str)
}
func TruncateRightToCellWidth(str string, width int, tail string) string {
cellWidth := ansi.PrintableRuneWidth(str)
if cellWidth <= width {
return str
}
tailWidth := ansi.PrintableRuneWidth(tail)
if width <= tailWidth {
return tail[:width]
}
targetWidth := width - tailWidth
runes := []rune(str)
for i := len(runes) - 1; i >= 0; i-- {
str = string(runes[:i])
if ansi.PrintableRuneWidth(str) <= targetWidth {
return str + tail
}
}
return tail
}
func TruncateLeftToCellWidth(str string, width int, tail string) string {
cellWidth := ansi.PrintableRuneWidth(str)
if cellWidth <= width {
return str
}
tailWidth := ansi.PrintableRuneWidth(tail)
if width <= tailWidth {
return tail[:width]
}
targetWidth := width - tailWidth
runes := []rune(str)
for i := 0; i < len(runes); i++ {
str = string(runes[i:])
if ansi.PrintableRuneWidth(str) <= targetWidth {
return tail + str
}
}
return tail
}
func ScrollIntoView(vp *viewport.Model, offset int, edge int) {
currentOffset := vp.YOffset
if offset >= currentOffset && offset < currentOffset+vp.Height {
return
}
distance := currentOffset - offset
if distance < 0 {
// we should scroll down until it just comes into view
vp.SetYOffset(currentOffset - (distance + (vp.Height - edge)) + 1)
} else {
// we should scroll up
vp.SetYOffset(currentOffset - distance - edge)
}
}
func ErrorBanner(err error, width int) string {
if err == nil {
return ""
}
return lipgloss.NewStyle().
Width(width).
AlignHorizontal(lipgloss.Center).
Bold(true).
Foreground(lipgloss.Color("1")).
Render(fmt.Sprintf("%s", err))
}

View File

@ -1,168 +0,0 @@
package chat
import (
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/textarea"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// custom tea.Msg types
type (
// sent when a new conversation title generated
msgConversationTitleGenerated string
// sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted conversation.Conversation
msgMessagesPersisted []conversation.Message
// sent when a conversation's messages are laoded
msgConversationMessagesLoaded struct {
messages []conversation.Message
}
// a special case of common.MsgError that stops the response waiting animation
msgChatResponseError struct {
Err error
}
// sent on each chunk received from LLM
msgChatResponseChunk provider.Chunk
// sent on each completed reply
msgChatResponse conversation.Message
// sent when the response is canceled
msgChatResponseCanceled struct{}
// sent when results from a tool call are returned
msgToolResults []api.ToolResult
// sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *conversation.Message
// sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *conversation.Message
// sent when a message's contents are updated and saved
msgMessageUpdated *conversation.Message
// sent when a message is cloned, with the cloned message
msgMessageCloned *conversation.Message
)
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type state int
const (
idle state = iota
loading
pendingResponse
)
type Model struct {
// App state
App *model.AppModel
Height int
Width int
// Chat view state
state state // current overall status of the view
selectedMessage int
editorTarget editorTarget
stopSignal chan struct{}
replyChan chan conversation.Message
chatReplyChunks chan provider.Chunk
persistence bool // whether we will save new messages in the conversation
// UI state
focus focusState
showDetails bool // whether various details are shown in the UI (e.g. system prompt, tool calls/results, message metadata)
wrap bool // whether message content is wrapped to viewport width
messageCache []string // cache of syntax highlighted and wrapped message content
messageOffsets []int
// ui elements
content viewport.Model
input textarea.Model
spinner spinner.Model
replyCursor cursor.Model // cursor to indicate incoming response
// metrics
tokenCount uint
startTime time.Time
elapsed time.Duration
}
func getSpinner() spinner.Model {
return spinner.New(spinner.WithSpinner(
spinner.Spinner{
Frames: []string{
"∙∙∙",
"●∙∙",
"●●∙",
"●●●",
"∙●●",
"∙∙●",
"∙∙∙",
"∙∙●",
"∙●●",
"●●●",
"●●∙",
"●∙∙",
},
FPS: 440 * time.Millisecond,
},
))
}
func Chat(app *model.AppModel) *Model {
m := Model{
App: app,
state: idle,
persistence: true,
stopSignal: make(chan struct{}),
replyChan: make(chan conversation.Message),
chatReplyChunks: make(chan provider.Chunk),
wrap: true,
selectedMessage: -1,
content: viewport.New(0, 0),
input: textarea.New(),
spinner: getSpinner(),
replyCursor: cursor.New(),
}
m.replyCursor.SetChar(" ")
m.replyCursor.Focus()
m.input.Focus()
m.input.MaxHeight = 0
m.input.CharLimit = 0
m.input.ShowLineNumbers = false
m.input.Placeholder = "Enter a message"
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
m.input.FocusedStyle.Base = inputFocusedStyle
m.input.BlurredStyle.Base = inputBlurredStyle
return &m
}
func (m *Model) Init() tea.Cmd {
return tea.Batch(
m.waitForResponseChunk(),
)
}

View File

@ -1,136 +0,0 @@
package chat
import (
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg {
return msgChatResponseChunk(<-m.chatReplyChunks)
}
}
func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.App.LoadConversationMessages()
if err != nil {
return shared.AsMsgError(err)
}
return msgConversationMessagesLoaded{messages}
}
}
func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := m.App.GenerateConversationTitle(m.App.Messages)
if err != nil {
return shared.AsMsgError(err)
}
return msgConversationTitleGenerated(title)
}
}
func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd {
return func() tea.Msg {
msg, err := m.App.CloneMessage(message, selected)
if err != nil {
return shared.WrapError(err)
}
return msgMessageCloned(msg)
}
}
func (m *Model) updateMessageContent(message *conversation.Message) tea.Cmd {
return func() tea.Msg {
err := m.App.UpdateMessageContent(message)
if err != nil {
return shared.WrapError(err)
}
return msgMessageUpdated(message)
}
}
func (m *Model) cycleSelectedRoot(conv *conversation.Conversation, dir model.MessageCycleDirection) tea.Cmd {
if len(conv.RootMessages) < 2 {
return nil
}
return func() tea.Msg {
nextRoot, err := m.App.CycleSelectedRoot(conv, dir)
if err != nil {
return shared.WrapError(err)
}
return msgSelectedRootCycled(nextRoot)
}
}
func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 {
return nil
}
return func() tea.Msg {
nextReply, err := m.App.CycleSelectedReply(message, dir)
if err != nil {
return shared.WrapError(err)
}
return msgSelectedReplyCycled(nextReply)
}
}
func (m *Model) persistConversation() tea.Cmd {
return func() tea.Msg {
conversation, err := m.App.PersistConversation()
if err != nil {
return shared.AsMsgError(err)
}
return msgConversationPersisted(conversation)
}
}
func (m *Model) persistMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.App.PersistMessages()
if err != nil {
return shared.AsMsgError(err)
}
return msgMessagesPersisted(messages)
}
}
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
return func() tea.Msg {
results, err := m.App.ExecuteToolCalls(toolCalls)
if err != nil {
return shared.AsMsgError(err)
}
return msgToolResults(results)
}
}
func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse
m.spinner = getSpinner()
m.replyCursor.Blink = false
m.startTime = time.Now()
m.elapsed = 0
m.tokenCount = 0
return tea.Batch(
m.spinner.Tick,
func() tea.Msg {
resp, err := m.App.Prompt(m.App.Messages, m.chatReplyChunks, m.stopSignal)
if err != nil {
return msgChatResponseError{Err: err}
}
return msgChatResponse(*resp)
},
)
}

View File

@ -1,201 +0,0 @@
package chat
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
switch m.focus {
case focusInput:
cmd := m.handleInputKey(msg)
if cmd != nil {
return cmd
}
case focusMessages:
cmd := m.handleMessagesKey(msg)
if cmd != nil {
return cmd
}
}
switch msg.String() {
case "esc":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return shared.KeyHandled(msg)
}
return func() tea.Msg {
return shared.MsgViewChange(shared.ViewConversations)
}
case "ctrl+c":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return shared.KeyHandled(msg)
}
case "ctrl+g":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return shared.KeyHandled(msg)
}
return func() tea.Msg {
return shared.MsgViewChange(shared.ViewSettings)
}
case "ctrl+p":
m.persistence = !m.persistence
return shared.KeyHandled(msg)
case "ctrl+t":
m.showDetails = !m.showDetails
m.rebuildMessageCache()
m.updateContent()
return shared.KeyHandled(msg)
case "ctrl+w":
m.wrap = !m.wrap
m.rebuildMessageCache()
m.updateContent()
return shared.KeyHandled(msg)
case "ctrl+n":
m.App.NewConversation()
m.rebuildMessageCache()
m.updateContent()
return shared.KeyHandled(msg)
}
return nil
}
func (m *Model) scrollSelection(dir int) {
if m.selectedMessage+dir < 0 || m.selectedMessage+dir >= len(m.App.Messages) {
return
}
newIdx := m.selectedMessage
for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir {
if !m.showDetails && m.App.Messages[i].Role.IsSystem() {
continue
}
newIdx = i
break
}
if newIdx != m.selectedMessage {
m.selectedMessage = newIdx
m.updateContent()
}
yOffset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, yOffset, m.content.Height/2)
}
// handleMessagesKey handles input when the messages pane is focused
func (m *Model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "tab", "enter":
m.focus = focusInput
m.updateContent()
m.input.Focus()
return shared.KeyHandled(msg)
case "e":
if m.selectedMessage < len(m.App.Messages) {
m.editorTarget = selectedMessage
return tuiutil.OpenTempfileEditor(
"message.*.md",
m.App.Messages[m.selectedMessage].Content,
"# Edit the message below\n",
)
}
return nil
case "ctrl+k", "ctrl+up":
if m.selectedMessage > 0 {
m.scrollSelection(-1)
}
return shared.KeyHandled(msg)
case "ctrl+j", "ctrl+down":
if m.selectedMessage < len(m.App.Messages)-1 {
m.scrollSelection(1)
}
return shared.KeyHandled(msg)
case "ctrl+h", "ctrl+left", "ctrl+l", "ctrl+right":
dir := model.CyclePrev
if msg.String() == "ctrl+l" || msg.String() == "ctrl+right" {
dir = model.CycleNext
}
var cmd tea.Cmd
if m.selectedMessage == 0 {
cmd = m.cycleSelectedRoot(&m.App.Conversation, dir)
} else if m.selectedMessage > 0 {
cmd = m.cycleSelectedReply(&m.App.Messages[m.selectedMessage-1], dir)
}
return cmd
case "ctrl+r":
// prompt the model with all messages up to and including the selected message
if m.state == idle && m.selectedMessage < len(m.App.Messages) {
m.App.Messages = m.App.Messages[:m.selectedMessage+1]
m.messageCache = m.messageCache[:m.selectedMessage+1]
cmd := m.promptLLM()
m.updateContent()
m.content.GotoBottom()
return cmd
}
}
return nil
}
// handleInputKey handles input when the input textarea is focused
func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.focus = focusMessages
if len(m.App.Messages) > 0 {
if m.selectedMessage < 0 || m.selectedMessage >= len(m.App.Messages) {
m.selectedMessage = len(m.App.Messages) - 1
}
offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
}
m.updateContent()
m.input.Blur()
return shared.KeyHandled(msg)
case "ctrl+s":
if m.state != idle {
return nil
}
input := strings.TrimSpace(m.input.Value())
if input == "" {
return shared.KeyHandled(msg)
}
if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role.IsUser() {
return shared.WrapError(fmt.Errorf("Can't reply to a user message"))
}
m.addMessage(conversation.Message{
Role: api.MessageRoleUser,
Content: input,
})
m.input.SetValue("")
var cmds []tea.Cmd
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
cmds = append(cmds, m.promptLLM())
m.updateContent()
m.content.GotoBottom()
return tea.Batch(cmds...)
case "ctrl+e":
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input
return cmd
}
return nil
}

View File

@ -1,270 +0,0 @@
package chat
import (
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/cursor"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) setMessage(i int, msg conversation.Message) {
if i >= len(m.App.Messages) {
panic("i out of range")
}
m.App.Messages[i] = msg
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) addMessage(msg conversation.Message) {
m.App.Messages = append(m.App.Messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.App.Messages)-1))
}
func (m *Model) setMessageContents(i int, content string) {
if i >= len(m.App.Messages) {
panic("i out of range")
}
m.App.Messages[i].Content = content
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) rebuildMessageCache() {
m.messageCache = make([]string, len(m.App.Messages))
for i := range m.App.Messages {
m.messageCache[i] = m.renderMessage(i)
}
}
func (m *Model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationMessagesView())
if atBottom {
m.content.GotoBottom()
}
}
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
inputHandled := false
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
cmd := m.handleInput(msg)
if cmd != nil {
inputHandled = true
cmds = append(cmds, cmd)
}
case tea.WindowSizeMsg:
m.Width, m.Height = msg.Width, msg.Height
m.content.Width = msg.Width
m.input.SetWidth(msg.Width - m.input.FocusedStyle.Base.GetHorizontalFrameSize())
if len(m.App.Messages) > 0 {
m.rebuildMessageCache()
m.updateContent()
}
case shared.MsgViewEnter:
// wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick)
// Refresh view
m.rebuildMessageCache()
m.updateContent()
if m.App.Conversation.ID > 0 {
// (re)load conversation contents
cmds = append(cmds, m.loadConversationMessages())
}
case tuiutil.MsgTempfileEditorClosed:
contents := string(msg)
switch m.editorTarget {
case input:
m.input.SetValue(contents)
case selectedMessage:
toEdit := m.App.Messages[m.selectedMessage]
if toEdit.Content != contents {
toEdit.Content = contents
m.setMessage(m.selectedMessage, toEdit)
if m.persistence && toEdit.ID > 0 {
// create clone of message with its new contents
cmds = append(cmds, m.cloneMessage(toEdit, true))
}
}
}
case msgConversationMessagesLoaded:
m.App.Messages = msg.messages
if m.selectedMessage == -1 {
m.selectedMessage = len(msg.messages) - 1
} else {
m.selectedMessage = min(m.selectedMessage, len(m.App.Messages))
}
m.rebuildMessageCache()
m.updateContent()
case msgChatResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
if msg.Content == "" {
break
}
last := len(m.App.Messages) - 1
if last >= 0 && m.App.Messages[last].Role.IsAssistant() {
// append chunk to existing message
m.setMessageContents(last, m.App.Messages[last].Content+msg.Content)
} else {
// use chunk in a new message
m.addMessage(conversation.Message{
Role: api.MessageRoleAssistant,
Content: msg.Content,
})
}
m.updateContent()
// show cursor and reset blink interval (simulate typing)
m.replyCursor.Blink = false
cmds = append(cmds, m.replyCursor.BlinkCmd())
m.tokenCount += msg.TokenCount
m.elapsed = time.Now().Sub(m.startTime)
case msgChatResponse:
m.state = idle
reply := conversation.Message(msg)
reply.Content = strings.TrimSpace(reply.Content)
last := len(m.App.Messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply")
}
if m.App.Messages[last].Role.IsAssistant() {
// TODO: handle continuations gracefully - only some models support them
m.setMessage(last, reply)
} else {
m.addMessage(reply)
}
if reply.Role == api.MessageRoleToolCall {
// TODO: user confirmation before execution
// m.state = confirmToolUse
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
}
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
if m.App.Conversation.Title == "" && len(m.App.Messages) > 0 {
cmds = append(cmds, m.generateConversationTitle())
}
case msgChatResponseCanceled:
m.state = idle
m.updateContent()
case msgChatResponseError:
m.state = idle
m.updateContent()
return m, shared.WrapError(msg.Err)
case msgToolResults:
last := len(m.App.Messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply")
}
if m.App.Messages[last].Role != api.MessageRoleToolCall {
panic("Previous message not a tool call, unexpected")
}
m.addMessage(conversation.Message{
Role: api.MessageRoleToolResult,
ToolResults: conversation.ToolResults(msg),
})
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
m.updateContent()
case msgConversationTitleGenerated:
title := string(msg)
m.App.Conversation.Title = title
if m.persistence && m.App.Conversation.ID > 0 {
cmds = append(cmds, m.persistConversation())
}
case cursor.BlinkMsg:
if m.state == pendingResponse {
// ensure we show the updated "wait for response" cursor blink state
last := len(m.App.Messages) - 1
m.messageCache[last] = m.renderMessage(last)
m.updateContent()
}
case msgConversationPersisted:
m.App.Conversation = conversation.Conversation(msg)
cmds = append(cmds, m.persistMessages())
case msgMessagesPersisted:
m.App.Messages = msg
m.rebuildMessageCache()
m.updateContent()
case msgMessageCloned:
cmds = append(cmds, m.loadConversationMessages())
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
cmds = append(cmds, m.loadConversationMessages())
}
var cmd tea.Cmd
m.spinner, cmd = m.spinner.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
m.replyCursor, cmd = m.replyCursor.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
prevInputLineCnt := m.input.LineCount()
if !inputHandled {
m.input, cmd = m.input.Update(msg)
if cmd != nil {
inputHandled = true
cmds = append(cmds, cmd)
}
}
if !inputHandled {
m.content, cmd = m.content.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
}
// this is a pretty nasty hack to ensure the input area viewport doesn't
// scroll below its content, which can happen when the input viewport
// height has grown, or previously entered lines have been deleted
if prevInputLineCnt != m.input.LineCount() {
// dist is the distance we'd need to scroll up from the current cursor
// position to position the last input line at the bottom of the
// viewport. if negative, we're already scrolled above the bottom
dist := m.input.Line() - (m.input.LineCount() - m.input.Height())
if dist > 0 {
for i := 0; i < dist; i++ {
// move cursor up until content reaches the bottom of the viewport
m.input.CursorUp()
}
m.input, _ = m.input.Update(nil)
for i := 0; i < dist; i++ {
// move cursor back down to its previous position
m.input.CursorDown()
}
m.input, _ = m.input.Update(nil)
}
}
if len(cmds) > 0 {
return m, tea.Batch(cmds...)
}
return m, nil
}

View File

@ -1,365 +0,0 @@
package chat
import (
"encoding/json"
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/wordwrap"
"github.com/muesli/reflow/wrap"
"gopkg.in/yaml.v3"
)
// styles
var (
boldStyle = lipgloss.NewStyle().Bold(true)
faintStyle = lipgloss.NewStyle().Faint(true)
boldFaintStyle = lipgloss.NewStyle().Faint(true).Bold(true)
messageHeadingStyle = lipgloss.NewStyle().
MarginTop(1).
MarginBottom(1)
userStyle = boldFaintStyle.Foreground(lipgloss.Color("10"))
assistantStyle = boldFaintStyle.Foreground(lipgloss.Color("12"))
systemStyle = boldStyle.Foreground(lipgloss.Color("8"))
messageStyle = lipgloss.NewStyle().
PaddingLeft(2).
PaddingRight(2)
inputFocusedStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder(), true, true, true, false)
inputBlurredStyle = lipgloss.NewStyle().
Faint(true).
Border(lipgloss.RoundedBorder(), true, true, true, false)
footerStyle = lipgloss.NewStyle().Padding(0, 1)
)
func (m *Model) renderMessageHeading(i int, message *conversation.Message) string {
friendly := message.Role.FriendlyRole()
style := systemStyle
switch message.Role {
case api.MessageRoleUser:
style = userStyle
case api.MessageRoleAssistant:
style = assistantStyle
case api.MessageRoleToolCall:
style = assistantStyle
friendly = api.MessageRoleAssistant.FriendlyRole()
case api.MessageRoleSystem:
case api.MessageRoleToolResult:
}
user := style.Render(friendly)
var prefix, suffix string
if i == m.selectedMessage && m.focus == focusMessages {
prefix = "> "
} else {
prefix = " "
}
if i == 0 && m.App.Conversation.SelectedRootID != nil && len(m.App.Conversation.RootMessages) > 1 {
selectedRootIndex := 0
for j, reply := range m.App.Conversation.RootMessages {
if reply.ID == *m.App.Conversation.SelectedRootID {
selectedRootIndex = j
break
}
}
suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.Conversation.RootMessages)))
}
if i > 0 && len(m.App.Messages[i-1].Replies) > 1 {
// Find the selected reply index
selectedReplyIndex := 0
for j, reply := range m.App.Messages[i-1].Replies {
if reply.ID == *m.App.Messages[i-1].SelectedReplyID {
selectedReplyIndex = j
break
}
}
suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedReplyIndex+1, len(m.App.Messages[i-1].Replies)))
}
if message.ID == 0 {
suffix += faintStyle.Render(" (not saved)")
}
heading := prefix + user + suffix
if message.Metadata.GenerationModel != nil && m.showDetails {
heading += faintStyle.Render(
fmt.Sprintf(" | %s", *message.Metadata.GenerationModel),
)
}
return messageHeadingStyle.Render(heading)
}
// renderMessages renders the message at the given index as it should be shown
// *at this moment* - we render differently depending on the current application
// state (window size, etc, etc).
func (m *Model) renderMessage(i int) string {
msg := &m.App.Messages[i]
// Write message contents
sb := &strings.Builder{}
sb.Grow(len(msg.Content) * 2)
if msg.Content != "" {
err := m.App.Ctx.Chroma.Highlight(sb, msg.Content)
if err != nil {
sb.Reset()
sb.WriteString(msg.Content)
}
}
isLast := i == len(m.App.Messages)-1
isAssistant := msg.Role == api.MessageRoleAssistant
if m.state == pendingResponse && isLast && isAssistant {
// Show the assistant's cursor
sb.WriteString(m.replyCursor.View())
}
// Write tool call info
var toolString string
switch msg.Role {
case api.MessageRoleToolCall:
bytes, err := yaml.Marshal(msg.ToolCalls)
if err != nil {
toolString = "Could not serialize ToolCalls"
} else {
toolString = "tool_calls:\n" + string(bytes)
}
case api.MessageRoleToolResult:
type renderedResult struct {
ToolName string `yaml:"tool"`
Result any `yaml:"result,omitempty"`
}
var toolResults []renderedResult
for _, result := range msg.ToolResults {
if m.showDetails {
var jsonResult interface{}
err := json.Unmarshal([]byte(result.Result), &jsonResult)
if err != nil {
// If parsing as JSON fails, treat Result as a plain string
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: result.Result,
})
} else {
// If parsing as JSON succeeds, marshal the parsed JSON into YAML
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: &jsonResult,
})
}
} else {
// Only show the tool name when results are hidden
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: "(hidden, press ctrl+t to view)",
})
}
}
bytes, err := yaml.Marshal(toolResults)
if err != nil {
toolString = "Could not serialize ToolResults"
} else {
toolString = "tool_results:\n" + string(bytes)
}
}
if toolString != "" {
toolString = strings.TrimRight(toolString, "\n")
if msg.Content != "" {
sb.WriteString("\n\n")
}
_ = m.App.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
}
content := strings.TrimRight(sb.String(), "\n")
if m.wrap {
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding()
// first we word-wrap text to slightly less than desired width (since
// wordwrap seems to have an off-by-1 issue), then hard wrap at
// desired with
content = wrap.String(wordwrap.String(content, wrapWidth-2), wrapWidth)
}
return messageStyle.Width(0).Render(content)
}
// render the conversation into a string
func (m *Model) conversationMessagesView() string {
m.messageOffsets = make([]int, len(m.App.Messages))
lineCnt := 1
sb := strings.Builder{}
for i, message := range m.App.Messages {
m.messageOffsets[i] = lineCnt
if !m.showDetails && message.Role.IsSystem() {
continue
}
heading := m.renderMessageHeading(i, &message)
sb.WriteString(heading)
sb.WriteString("\n")
lineCnt += lipgloss.Height(heading)
rendered := m.messageCache[i]
sb.WriteString(rendered)
sb.WriteString("\n")
lineCnt += lipgloss.Height(rendered)
}
// Render a placeholder for the incoming assistant reply
if m.state == pendingResponse && m.App.Messages[len(m.App.Messages)-1].Role != api.MessageRoleAssistant {
heading := m.renderMessageHeading(-1, &conversation.Message{
Role: api.MessageRoleAssistant,
Metadata: conversation.MessageMeta{
GenerationModel: &m.App.Model,
},
})
sb.WriteString(heading)
sb.WriteString("\n")
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
sb.WriteString("\n")
}
return sb.String()
}
func (m *Model) Content(width, height int) string {
// calculate clamped input height to accomodate input text
// minimum 4 lines, maximum half of content area
inputHeight := max(4, min(height/2, m.input.LineCount()))
m.input.SetHeight(inputHeight)
input := m.input.View()
// remaining height towards content
m.content.Width, m.content.Height = width, height-tuiutil.Height(input)
content := m.content.View()
return lipgloss.JoinVertical(lipgloss.Left, content, input)
}
func (m *Model) Header(width int) string {
titleStyle := lipgloss.NewStyle().Bold(true)
var title string
if m.App.Conversation.Title != "" {
title = m.App.Conversation.Title
} else {
title = "Untitled"
}
title = tuiutil.TruncateRightToCellWidth(title, width-styles.Header.GetHorizontalPadding(), "...")
header := titleStyle.Render(title)
return styles.Header.Width(width).Render(header)
}
func (m *Model) Footer(width int) string {
segmentStyle := lipgloss.NewStyle().Faint(true)
segmentSeparator := segmentStyle.Render(" | ")
// Left segments
leftSegments := make([]string, 0, 4)
if m.state == pendingResponse {
leftSegments = append(leftSegments, segmentStyle.Render(m.spinner.View()))
} else {
leftSegments = append(leftSegments, segmentStyle.Render("∙∙∙"))
}
if m.elapsed > 0 && m.tokenCount > 0 {
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
leftSegments = append(leftSegments, segmentStyle.Render(throughput))
}
// var status string
// switch m.state {
// case pendingResponse:
// status = "Press ctrl+c to cancel"
// default:
// status = "Press ctrl+s to send"
// }
// leftSegments = append(leftSegments, segmentStyle.Render(status))
// Right segments
rightSegments := make([]string, 0, 8)
if m.App.Agent != nil {
rightSegments = append(rightSegments, segmentStyle.Render(m.App.Agent.Name))
}
model := segmentStyle.Render(m.App.ActiveModel(lipgloss.NewStyle()))
rightSegments = append(rightSegments, model)
savingStyle := segmentStyle.Bold(true)
saving := ""
if m.persistence {
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("💾✅")
} else {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("💾❌")
}
rightSegments = append(rightSegments, saving)
return m.layoutFooter(width, leftSegments, rightSegments, segmentSeparator)
}
func (m *Model) layoutFooter(
width int,
leftSegments []string,
rightSegments []string,
segmentSeparator string,
) string {
left := strings.Join(leftSegments, segmentSeparator)
right := strings.Join(rightSegments, segmentSeparator)
leftWidth := tuiutil.Width(left)
rightWidth := tuiutil.Width(right)
sepWidth := tuiutil.Width(segmentSeparator)
frameWidth := footerStyle.GetHorizontalFrameSize()
availableWidth := width - frameWidth - leftWidth - rightWidth
if availableWidth >= sepWidth {
// Everything fits
padding := strings.Repeat(" ", availableWidth)
return footerStyle.Render(left + padding + right)
}
// Inserted between left and right segments when they're being truncated
div := "..."
totalAvailableWidth := width - frameWidth
availableTruncWidth := totalAvailableWidth - len(div)
minVisibleLength := 3
if availableTruncWidth < 2*minVisibleLength {
minVisibleLength = availableTruncWidth / 2
}
leftProportion := float64(leftWidth) / float64(leftWidth+rightWidth)
newLeftWidth := int(max(float64(minVisibleLength), leftProportion*float64(availableTruncWidth)))
newRightWidth := totalAvailableWidth - newLeftWidth
truncatedLeft := faintStyle.Render(tuiutil.TruncateRightToCellWidth(left, newLeftWidth, ""))
truncatedRight := faintStyle.Render(tuiutil.TruncateLeftToCellWidth(right, newRightWidth, "..."))
return footerStyle.Width(width).Render(truncatedLeft + truncatedRight)
}

View File

@ -1,326 +0,0 @@
package conversations
import (
"fmt"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type (
// sent when conversation list is loaded
msgConversationsLoaded conversation.ConversationList
// sent when a single conversation is loaded
msgConversationLoaded conversation.Conversation
// sent when a conversation is deleted
msgConversationDeleted struct{}
)
type Model struct {
App *model.AppModel
width int
height int
cursor int
itemOffsets []int // conversation y offsets
content viewport.Model
confirmPrompt bubbles.ConfirmPrompt
}
func Conversations(app *model.AppModel) *Model {
viewport.New(0, 0)
m := Model{
App: app,
content: viewport.New(0, 0),
}
return &m
}
func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
if m.confirmPrompt.Focused() {
var cmd tea.Cmd
m.confirmPrompt, cmd = m.confirmPrompt.Update(msg)
if cmd != nil {
return cmd
}
}
conversations := m.App.Conversations.Items
switch msg.String() {
case "enter":
if len(conversations) > 0 && m.cursor < len(conversations) {
return m.loadConversation(conversations[m.cursor].ID)
}
case "j", "down":
if m.cursor < len(conversations)-1 {
m.cursor++
if m.cursor == len(conversations)-1 {
m.content.GotoBottom()
} else {
// this hack positions the *next* conversatoin slightly
// *off* the screen, ensuring the entire m.cursor is shown,
// even if its height may not be constant due to wrapping.
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor+1], -1)
}
m.content.SetContent(m.renderConversationList())
} else {
m.cursor = len(conversations) - 1
m.content.GotoBottom()
}
return shared.KeyHandled(msg)
case "k", "up":
if m.cursor > 0 {
m.cursor--
if m.cursor == 0 {
m.content.GotoTop()
} else {
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor], 1)
}
m.content.SetContent(m.renderConversationList())
} else {
m.cursor = 0
m.content.GotoTop()
}
return shared.KeyHandled(msg)
case "n":
m.App.NewConversation()
return shared.ChangeView(shared.ViewChat)
case "d":
if !m.confirmPrompt.Focused() && len(conversations) > 0 && m.cursor < len(conversations) {
title := conversations[m.cursor].Title
if title == "" {
title = "(untitled)"
}
m.confirmPrompt = bubbles.NewConfirmPrompt(
fmt.Sprintf("Delete '%s'?", title),
conversations[m.cursor],
)
m.confirmPrompt.Style = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("3"))
return shared.KeyHandled(msg)
}
case "c":
// copy/clone conversation
case "r":
// show prompt to rename conversation
case "shift+r":
// show prompt to generate name for conversation
}
return nil
}
func (m *Model) Init() tea.Cmd {
return nil
}
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
isInput := false
inputHandled := false
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
isInput = true
cmd := m.handleInput(msg)
if cmd != nil {
cmds = append(cmds, cmd)
inputHandled = true
}
case shared.MsgViewEnter:
cmds = append(cmds, m.loadConversations())
m.content.SetContent(m.renderConversationList())
case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height
m.content.SetContent(m.renderConversationList())
case msgConversationsLoaded:
m.App.Conversations = conversation.ConversationList(msg)
m.cursor = max(0, min(len(m.App.Conversations.Items), m.cursor))
m.content.SetContent(m.renderConversationList())
case msgConversationLoaded:
m.App.ClearConversation()
m.App.Conversation = conversation.Conversation(msg)
cmds = append(cmds, func() tea.Msg {
return shared.MsgViewChange(shared.ViewChat)
})
case bubbles.MsgConfirmPromptAnswered:
m.confirmPrompt.Blur()
if msg.Value {
conv, ok := msg.Payload.(conversation.ConversationListItem)
if ok {
cmds = append(cmds, m.deleteConversation(conv))
}
}
case msgConversationDeleted:
cmds = append(cmds, m.loadConversations())
}
if !isInput || !inputHandled {
content, cmd := m.content.Update(msg)
m.content = content
if cmd != nil {
cmds = append(cmds, cmd)
}
}
if len(cmds) > 0 {
return m, tea.Batch(cmds...)
}
return m, nil
}
func (m *Model) loadConversations() tea.Cmd {
return func() tea.Msg {
list, err := m.App.Ctx.Conversations.LoadConversationList()
if err != nil {
return shared.AsMsgError(fmt.Errorf("Could not load conversations: %v", err))
}
return msgConversationsLoaded(list)
}
}
func (m *Model) loadConversation(conversationID uint) tea.Cmd {
return func() tea.Msg {
conversation, err := m.App.Ctx.Conversations.GetConversationByID(conversationID)
if err != nil {
return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err))
}
return msgConversationLoaded(*conversation)
}
}
func (m *Model) deleteConversation(conv conversation.ConversationListItem) tea.Cmd {
return func() tea.Msg {
err := m.App.Ctx.Conversations.DeleteConversationById(conv.ID)
if err != nil {
return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err))
}
return msgConversationDeleted{}
}
}
func (m *Model) Header(width int) string {
titleStyle := lipgloss.NewStyle().Bold(true)
header := titleStyle.Render("Conversations")
return styles.Header.Width(width).Render(header)
}
func (m *Model) Content(width int, height int) string {
m.content.Width, m.content.Height = width, height
return m.content.View()
}
func (m *Model) Footer(width int) string {
if m.confirmPrompt.Focused() {
return lipgloss.NewStyle().Width(width).Render(m.confirmPrompt.View())
}
return ""
}
func (m *Model) renderConversationList() string {
now := time.Now()
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
dayOfWeek := int(now.Weekday())
categories := []struct {
name string
cutoff time.Duration
}{
{"Today", now.Sub(midnight)},
{"Yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
{"This week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
{"Last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
{"This month", now.Sub(monthStart)},
{"Last month", now.Sub(monthStart.AddDate(0, -1, 0))},
{"2 Months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
{"3 Months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
{"4 Months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
{"5 Months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
{"6 Months ago", now.Sub(monthStart.AddDate(0, -6, 0))},
{"Older", now.Sub(time.Time{})},
}
categoryStyle := lipgloss.NewStyle().
MarginBottom(1).
Foreground(lipgloss.Color("12")).
PaddingLeft(1).
Bold(true)
itemStyle := lipgloss.NewStyle().
MarginBottom(1)
ageStyle := lipgloss.NewStyle().Faint(true).SetString()
titleStyle := lipgloss.NewStyle().Bold(true)
untitledStyle := lipgloss.NewStyle().Faint(true).Italic(true)
selectedStyle := lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6"))
var (
currentOffset int
currentCategory string
sb strings.Builder
)
m.itemOffsets = make([]int, len(m.App.Conversations.Items))
sb.WriteRune('\n')
currentOffset += 1
for i, c := range m.App.Conversations.Items {
lastReplyAge := now.Sub(c.LastMessageAt)
var category string
for _, g := range categories {
if lastReplyAge < g.cutoff {
category = g.name
break
}
}
// print the category
if category != currentCategory {
currentCategory = category
heading := categoryStyle.Render(currentCategory)
sb.WriteString(heading)
currentOffset += tuiutil.Height(heading)
sb.WriteRune('\n')
}
tStyle := titleStyle
if c.Title == "" {
tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)")
}
if i == m.cursor {
tStyle = tStyle.Inherit(selectedStyle)
}
title := tStyle.Width(m.width - 3).PaddingLeft(2).Render(c.Title)
if i == m.cursor {
title = ">" + title[1:]
}
m.itemOffsets[i] = currentOffset
item := itemStyle.Render(fmt.Sprintf(
"%s\n %s",
title,
ageStyle.Render(util.HumanTimeElapsedSince(lastReplyAge)),
))
sb.WriteString(item)
currentOffset += tuiutil.Height(item)
if i < len(m.App.Conversations.Items)-1 {
sb.WriteRune('\n')
}
}
return sb.String()
}

View File

@ -1,137 +0,0 @@
package settings
import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type Model struct {
App *model.AppModel
prevView shared.View
content viewport.Model
modelList list.Model
width int
height int
}
type modelOpt struct {
provider string
model string
}
const (
modelListId int = iota + 1
)
func Settings(app *model.AppModel) *Model {
m := &Model{
App: app,
content: viewport.New(0, 0),
}
return m
}
func (m *Model) Init() tea.Cmd {
m.modelList = list.NewWithGroups(m.getModelOptions())
m.modelList.ID = modelListId
return nil
}
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
m.modelList, cmd = m.modelList.Update(msg)
if cmd != nil {
return m, cmd
}
switch msg.String() {
case "esc":
return m, func() tea.Msg {
return shared.MsgViewChange(m.prevView)
}
}
case shared.MsgViewEnter:
m.prevView = shared.View(msg)
m.modelList.Focus()
m.content.SetContent(m.renderContent())
case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height
m.content.Width = msg.Width
m.content.Height = msg.Height
m.content.SetContent(m.renderContent())
case list.MsgOptionSelected:
switch msg.ID {
case modelListId:
if modelOpt, ok := msg.Option.Value.(modelOpt); ok {
m.App.Model = modelOpt.model
m.App.ProviderName = modelOpt.provider
}
return m, shared.ChangeView(m.prevView)
}
}
m.modelList, cmd = m.modelList.Update(msg)
if cmd != nil {
return m, cmd
}
m.content.SetContent(m.renderContent())
return m, nil
}
func (m *Model) getModelOptions() []list.OptionGroup {
modelOpts := []list.OptionGroup{}
for _, p := range m.App.Ctx.Config.Providers {
provider := p.Name
if provider == "" {
provider = p.Kind
}
providerLabel := p.Display
if providerLabel == "" {
providerLabel = strings.ToUpper(provider[:1]) + provider[1:]
}
group := list.OptionGroup{
Name: providerLabel,
}
for _, model := range p.Models {
group.Options = append(group.Options, list.Option{
Label: model,
Value: modelOpt{provider, model},
})
}
modelOpts = append(modelOpts, group)
}
return modelOpts
}
func (m *Model) Header(width int) string {
boldStyle := lipgloss.NewStyle().Bold(true)
// TODO: update header depending on active settings mode (model, agent, etc)
header := boldStyle.Render("Model selection")
return styles.Header.Width(width).Render(header)
}
func (m *Model) Content(width, height int) string {
// TODO: see Header()
currentModel := " Active model: " + m.App.ActiveModel(lipgloss.NewStyle())
m.modelList.Width, m.modelList.Height = width, height - 2
return "\n" + currentModel + "\n" + m.modelList.View()
}
func (m *Model) Footer(width int) string {
return ""
}
func (m *Model) renderContent() string {
return m.modelList.View()
}

View File

@ -58,29 +58,3 @@ func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
s.formatter.Format(&sb, s.style, it) s.formatter.Format(&sb, s.style, it)
return sb.String(), nil return sb.String(), nil
} }
func (s *ChromaHighlighter) HighlightLang(w io.Writer, text string, lang string) (error) {
l := lexers.Get(lang)
if l == nil {
l = lexers.Fallback
}
l = chroma.Coalesce(l)
old := s.lexer
s.lexer = l
err := s.Highlight(w, text)
s.lexer = old
return err
}
func (s *ChromaHighlighter) HighlightLangS(text string, lang string) (string, error) {
l := lexers.Get(lang)
if l == nil {
l = lexers.Fallback
}
l = chroma.Coalesce(l)
old := s.lexer
s.lexer = l
highlighted, err := s.HighlightS(text)
s.lexer = old
return highlighted, err
}

View File

@ -21,7 +21,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string
msgFile, _ := os.CreateTemp("/tmp", pattern) msgFile, _ := os.CreateTemp("/tmp", pattern)
defer os.Remove(msgFile.Name()) defer os.Remove(msgFile.Name())
os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend) os.WriteFile(msgFile.Name(), []byte(placeholder + content), os.ModeAppend)
editor := os.Getenv("EDITOR") editor := os.Getenv("EDITOR")
if editor == "" { if editor == "" {
@ -137,8 +137,8 @@ func SetStructDefaults(data interface{}) bool {
} }
// Get the "default" struct tag // Get the "default" struct tag
defaultTag, ok := v.Type().Field(i).Tag.Lookup("default") defaultTag := v.Type().Field(i).Tag.Get("default")
if !ok { if defaultTag == "" {
continue continue
} }
@ -147,16 +147,16 @@ func SetStructDefaults(data interface{}) bool {
case reflect.String: case reflect.String:
defaultValue := defaultTag defaultValue := defaultTag
field.Set(reflect.ValueOf(&defaultValue)) field.Set(reflect.ValueOf(&defaultValue))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
field.Set(reflect.New(e))
field.Elem().SetUint(intValue)
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits()) intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
field.Set(reflect.New(e)) field.Set(reflect.New(e))
field.Elem().SetInt(intValue) field.Elem().SetInt(intValue)
case reflect.Float32, reflect.Float64: case reflect.Float32:
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits()) floatValue, _ := strconv.ParseFloat(defaultTag, 32)
field.Set(reflect.New(e))
field.Elem().SetFloat(floatValue)
case reflect.Float64:
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
field.Set(reflect.New(e)) field.Set(reflect.New(e))
field.Elem().SetFloat(floatValue) field.Elem().SetFloat(floatValue)
case reflect.Bool: case reflect.Bool: