Compare commits

..

No commits in common. "main" and "tools" have entirely different histories.
main ... tools

59 changed files with 1759 additions and 6757 deletions

181
README.md
View File

@ -1,174 +1,37 @@
# lmcli - Large ____ Model CLI # lmcli
`lmcli` is a versatile command-line interface for interacting with LLMs and LMMs. `lmcli` is a (Large) Language Model CLI.
## Features Current features:
- Perform one-shot prompts with `lmcli prompt <message>`
- Manage persistent conversations with the `new`, `reply`, `view`, and `rm`
sub-commands.
- Syntax highlighted output
- Multiple model backends (Ollama, OpenAI, Anthropic, Google) Planned features:
- Customizable agents with tool calling - Ask questions about content received on stdin
- Persistent conversation management - "functions" to allow reading (and possibly writing) to files within the
- Message branching (edit and re-prompt to your heart's desire) current working directory
- Interactive terminal interface for seamless chat experiences
- Utilizes `$EDITOR`, write and edit prompts from the comfort of your own editor :)
- `vi`-like bindings
- Syntax highlighting!
## Screenshots Maybe features:
- Natural language image generation, iterative editing
[TODO: Add screenshots of the TUI in action, showing different views and features] ## Install
## Installation ```shell
$ go install git.mlow.ca/mlow/lmcli@latest
To install `lmcli`, make sure you have Go installed on your system, then run:
```sh
go install git.mlow.ca/mlow/lmcli@latest
``` ```
## Configuration
`lmcli` uses a YAML configuration file located at `~/.config/lmcli/config.yaml`. Here's a sample configuration:
```yaml
defaults:
model: claude-3-5-sonnet-20240620
maxTokens: 3072
temperature: 0.2
conversations:
titleGenerationModel: claude-3-haiku-20240307
chroma:
style: onedark
formatter: terminal16m
agents:
- name: coder
tools:
- dir_tree
- read_file
#- write_file
systemPrompt: |
You are an experienced software engineer...
# ...
providers:
- kind: ollama
models:
- phi3:instruct
- llama3:8b
- kind: anthropic
apiKey: your-api-key-here
models:
- claude-3-5-sonnet-20240620
- claude-3-opus-20240229
- claude-3-haiku-20240307
- kind: openai
apiKey: your-api-key-here
models:
- gpt-4o
- gpt-4-turbo
- name: openrouter
kind: openai
apiKey: your-api-key-here
baseUrl: https://openrouter.ai/api/
models:
- qwen/qwen-2-72b-instruct
# ...
```
Customize this file to add your own providers, agents, and models.
### Syntax highlighting
Syntax highlighting is performed by [Chroma](https://github.com/alecthomas/chroma).
Refer to [`Chroma/styles`](https://github.com/alecthomas/chroma/tree/master/styles) for available styles (TODO: add support for custom Chroma styles).
Available formatters:
- `terminal` - 8 colors
- `terminal16` - 16 colors
- `terminal256` - 256 colors
- `terminal16m` - true color (default)
## Agents
Agents in `lmcli` combine a system prompt with a set of available tools. Agents are defined in `config.yaml` and are called upon with the `-a`/`--agent` flag.
Agent functionality is expected to be expanded on, bringing them to close parity with something like OpenAI's "Assistants" feature.
## Tools
Tools are used by agents to acquire information from and interact with external systems. The following built-in tools are available:
- `dir_tree`: Display a directory structure
- `read_file`: Read the contents of a file
- `write_file`: Write content to a file
- `file_insert_lines`: Insert lines at a specific position in a file
- `file_replace_lines`: Replace a range of lines in a file
Obviously, some of these tools carry significant risk. Use wisely :)
More tool features are planned, including the ability to define arbitrary tools which call out to external scripts, tools to spawn sub-agents, perform web searches, etc.
## Usage ## Usage
```console Invoke `lmcli` at least once:
```shell
$ lmcli help $ lmcli help
lmcli - Large Language Model CLI
Usage:
lmcli <command> [flags]
lmcli [command]
Available Commands:
chat Open the chat interface
clone Clone conversations
completion Generate the autocompletion script for the specified shell
continue Continue a conversation from the last message
edit Edit the last user reply in a conversation
help Help about any command
list List conversations
new Start a new conversation
prompt Do a one-shot prompt
rename Rename a conversation
reply Reply to a conversation
retry Retry the last user reply in a conversation
rm Remove conversations
view View messages in a conversation
Flags:
-h, --help help for lmcli
Use "lmcli [command] --help" for more information about a command.
``` ```
### Examples Edit `~/.config/lmcli/config.yaml` and set `openai.apiKey` to your API key.
Start a new chat with the `coder` agent: Refer back to the output of `lmcli help` for usage.
```console Enjoy!
$ lmcli chat --agent coder
```
Start a new conversation, imperative style (no tui):
```console
$ lmcli new "Help me plan meals for the next week"
```
Send a one-shot prompt (no persistence):
```console
$ lmcli prompt "What is the answer to life, the universe, and everything?"
```
## Contributing
Contributions to `lmcli` are welcome! Feel free to open issues or submit pull requests on the project repository.
For a full list of planned features and improvements, check out the [TODO.md](TODO.md) file.
## License
To be determined
## Acknowledgements
`lmcli` is a small hobby project. Special thanks to the Go community and the creators of the libraries used in this project.

34
TODO.md
View File

@ -1,34 +0,0 @@
# TODO
- [x] Strip anthropic XML function call scheme from content, to reconstruct
when calling anthropic?
- [x] `dir_tree` tool
- [x] Implement native Anthropic API tool calling
- [x] Agents - a name given to a system prompt + set of available tools +
potentially other relevent data (e.g. external service credentials, files for
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
code-helper`, `lmcli chat -a financier`).
- [ ] Specialized agents which have integrations beyond basic tool calling,
e.g. a coding agent which bakes in efficient code context management
(only the current state of relevant files get shown to the model in the
system prompt, rather than having them in the conversation messages)
- [ ] Agents may have some form of long term memory management (key-value?
natural lang?).
- [ ] Sandboxed python, js interpreter (both useful for different reasons)
- [ ] Support for arbitrary external script tools
- [ ] Search - RAG driven search of existing conversation "hey, remind me of
the conversation we had six months ago about X")
- [ ] Conversation categorization - model driven category creation and
conversation classification
- [ ] Image input
- [ ] Image output (sixel support?)
- [ ] Conversation exports to html/pdf/json
## UI
- [x] Prettify/normalize tool_call and tool_result outputs so they can be
shown/optionally hidden in `lmcli view` and `lmcli chat`
- [ ] User confirmation before calling (some?) tools
- [ ] Conversation deletion in conversations view
- [ ] Message deletion, Ctrl+D to delete a message and attach its children to
its parent, Ctrl+Shift+D to delete a message and its descendents
- [ ] Show available key bindings and their action in any given view

24
go.mod
View File

@ -4,39 +4,25 @@ go 1.21
require ( require (
github.com/alecthomas/chroma/v2 v2.11.1 github.com/alecthomas/chroma/v2 v2.11.1
github.com/charmbracelet/bubbles v0.18.0 github.com/go-yaml/yaml v2.1.0+incompatible
github.com/charmbracelet/bubbletea v0.25.0 github.com/gookit/color v1.5.4
github.com/charmbracelet/lipgloss v0.10.0 github.com/sashabaranov/go-openai v1.17.7
github.com/muesli/reflow v0.3.0
github.com/spf13/cobra v1.8.0 github.com/spf13/cobra v1.8.0
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.5.4 gorm.io/driver/sqlite v1.5.4
gorm.io/gorm v1.25.5 gorm.io/gorm v1.25.5
) )
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.15.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sync v0.1.0 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
golang.org/x/sys v0.14.0 // indirect golang.org/x/sys v0.14.0 // indirect
golang.org/x/term v0.6.0 // indirect
golang.org/x/text v0.3.8 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect
) )

57
go.sum
View File

@ -4,22 +4,16 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw= github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@ -32,52 +26,33 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=

17
main.go
View File

@ -1,18 +1,15 @@
package main package main
import ( import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/cmd" "os"
"git.mlow.ca/mlow/lmcli/pkg/cli"
) )
func main() { func main() {
ctx, err := lmcli.NewContext() if err := cli.Execute(); err != nil {
if err != nil { fmt.Fprint(os.Stderr, err)
lmcli.Fatal("%v\n", err) os.Exit(1)
}
root := cmd.RootCmd(ctx)
if err := root.Execute(); err != nil {
lmcli.Fatal("%v\n", err)
} }
} }

View File

@ -1,142 +0,0 @@
package toolbox
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
Use these results for your own reference in completing your task, they do not need to be shown to the user.
Example result:
{
"message": "success",
"result": ".
a_directory/
file1.txt (100 bytes)
file2.txt (200 bytes)
a_file.txt (123 bytes)
another_file.txt (456 bytes)"
}
`
var DirTreeTool = api.ToolSpec{
Name: "dir_tree",
Description: TREE_DESCRIPTION,
Parameters: []api.ToolParameter{
{
Name: "relative_path",
Type: "string",
Description: "If set, display the tree starting from this path relative to the current one.",
},
{
Name: "depth",
Type: "integer",
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
var relativeDir string
if tmp, ok := args["relative_path"]; ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
}
}
var depth int = 0 // Default value if not provided
if tmp, ok := args["depth"]; ok {
switch v := tmp.(type) {
case float64:
depth = int(v)
case string:
var err error
if depth, err = strconv.Atoi(v); err != nil {
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
}
depth = max(0, min(5, depth))
default:
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
}
}
result := tree(relativeDir, depth)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("could not serialize result: %v", err)
}
return ret, nil
},
}
func tree(path string, depth int) api.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
var treeOutput strings.Builder
treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", depth)
if err != nil {
return api.CallResult{
Message: err.Error(),
}
}
return api.CallResult{Result: treeOutput.String()}
}
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
files, err := os.ReadDir(path)
if err != nil {
return err
}
for i, file := range files {
if strings.HasPrefix(file.Name(), ".") {
// Skip hidden files and directories
continue
}
isLast := i == len(files)-1
var branch string
if isLast {
branch = "└── "
} else {
branch = "├── "
}
info, _ := file.Info()
size := info.Size()
sizeStr := fmt.Sprintf(" (%d bytes)", size)
output.WriteString(prefix + branch + file.Name())
if file.IsDir() {
output.WriteString("/\n")
if depth > 0 {
var nextPrefix string
if isLast {
nextPrefix = prefix + " "
} else {
nextPrefix = prefix + "│ "
}
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
}
} else {
output.WriteString(sizeStr + "\n")
}
}
return nil
}

View File

@ -1,114 +0,0 @@
package toolbox
import (
"fmt"
"os"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.`
var FileInsertLinesTool = api.ToolSpec{
Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: []api.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "position",
Type: "integer",
Description: `Which line to insert content *before*.`,
Required: true,
},
{
Name: "content",
Type: "string",
Description: `The content to insert.`,
Required: true,
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var position int
tmp, ok = args["position"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
}
position = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileInsertLines(path, position, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileInsertLines(path string, position int, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
if position < 1 {
return api.CallResult{Message: "start_line cannot be less than 1"}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
before := lines[:position-1]
after := lines[position-1:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return api.CallResult{Result: newContent}
}

View File

@ -1,178 +0,0 @@
package toolbox
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
var MODIFY_FILE_DESCRIPTION = []string{
"Modify a file. If the file does not exist, it will be created.",
"",
"Content can be either inserted, replaced, or removed through a combination of the start, stop, and content parameters.",
"Use the start and stop line numbers to limit the range of modification to the file.",
"If both `start` and `stop` are left unset (or set to 0), the entire file's contents will be updated.",
"If `start` is set to n and `stop` to n+1, content will be inserted at line n (the content that was at line n will be shifted below the newly inserted content).",
"If only `start` is set, content from the given line and onwards will be updated.",
"If only `stop` is set, content up to but not including the given line will be updated.",
"",
"Examples:",
"1. Append to a file:",
" {\"path\": \"example.txt\", \"start\": <last_line_number + 1>, \"content\": \"New content to append\"}",
"",
"2. Insert at a specific line:",
" {\"path\": \"example.txt\", \"start\": 5, \"stop\": 5, \"content\": \"New line inserted above the previous line 5\"}",
"",
"3. Replace a range of lines:",
" {\"path\": \"example.txt\", \"start\": 3, \"stop\": 7, \"content\": \"New content replacing lines 3-7\"}",
"",
"4. Remove a range of lines:",
" {\"path\": \"example.txt\", \"start\": 2, \"stop\": 5}",
"",
"5. Replace entire file contents:",
" {\"path\": \"example.txt\", \"content\": \"New file contents\"}",
"",
"6. Update from a specific line to the end of the file:",
" {\"path\": \"example.txt\", \"start\": 10, \"content\": \"New content from line 10 onwards\"}",
"",
"7. Update from the beginning of the file to a specific line:",
" {\"path\": \"example.txt\", \"stop\": 6, \"content\": \"New content for first 5 lines\"}",
"",
"Note: Always use specific line numbers based on the current file content. Avoid using arbitrarily large numbers for start or stop.",
}
var ModifyFile = api.ToolSpec{
Name: "modify_file",
Description: strings.Join(MODIFY_FILE_DESCRIPTION, "\n"),
Parameters: []api.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "start",
Type: "integer",
Description: `Start line of the range to modify (inclusive). If omitted, the beginning of the file is implied.`,
},
{
Name: "stop",
Type: "integer",
Description: `End line of the range to modify (inclusive). If omitted, the end of the file is implied.`,
},
{
Name: "content",
Type: "string",
Description: "Content to insert/replace at the range defined by `start` and `stop`. If omitted, the range is removed.",
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to modify_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start int
tmp, ok = args["start"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start in function arguments: %v", tmp)
}
start = int(tmp)
}
var stop int
tmp, ok = args["stop"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid stop in function arguments: %v", tmp)
}
stop = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileModifyContents(path, start, stop, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileModifyContents(path string, startLine int, stopLine int, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
// If both start and stop are unset, update the entire file
if startLine == 0 && stopLine == 0 {
lines = contentLines
} else {
if startLine < 1 {
startLine = 1
}
if stopLine == 0 || stopLine > len(lines) {
stopLine = len(lines)
}
before := lines[:startLine-1]
after := lines[stopLine:]
// Handle insertion case
if startLine == stopLine {
lines = append(before, append(contentLines, lines[startLine-1:]...)...)
} else {
// If content is omitted, remove the specified range
if content == "" {
lines = append(before, after...)
} else {
lines = append(before, append(contentLines, after...)...)
}
}
}
newContent := strings.Join(lines, "\n")
// Write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return api.CallResult{Result: util.AddLineNumbers(newContent)}
}

View File

@ -1,100 +0,0 @@
package toolbox
import (
"fmt"
"os"
"path/filepath"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
Example result:
{
"message": "success",
"result": [
{"name": "a_file.txt", "type": "file", "size": 123},
{"name": "a_directory/", "type": "dir", "size": 11},
...
]
}
For files, size represents the size of the file, in bytes.
For directories, size represents the number of entries in that directory.`
var ReadDirTool = api.ToolSpec{
Name: "read_dir",
Description: READ_DIR_DESCRIPTION,
Parameters: []api.ToolParameter{
{
Name: "relative_dir",
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
result := readDir(relativeDir)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func readDir(path string) api.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
files, err := os.ReadDir(path)
if err != nil {
return api.CallResult{
Message: err.Error(),
}
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
name := f.Name()
if strings.HasPrefix(name, ".") {
// skip hidden files
continue
}
entryType := "file"
size := info.Size()
if info.IsDir() {
name += "/"
entryType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
size = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": name,
"type": entryType,
"size": size,
})
}
return api.CallResult{Result: dirContents}
}

View File

@ -1,65 +0,0 @@
package toolbox
import (
"fmt"
"os"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
Each line of the returned content is prefixed with its line number and a tab (\t).
Example result:
{
"message": "success",
"result": "1\tthe contents\n2\tof the file\n"
}`
var ReadFileTool = api.ToolSpec{
Name: "read_file",
Description: READ_FILE_DESCRIPTION,
Parameters: []api.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path to a file within the current working directory to read.",
Required: true,
},
},
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
result := readFile(path)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func readFile(path string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
data, err := os.ReadFile(path)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
return api.CallResult{
Result: toolutil.AddLineNumbers(string(data)),
}
}

View File

@ -1,78 +0,0 @@
package util
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// isPathContained attempts to verify whether `path` is the same as or
// contained within `directory`. It is overly cautious, returning false even if
// `path` IS contained within `directory`, but the two paths use different
// casing, and we happen to be on a case-insensitive filesystem.
// This is ultimately to attempt to stop an LLM from going outside of where I
// tell it to. Additional layers of security should be considered.. run in a
// VM/container.
func IsPathContained(directory string, path string) (bool, error) {
// Clean and resolve symlinks for both paths
path, err := filepath.Abs(path)
if err != nil {
return false, err
}
// check if path exists
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return false, fmt.Errorf("Could not stat path: %v", err)
}
} else {
path, err = filepath.EvalSymlinks(path)
if err != nil {
return false, err
}
}
directory, err = filepath.Abs(directory)
if err != nil {
return false, err
}
directory, err = filepath.EvalSymlinks(directory)
if err != nil {
return false, err
}
// Case insensitive checks
if !strings.EqualFold(path, directory) &&
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
return false, nil
}
return true, nil
}
func IsPathWithinCWD(path string) (bool, string) {
cwd, err := os.Getwd()
if err != nil {
return false, "Failed to determine current working directory"
}
if ok, err := IsPathContained(cwd, path); !ok {
if err != nil {
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
}
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
}
return true, ""
}
// AddLineNumbers takes a string of content and returns a new string with line
// numbers prefixed
func AddLineNumbers(content string) string {
lines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
result := strings.Builder{}
for i, line := range lines {
result.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return result.String()
}

View File

@ -1,71 +0,0 @@
package toolbox
import (
"fmt"
"os"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
Example result:
{
"message": "success"
}`
var WriteFileTool = api.ToolSpec{
Name: "write_file",
Description: WRITE_FILE_DESCRIPTION,
Parameters: []api.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path to a file within the current working directory to write to.",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "The content to write to the file. Overwrites any existing content!",
Required: true,
},
},
Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["content"]
if !ok {
return "", fmt.Errorf("Content parameter to write_file was not included.")
}
content, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
result := writeFile(path, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func writeFile(path string, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return api.CallResult{Message: reason}
}
err := os.WriteFile(path, []byte(content), 0644)
if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return api.CallResult{}
}

View File

@ -1,47 +0,0 @@
package agents
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
"dir_tree": toolbox.DirTreeTool,
"read_dir": toolbox.ReadDirTool,
"read_file": toolbox.ReadFileTool,
"modify_file": toolbox.ModifyFile,
"write_file": toolbox.WriteFileTool,
}
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
var toolResults []api.ToolResult
for _, call := range calls {
var tool *api.ToolSpec
for i := range available {
if available[i].Name == call.Name {
tool = &available[i]
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
}
// Execute the tool
result, err := tool.Impl(tool, call.Parameters)
if err != nil {
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
}
toolResult := api.ToolResult{
ToolCallID: call.ID,
ToolName: call.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -1,49 +0,0 @@
package api
import (
"context"
)
type ReplyCallback func(Message)
type Chunk struct {
Content string
TokenCount uint
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
Toolbox []ToolSpec
}
type ChatCompletionProvider interface {
// CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string.
CreateChatCompletion(
ctx context.Context,
params RequestParameters,
messages []Message,
) (*Message, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received.
CreateChatCompletionStream(
ctx context.Context,
params RequestParameters,
messages []Message,
chunks chan<- Chunk,
) (*Message, error)
}
func IsAssistantContinuation(messages []Message) bool {
if len(messages) == 0 {
return false
}
return messages[len(messages)-1].Role == MessageRoleAssistant
}

View File

@ -1,11 +0,0 @@
package api
import "database/sql"
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}

View File

@ -1,72 +0,0 @@
package api
import (
"time"
)
type MessageRole string
const (
MessageRoleSystem MessageRole = "system"
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
ID uint `gorm:"primaryKey"`
ConversationID *uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
Content string
Role MessageRole
CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
}
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
if len(m) > 0 && m[0].Role == MessageRoleSystem {
if force {
m[0].Content = system
}
return m
} else {
return append([]Message{{
Role: MessageRoleSystem,
Content: system,
}}, m...)
}
}
func (m *MessageRole) IsAssistant() bool {
switch *m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m MessageRole) FriendlyRole() string {
switch m {
case MessageRoleUser:
return "You"
case MessageRoleSystem:
return "System"
case MessageRoleAssistant:
return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default:
return string(m)
}
}

View File

@ -1,450 +0,0 @@
package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const ANTHROPIC_VERSION = "2023-06-01"
type AnthropicClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties map[string]Property `json:"properties"`
Required []string `json:"required"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
System string `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature,omitempty"`
Stream bool `json:"stream"`
}
type ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
partialJsonAccumulator string
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason"`
Usage Usage `json:"usage"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type StreamEvent struct {
Type string `json:"type"`
Message interface{} `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Delta interface{} `json:"delta,omitempty"`
}
func convertTools(tools []api.ToolSpec) []Tool {
anthropicTools := make([]Tool, len(tools))
for i, tool := range tools {
properties := make(map[string]Property)
for _, param := range tool.Parameters {
properties[param.Name] = Property{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
}
var required []string
for _, param := range tool.Parameters {
if param.Required {
required = append(required, param.Name)
}
}
anthropicTools[i] = Tool{
Name: tool.Name,
Description: tool.Description,
InputSchema: InputSchema{
Type: "object",
Properties: properties,
Required: required,
},
}
}
return anthropicTools
}
func createChatCompletionRequest(
params api.RequestParameters,
messages []api.Message,
) (string, ChatCompletionRequest) {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
var systemMessage string
for _, m := range messages {
if m.Role == api.MessageRoleSystem {
systemMessage = m.Content
continue
}
var content interface{}
role := string(m.Role)
switch m.Role {
case api.MessageRoleToolCall:
role = "assistant"
contentBlocks := make([]map[string]interface{}, 0)
if m.Content != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": m.Content,
})
}
for _, toolCall := range m.ToolCalls {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use",
"id": toolCall.ID,
"name": toolCall.Name,
"input": toolCall.Parameters,
})
}
content = contentBlocks
case api.MessageRoleToolResult:
role = "user"
contentBlocks := make([]map[string]interface{}, 0)
for _, result := range m.ToolResults {
contentBlock := map[string]interface{}{
"type": "tool_result",
"tool_use_id": result.ToolCallID,
"content": result.Result,
}
contentBlocks = append(contentBlocks, contentBlock)
}
content = contentBlocks
default:
content = m.Content
}
requestMessages = append(requestMessages, ChatCompletionMessage{
Role: role,
Content: content,
})
}
request := ChatCompletionRequest{
Model: params.Model,
Messages: requestMessages,
System: systemMessage,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
}
var prefill string
if api.IsAssistantContinuation(messages) {
prefill = messages[len(messages)-1].Content
}
return prefill, request
}
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages")
}
_, req := createChatCompletionRequest(params, messages)
req.Stream = false
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return convertResponseToMessage(completionResp)
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages")
}
prefill, req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
contentBlocks := make(map[int]*ContentBlock)
var finalMessage *ChatCompletionResponse
var firstChunkReceived bool
reader := bufio.NewReader(resp.Body)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, fmt.Errorf("error reading stream: %w", err)
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var streamEvent StreamEvent
err = json.Unmarshal(line, &streamEvent)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
}
switch streamEvent.Type {
case "message_start":
finalMessage = &ChatCompletionResponse{}
err = json.Unmarshal(line, &struct {
Message *ChatCompletionResponse `json:"message"`
}{Message: finalMessage})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
}
case "content_block_start":
var contentBlockStart struct {
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
err = json.Unmarshal(line, &contentBlockStart)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
}
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
case "content_block_delta":
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
}
block := contentBlocks[streamEvent.Index]
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
}
deltaType, ok := delta["type"].(string)
if !ok {
return nil, fmt.Errorf("delta missing type field")
}
switch deltaType {
case "text_delta":
if text, ok := delta["text"].(string); ok {
if !firstChunkReceived {
if prefill == "" {
// if there is no prefil, ensure we trim leading whitespace
text = strings.TrimSpace(text)
}
firstChunkReceived = true
}
block.Text += text
output <- api.Chunk{
Content: text,
TokenCount: 1,
}
}
case "input_json_delta":
if block.Type != "tool_use" {
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
}
if partialJSON, ok := delta["partial_json"].(string); ok {
block.partialJsonAccumulator += partialJSON
}
}
case "content_block_stop":
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
}
block := contentBlocks[streamEvent.Index]
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
var inputData map[string]interface{}
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
}
block.Input = inputData
}
case "message_delta":
if finalMessage == nil {
return nil, fmt.Errorf("received message_delta before message_start")
}
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
}
if stopReason, ok := delta["stop_reason"].(string); ok {
finalMessage.StopReason = stopReason
}
case "message_stop":
// End of the stream
goto END_STREAM
case "error":
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
default:
// Ignore unknown event types
}
}
}
END_STREAM:
if finalMessage == nil {
return nil, fmt.Errorf("no final message received")
}
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
for _, v := range contentBlocks {
finalMessage.Content = append(finalMessage.Content, *v)
}
return convertResponseToMessage(*finalMessage)
}
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
content := strings.Builder{}
var toolCalls []api.ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
content.WriteString(block.Text)
case "tool_use":
parameters, ok := block.Input.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
}
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Name: block.Name,
Parameters: parameters,
})
}
}
message := &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
ToolCalls: toolCalls,
}
if len(toolCalls) > 0 {
message.Role = api.MessageRoleToolCall
}
return message, nil
}

View File

@ -1,450 +0,0 @@
package google
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type Client struct {
APIKey string
BaseURL string
}
type ContentPart struct {
Text string `json:"text,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Args map[string]string `json:"args"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response interface{} `json:"response"`
}
type Content struct {
Role string `json:"role"`
Parts []ContentPart `json:"parts"`
}
type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
SystemInstruction *Content `json:"systemInstruction,omitempty"`
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
}
type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
}
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"`
UsageMetadata UsageMetadata `json:"usageMetadata"`
}
type Tool struct {
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
}
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Values []string `json:"values,omitempty"`
}
func convertTools(tools []api.ToolSpec) []Tool {
geminiTools := make([]Tool, len(tools))
for i, tool := range tools {
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
// TODO: proper enum handing
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Values: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
geminiTools[i] = Tool{
FunctionDeclarations: []FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "OBJECT",
Properties: params,
Required: required,
},
},
},
}
}
return geminiTools
}
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls {
args := make(map[string]string)
for k, v := range call.Parameters {
args[k] = fmt.Sprintf("%v", v)
}
converted[i].FunctionCall = &FunctionCall{
Name: call.Name,
Args: args,
}
}
return converted
}
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
converted := make([]api.ToolCall, len(functionCalls))
for i, call := range functionCalls {
params := make(map[string]interface{})
for k, v := range call.Args {
params[k] = v
}
converted[i].Name = call.Name
converted[i].Parameters = params
}
return converted
}
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
err := json.Unmarshal([]byte(result.Result), &obj)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
}
results[i] = FunctionResponse{
Name: result.ToolName,
Response: obj,
}
}
return results, nil
}
func createGenerateContentRequest(
params api.RequestParameters,
messages []api.Message,
) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages))
startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
system = messages[0].Content
startIdx = 1
}
for _, m := range messages[startIdx:] {
switch m.Role {
case "tool_call":
content := Content{
Role: "model",
Parts: convertToolCallToGemini(m.ToolCalls),
}
requestContents = append(requestContents, content)
case "tool_result":
results, err := convertToolResultsToGemini(m.ToolResults)
if err != nil {
return nil, err
}
// expand tool_result messages' results into multiple gemini messages
for _, result := range results {
content := Content{
Role: "function",
Parts: []ContentPart{
{
FunctionResp: &result,
},
},
}
requestContents = append(requestContents, content)
}
default:
var role string
switch m.Role {
case api.MessageRoleAssistant:
role = "model"
case api.MessageRoleUser:
role = "user"
}
if role == "" {
panic("Unhandled role: " + m.Role)
}
content := Content{
Role: role,
Parts: []ContentPart{
{
Text: m.Content,
},
},
}
requestContents = append(requestContents, content)
}
}
request := &GenerateContentRequest{
Contents: requestContents,
GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature,
TopP: &params.TopP,
},
}
if system != "" {
request.SystemInstruction = &Content{
Parts: []ContentPart{
{
Text: system,
},
},
}
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
}
return request, nil
}
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *Client) CreateChatCompletion(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return nil, err
}
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
choice := completionResp.Candidates[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content
}
var toolCalls []FunctionCall
for _, part := range choice.Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
}
}
if len(toolCalls) > 0 {
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
}
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return nil, err
}
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body)
lastTokenCount := 0
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var resp GenerateContentResponse
err = json.Unmarshal(line, &resp)
if err != nil {
return nil, err
}
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
lastTokenCount += tokens
choice := resp.Candidates[0]
for _, part := range choice.Content.Parts {
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" {
output <- api.Chunk{
Content: part.Text,
TokenCount: uint(tokens),
}
content.WriteString(part.Text)
}
}
}
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content.String(),
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
}

View File

@ -1,188 +0,0 @@
package ollama
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type OllamaClient struct {
BaseURL string
}
type OllamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type OllamaRequest struct {
Model string `json:"model"`
Messages []OllamaMessage `json:"messages"`
Stream bool `json:"stream"`
}
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message OllamaMessage `json:"message"`
Done bool `json:"done"`
TotalDuration uint64 `json:"total_duration,omitempty"`
LoadDuration uint64 `json:"load_duration,omitempty"`
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
EvalCount uint64 `json:"eval_count,omitempty"`
EvalDuration uint64 `json:"eval_duration,omitempty"`
}
func createOllamaRequest(
params api.RequestParameters,
messages []api.Message,
) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages))
for _, m := range messages {
message := OllamaMessage{
Role: string(m.Role),
Content: m.Content,
}
requestMessages = append(requestMessages, message)
}
request := OllamaRequest{
Model: params.Model,
Messages: requestMessages,
}
return request
}
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, nil
}
func (c *OllamaClient) CreateChatCompletion(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = false
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: completionResp.Message.Content,
}, nil
}
func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createOllamaRequest(params, messages)
req.Stream = true
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.sendRequest(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
}
if len(streamResp.Message.Content) > 0 {
output <- api.Chunk{
Content: streamResp.Message.Content,
TokenCount: 1,
}
content.WriteString(streamResp.Message.Content)
}
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
}

View File

@ -1,352 +0,0 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type OpenAIClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall struct {
Type string `json:"type"`
ID string `json:"id"`
Index *int `json:"index,omitempty"`
Function FunctionDefinition `json:"function"`
}
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
Arguments string `json:"arguments,omitempty"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Messages []ChatCompletionMessage `json:"messages"`
N int `json:"n"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionChoice struct {
Message ChatCompletionMessage `json:"message"`
}
type ChatCompletionResponse struct {
Choices []ChatCompletionChoice `json:"choices"`
}
type ChatCompletionStreamChoice struct {
Delta ChatCompletionMessage `json:"delta"`
}
type ChatCompletionStreamResponse struct {
Choices []ChatCompletionStreamChoice `json:"choices"`
}
func convertTools(tools []api.ToolSpec) []Tool {
openaiTools := make([]Tool, len(tools))
for i, tool := range tools {
openaiTools[i].Type = "function"
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
openaiTools[i].Function = FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "object",
Properties: params,
Required: required,
},
}
}
return openaiTools
}
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].Type = "function"
converted[i].ID = call.ID
converted[i].Function.Name = call.Name
json, _ := json.Marshal(call.Parameters)
converted[i].Function.Arguments = string(json)
}
return converted
}
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
converted := make([]api.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].ID = call.ID
converted[i].Name = call.Function.Name
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
}
return converted
}
func createChatCompletionRequest(
params api.RequestParameters,
messages []api.Message,
) ChatCompletionRequest {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
message := ChatCompletionMessage{}
message.Role = "assistant"
message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
requestMessages = append(requestMessages, message)
case "tool_result":
// expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults {
message := ChatCompletionMessage{}
message.Role = "tool"
message.Content = result.Result
message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message)
}
default:
message := ChatCompletionMessage{}
message.Role = string(m.Role)
message.Content = m.Content
requestMessages = append(requestMessages, message)
}
}
request := ChatCompletionRequest{
Model: params.Model,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Messages: requestMessages,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
request.ToolChoice = "auto"
}
return request
}
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createChatCompletionRequest(params, messages)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
choice := completionResp.Choices[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content + choice.Message.Content
} else {
content = choice.Message.Content
}
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
content := strings.Builder{}
toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(line, []byte("[DONE]")) {
break
}
var streamResp ChatCompletionStreamResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
}
delta := streamResp.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
}
if len(delta.Content) > 0 {
output <- api.Chunk{
Content: delta.Content,
TokenCount: 1,
}
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content.String(),
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
}

View File

@ -1,98 +0,0 @@
package api
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type ToolSpec struct {
Name string
Description string
Parameters []ToolParameter
Impl func(*ToolSpec, map[string]interface{}) (string, error)
}
type ToolParameter struct {
Name string `json:"name"`
Type string `json:"type"` // "string", "integer", "boolean"
Required bool `json:"required"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ToolCall struct {
ID string `json:"id" yaml:"-"`
Name string `json:"name" yaml:"tool"`
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
}
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

32
pkg/cli/cli.go Normal file
View File

@ -0,0 +1,32 @@
package cli
import (
"fmt"
"os"
)
var config *Config
var store *Store
func init() {
var err error
config, err = NewConfig()
if err != nil {
Fatal("%v\n", err)
}
store, err = NewStore()
if err != nil {
Fatal("%v\n", err)
}
}
func Fatal(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
os.Exit(1)
}
func Warn(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
}

550
pkg/cli/cmd.go Normal file
View File

@ -0,0 +1,550 @@
package cli
import (
"fmt"
"slices"
"strings"
"time"
"github.com/spf13/cobra"
)
var (
maxTokens int
model string
systemPrompt string
systemPromptFile string
)
func init() {
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd}
for _, cmd := range inputCmds {
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens")
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use")
cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "The system prompt to use.")
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file whose contents are used as the system prompt.")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
rootCmd.AddCommand(
continueCmd,
lsCmd,
newCmd,
promptCmd,
replyCmd,
retryCmd,
rmCmd,
viewCmd,
)
}
func Execute() error {
return rootCmd.Execute()
}
func SystemPrompt() string {
if systemPromptFile != "" {
content, err := FileContents(systemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v", systemPromptFile, err)
}
return content
}
return systemPrompt
}
// LLMRequest prompts the LLM with the given Message, writes the (partial)
// response to stdout, and returns the (partial) response or any errors.
func LLMRequest(messages []Message) (string, error) {
// receiver receives the reponse from LLM
receiver := make(chan string)
defer close(receiver)
// start HandleDelayedContent goroutine to print received data to stdout
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
if response != "" {
if err != nil {
Warn("Received partial response. Error: %v\n", err)
err = nil
}
// there was some content, so break to a new line after it
fmt.Println()
}
return response, err
}
// InputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates
// the editor and gets stripped from the final output.
func InputFromArgsOrEditor(args []string, placeholder string) (message string) {
var err error
if len(args) == 0 {
message, err = InputFromEditor(placeholder, "message.*.md")
if err != nil {
Fatal("Failed to get input: %v\n", err)
}
} else {
message = strings.Trim(strings.Join(args, " "), " \t\n")
}
return
}
var rootCmd = &cobra.Command{
Use: "lmcli",
Short: "Interact with Large Language Models",
Long: `lmcli is a CLI tool to interact with Large Language Models.`,
Run: func(cmd *cobra.Command, args []string) {
// execute `lm ls` by default
},
}
var lsCmd = &cobra.Command{
Use: "ls",
Short: "List existing conversations",
Long: `List all existing conversations in descending order of recent activity.`,
Run: func(cmd *cobra.Command, args []string) {
conversations, err := store.Conversations()
if err != nil {
fmt.Println("Could not fetch conversations.")
return
}
// Example output
// $ lmcli ls
// last hour:
// 98sg - 12 minutes ago - Project discussion
// last day:
// tj3l - 10 hours ago - Deep learning concepts
// last week:
// bwfm - 2 days ago - Machine learning study
// 8n3h - 3 days ago - Weekend plans
// f3n7 - 6 days ago - CLI development
// last month:
// 5hn2 - 8 days ago - Book club discussion
// b7ze - 20 days ago - Gardening tips and tricks
// last 6 months:
// 3jn2 - 30 days ago - Web development best practices
// 43jk - 2 months ago - Longboard maintenance
// g8d9 - 3 months ago - History book club
// 4lk3 - 4 months ago - Local events and meetups
// 43jn - 6 months ago - Mobile photography techniques
type ConversationLine struct {
timeSinceReply time.Duration
formatted string
}
now := time.Now()
categories := []string{
"recent",
"last hour",
"last 6 hours",
"last day",
"last week",
"last month",
"last 6 months",
"older",
}
categorized := map[string][]ConversationLine{}
for _, conversation := range conversations {
lastMessage, err := store.LastMessage(&conversation)
if lastMessage == nil || err != nil {
continue
}
messageAge := now.Sub(lastMessage.CreatedAt)
var category string
switch {
case messageAge <= 10*time.Minute:
category = "recent"
case messageAge <= time.Hour:
category = "last hour"
case messageAge <= 6*time.Hour:
category = "last 6 hours"
case messageAge <= 24*time.Hour:
category = "last day"
case messageAge <= 7*24*time.Hour:
category = "last week"
case messageAge <= 30*24*time.Hour:
category = "last month"
case messageAge <= 6*30*24*time.Hour: // Approximate as 6 months
category = "last 6 months"
default:
category = "older"
}
formatted := fmt.Sprintf(
"%s - %s - %s",
conversation.ShortName.String,
humanTimeElapsedSince(messageAge),
conversation.Title,
)
categorized[category] = append(
categorized[category],
ConversationLine{messageAge, formatted},
)
}
for _, category := range categories {
conversations, ok := categorized[category]
if !ok {
continue
}
slices.SortFunc(conversations, func(a, b ConversationLine) int {
return int(a.timeSinceReply - b.timeSinceReply)
})
fmt.Printf("%s:\n", category)
for _, conv := range conversations {
fmt.Printf(" %s\n", conv.formatted)
}
}
},
}
var rmCmd = &cobra.Command{
Use: "rm <conversation>",
Short: "Remove a conversation",
Long: `Removes a conversation by its short name.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if err != nil {
Fatal("Could not search for conversation: %v\n", err)
}
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
err = store.DeleteConversation(conversation)
if err != nil {
Fatal("Could not delete conversation: %v\n", err)
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}
var viewCmd = &cobra.Command{
Use: "view <conversation>",
Short: "View messages in a conversation",
Long: `Finds a conversation by its short name and displays its contents.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
messages, err := store.Messages(conversation)
if err != nil {
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
}
RenderConversation(messages, false)
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}
var replyCmd = &cobra.Command{
Use: "reply <conversation> [message]",
Short: "Send a reply to a conversation",
Long: `Sends a reply to conversation and writes the response to stdout.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
messages, err := store.Messages(conversation)
if err != nil {
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
}
messageContents := InputFromArgsOrEditor(args[1:], "# How would you like to reply?\n")
if messageContents == "" {
Fatal("No reply was provided.\n")
}
userReply := Message{
ConversationID: conversation.ID,
Role: "user",
OriginalContent: messageContents,
}
err = store.SaveMessage(&userReply)
if err != nil {
Warn("Could not save your reply: %v\n", err)
}
messages = append(messages, userReply)
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}
var newCmd = &cobra.Command{
Use: "new [message]",
Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`,
Run: func(cmd *cobra.Command, args []string) {
messageContents := InputFromArgsOrEditor(args, "# What would you like to say?\n")
if messageContents == "" {
Fatal("No message was provided.\n")
}
conversation := Conversation{}
err := store.SaveConversation(&conversation)
if err != nil {
Fatal("Could not save new conversation: %v\n", err)
}
messages := []Message{
{
ConversationID: conversation.ID,
Role: "system",
OriginalContent: SystemPrompt(),
},
{
ConversationID: conversation.ID,
Role: "user",
OriginalContent: messageContents,
},
}
for _, message := range messages {
err = store.SaveMessage(&message)
if err != nil {
Warn("Could not save %s message: %v\n", message.Role, err)
}
}
RenderConversation(messages, true)
reply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
reply.RenderTTY()
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
reply.OriginalContent = response
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save reply: %v\n", err)
}
err = conversation.GenerateTitle()
if err != nil {
Warn("Could not generate title for conversation: %v\n", err)
}
err = store.SaveConversation(&conversation)
if err != nil {
Warn("Could not save conversation after generating title: %v\n", err)
}
},
}
var promptCmd = &cobra.Command{
Use: "prompt [message]",
Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`,
Run: func(cmd *cobra.Command, args []string) {
message := InputFromArgsOrEditor(args, "# What would you like to say?\n")
if message == "" {
Fatal("No message was provided.\n")
}
messages := []Message{
{
Role: "system",
OriginalContent: SystemPrompt(),
},
{
Role: "user",
OriginalContent: message,
},
}
_, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
},
}
var retryCmd = &cobra.Command{
Use: "retry <conversation>",
Short: "Retries the last conversation prompt.",
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
messages, err := store.Messages(conversation)
if err != nil {
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
}
var lastUserMessageIndex int
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
lastUserMessageIndex = i
break
}
}
messages = messages[:lastUserMessageIndex+1]
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}
var continueCmd = &cobra.Command{
Use: "continue <conversation>",
Short: "Continues where the previous prompt left off.",
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
messages, err := store.Messages(conversation)
if err != nil {
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
}
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}

69
pkg/cli/config.go Normal file
View File

@ -0,0 +1,69 @@
package cli
import (
"fmt"
"os"
"path/filepath"
"github.com/go-yaml/yaml"
)
type Config struct {
ModelDefaults *struct {
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
} `yaml:"modelDefaults"`
OpenAI *struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"`
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
} `yaml:"openai"`
Chroma *struct {
Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"`
} `yaml:"chroma"`
}
func ConfigDir() string {
var configDir string
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
if xdgConfigHome != "" {
configDir = filepath.Join(xdgConfigHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
configDir = filepath.Join(userHomeDir, ".config/lmcli")
}
os.MkdirAll(configDir, 0755)
return configDir
}
func NewConfig() (*Config, error) {
configFile := filepath.Join(ConfigDir(), "config.yaml")
shouldWriteDefaults := false
c := &Config{}
configBytes, err := os.ReadFile(configFile)
if os.IsNotExist(err) {
shouldWriteDefaults = true
} else if err != nil {
return nil, fmt.Errorf("Could not read config file: %v", err)
} else {
yaml.Unmarshal(configBytes, c)
}
shouldWriteDefaults = SetStructDefaults(c)
if shouldWriteDefaults {
file, err := os.Create(configFile)
if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
}
bytes, _ := yaml.Marshal(c)
_, err = file.Write(bytes)
if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err)
}
}
return c, nil
}

56
pkg/cli/conversation.go Normal file
View File

@ -0,0 +1,56 @@
package cli
import (
"fmt"
"strings"
)
// FriendlyRole returns a human friendly signifier for the message's role.
func (m *Message) FriendlyRole() string {
var friendlyRole string
switch m.Role {
case "user":
friendlyRole = "You"
case "system":
friendlyRole = "System"
case "assistant":
friendlyRole = "Assistant"
default:
friendlyRole = m.Role
}
return friendlyRole
}
func (c *Conversation) GenerateTitle() error {
const header = "Generate a consise 4-5 word title for the conversation below."
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting())
messages := []Message{
{
Role: "user",
OriginalContent: prompt,
},
}
model := "gpt-3.5-turbo" // use cheap model to generate title
response, err := CreateChatCompletion(model, messages, 25)
if err != nil {
return err
}
c.Title = response
return nil
}
func (c *Conversation) FormatForExternalPrompting() string {
sb := strings.Builder{}
messages, err := store.Messages(c)
if err != nil {
Fatal("Could not retrieve messages for conversation %v", c)
}
for _, message := range messages {
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
}
return sb.String()
}

582
pkg/cli/functions.go Normal file
View File

@ -0,0 +1,582 @@
package cli
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
openai "github.com/sashabaranov/go-openai"
)
type FunctionResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
type FunctionParameter struct {
Type string `json:"type"` // "string", "integer", "boolean"
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type FunctionParameters struct {
Type string `json:"type"` // "object"
Properties map[string]FunctionParameter `json:"properties"`
Required []string `json:"required,omitempty"` // required function parameter names
}
type AvailableTool struct {
openai.Tool
// The tool's implementation. Returns a string, as tool call results
// are treated as normal messages with string contents.
Impl func(arguments map[string]interface{}) (string, error)
}
const (
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
Results are returned as JSON in the following format:
{
"message": "success", // if successful, or a different message indicating failure
// result may be an empty array if there are no files in the directory
"result": [
{"name": "a_file", "type": "file", "size": 123},
{"name": "a_directory/", "type": "dir", "size": 11},
... // more files or directories
]
}
For files, size represents the size (in bytes) of the file.
For directories, size represents the number of entries in that directory.`
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
Each line of the file is prefixed with its line number and a tabs (\t) to make
it make it easier to see which lines to change for other modifications.
Example result:
{
"message": "success", // if successful, or a different message indicating failure
"result": "1\tthe contents\n2\tof the file\n"
}`
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
Note: only use this tool when you've been explicitly asked to create or write to a file.
When using this function, you do not need to share the content you intend to write with the user first.
Example result:
{
"message": "success", // if successful, or a different message indicating failure
}`
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.`
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
Useful for re-writing snippets/blocks of code or entire functions.
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
)
var AvailableTools = map[string]AvailableTool{
"read_dir": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_dir",
Description: READ_DIR_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"relative_dir": {
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
return ReadDir(relativeDir), nil
},
},
"read_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_file",
Description: READ_FILE_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path to a file within the current working directory to read.",
},
},
Required: []string{"path"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
return ReadFile(path), nil
},
},
"write_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "write_file",
Description: WRITE_FILE_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path to a file within the current working directory to write to.",
},
"content": {
Type: "string",
Description: "The content to write to the file. Overwrites any existing content!",
},
},
Required: []string{"path", "content"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["content"]
if !ok {
return "", fmt.Errorf("Content parameter to write_file was not included.")
}
content, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
return WriteFile(path, content), nil
},
},
"file_insert_lines": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
},
"position": {
Type: "integer",
Description: `Which line to insert content *before*.`,
},
"content": {
Type: "string",
Description: `The content to insert.`,
},
},
Required: []string{"path", "position", "content"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var position int
tmp, ok = args["position"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
}
position = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
return FileInsertLines(path, position, content), nil
},
},
"file_replace_lines": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
},
"start_line": {
Type: "integer",
Description: `Line number which specifies the start of the replacement range (inclusive).`,
},
"end_line": {
Type: "integer",
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
},
"content": {
Type: "string",
Description: `Content to replace specified range. Omit to remove the specified range.`,
},
},
Required: []string{"path", "start_line"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start_line int
tmp, ok = args["start_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
}
start_line = int(tmp)
}
var end_line int
tmp, ok = args["end_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
}
end_line = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
return FileReplaceLines(path, start_line, end_line, content), nil
},
},
}
func resultToJson(result FunctionResult) string {
if result.Message == "" {
// When message not supplied, assume success
result.Message = "success"
}
jsonBytes, err := json.Marshal(result)
if err != nil {
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
}
return string(jsonBytes)
}
// ExecuteToolCalls handles the execution of all tool_calls provided, and
// returns their results formatted as []Message(s) with role: 'tool' and.
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
var toolResults []Message
for _, toolCall := range toolCalls {
if toolCall.Type != "function" {
// unsupported tool type
continue
}
tool, ok := AvailableTools[toolCall.Function.Name]
if !ok {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
}
var functionArgs map[string]interface{}
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
}
// TODO: ability to silence this
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
// Execute the tool
toolResult, err := tool.Impl(functionArgs)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
}
toolResults = append(toolResults, Message{
Role: "tool",
OriginalContent: toolResult,
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
// name is not required since the introduction of ToolCallID
// hypothesis: by setting it, we inform the model of what a
// function's purpose was if future requests omit the function
// definition
})
}
return toolResults, nil
}
// isPathContained attempts to verify whether `path` is the same as or
// contained within `directory`. It is overly cautious, returning false even if
// `path` IS contained within `directory`, but the two paths use different
// casing, and we happen to be on a case-insensitive filesystem.
// This is ultimately to attempt to stop an LLM from going outside of where I
// tell it to. Additional layers of security should be considered.. run in a
// VM/container.
func isPathContained(directory string, path string) (bool, error) {
// Clean and resolve symlinks for both paths
path, err := filepath.Abs(path)
if err != nil {
return false, err
}
// check if path exists
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return false, fmt.Errorf("Could not stat path: %v", err)
}
} else {
path, err = filepath.EvalSymlinks(path)
if err != nil {
return false, err
}
}
directory, err = filepath.Abs(directory)
if err != nil {
return false, err
}
directory, err = filepath.EvalSymlinks(directory)
if err != nil {
return false, err
}
// Case insensitive checks
if !strings.EqualFold(path, directory) &&
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
return false, nil
}
return true, nil
}
func isPathWithinCWD(path string) (bool, *FunctionResult) {
cwd, err := os.Getwd()
if err != nil {
return false, &FunctionResult{Message: "Failed to determine current working directory"}
}
if ok, err := isPathContained(cwd, path); !ok {
if err != nil {
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
}
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
}
return true, nil
}
func ReadDir(path string) string {
// TODO(?): implement whitelist - list of directories which model is allowed to work in
if path == "" {
path = "."
}
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
files, err := os.ReadDir(path)
if err != nil {
return resultToJson(FunctionResult{
Message: err.Error(),
})
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
name := f.Name()
if strings.HasPrefix(name, ".") {
// skip hidden files
continue
}
entryType := "file"
size := info.Size()
if info.IsDir() {
name += "/"
entryType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
size = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": name,
"type": entryType,
"size": size,
})
}
return resultToJson(FunctionResult{Result: dirContents})
}
func ReadFile(path string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
data, err := os.ReadFile(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
lines := strings.Split(string(data), "\n")
content := strings.Builder{}
for i, line := range lines {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return resultToJson(FunctionResult{
Result: content.String(),
})
}
func WriteFile(path string, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
err := os.WriteFile(path, []byte(content), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{})
}
func FileInsertLines(path string, position int, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
_, err = os.Create(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
}
data = []byte{}
}
if position < 1 {
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
before := lines[:position-1]
after := lines[position-1:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{Result: newContent})
}
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
_, err = os.Create(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
}
data = []byte{}
}
if startLine < 1 {
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
before := lines[:startLine-1]
after := lines[endLine:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{Result: newContent})
}

162
pkg/cli/openai.go Normal file
View File

@ -0,0 +1,162 @@
package cli
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
openai "github.com/sashabaranov/go-openai"
)
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
chatCompletionMessages := []openai.ChatCompletionMessage{}
for _, m := range messages {
message := openai.ChatCompletionMessage{
Role: m.Role,
Content: m.OriginalContent,
}
if m.ToolCallID.Valid {
message.ToolCallID = m.ToolCallID.String
}
if m.ToolCalls.Valid {
// unmarshal directly into chatMessage.ToolCalls
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
if err != nil {
// TODO: handle, this shouldn't really happen since
// we only save the successfully marshal'd data to database
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
}
}
chatCompletionMessages = append(chatCompletionMessages, message)
}
var tools []openai.Tool
for _, t := range AvailableTools {
// TODO: support some way to limit which tools are available per-request
tools = append(tools, t.Tool)
}
return openai.ChatCompletionRequest{
Model: model,
Messages: chatCompletionMessages,
MaxTokens: maxTokens,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
Tools: tools,
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
}
}
// CreateChatCompletion submits a Chat Completion API request and returns the
// response. CreateChatCompletion will recursively call itself in the case of
// tool calls, until a response is received with the final user-facing output.
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req)
if err != nil {
return "", err
}
choice := resp.Choices[0]
if len(choice.Message.ToolCalls) > 0 {
if choice.Message.Content != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletion with the tool call replies added
// to the original messages
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
}
// Return the user-facing message.
return choice.Message.Content, nil
}
// CreateChatCompletionStream submits a streaming Chat Completion API request
// and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil {
return "", err
}
defer stream.Close()
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
}
if e != nil {
err = e
break
}
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
if content.String() != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
messages = append(messages, Message{
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
}
return content.String(), err
}

133
pkg/cli/store.go Normal file
View File

@ -0,0 +1,133 @@
package cli
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"
"time"
sqids "github.com/sqids/sqids-go"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Store struct {
db *gorm.DB
sqids *sqids.Sqids
}
type Message struct {
ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation
OriginalContent string
Role string // one of: 'user', 'assistant', 'tool'
CreatedAt time.Time
ToolCallID sql.NullString
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
}
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
}
func DataDir() string {
var dataDir string
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
dataDir = filepath.Join(xdgDataHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
}
os.MkdirAll(dataDir, 0755)
return dataDir
}
func NewStore() (*Store, error) {
databaseFile := filepath.Join(DataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
models := []any{
&Conversation{},
&Message{},
}
for _, x := range models {
err := db.AutoMigrate(x)
if err != nil {
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
}
}
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &Store{db, _sqids}, nil
}
func (s *Store) SaveConversation(conversation *Conversation) error {
err := s.db.Save(&conversation).Error
if err != nil {
return err
}
if !conversation.ShortName.Valid {
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error
}
return err
}
func (s *Store) DeleteConversation(conversation *Conversation) error {
s.db.Where("conversation_id = ?", conversation.ID).Delete(&Message{})
return s.db.Delete(&conversation).Error
}
func (s *Store) SaveMessage(message *Message) error {
return s.db.Create(message).Error
}
func (s *Store) Conversations() ([]Conversation, error) {
var conversations []Conversation
err := s.db.Find(&conversations).Error
return conversations, err
}
func (s *Store) ConversationShortNameCompletions(shortName string) []string {
var completions []string
conversations, _ := s.Conversations() // ignore error for completions
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
}
func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) {
var conversation Conversation
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *Store) Messages(conversation *Conversation) ([]Message, error) {
var messages []Message
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
return messages, err
}
func (s *Store) LastMessage(conversation *Conversation) (*Message, error) {
var message Message
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
return &message, err
}

113
pkg/cli/tty.go Normal file
View File

@ -0,0 +1,113 @@
package cli
import (
"fmt"
"os"
"time"
"github.com/alecthomas/chroma/v2/quick"
"github.com/gookit/color"
)
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// noftify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
animationStep := 0
for {
select {
case _ = <-signal:
fmt.Print("\r")
signal <- ""
return
default:
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\r")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// HandleDelayedContent displays a waiting animation to stdout while waiting
// for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed.
func HandleDelayedContent(content <-chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range content {
if firstChunk {
// notify wait animation that we've received data
waitSignal <- ""
// wait for signal that wait animation has completed
<-waitSignal
firstChunk = false
}
fmt.Print(chunk)
}
}
// RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(messages []Message, spaceForResponse bool) {
l := len(messages)
for i, message := range messages {
message.RenderTTY()
if i < l-1 || spaceForResponse {
// print an additional space before the next message
fmt.Println()
}
}
}
// HighlightMarkdown applies syntax highlighting to the provided markdown text
// and writes it to stdout.
func HighlightMarkdown(markdownText string) error {
return quick.Highlight(os.Stdout, markdownText, "md", *config.Chroma.Formatter, *config.Chroma.Style)
}
func (m *Message) RenderTTY() {
var messageAge string
if m.CreatedAt.IsZero() {
messageAge = "now"
} else {
now := time.Now()
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
}
var roleStyle color.Style
switch m.Role {
case "system":
roleStyle = color.Style{color.HiRed}
case "user":
roleStyle = color.Style{color.HiGreen}
case "assistant":
roleStyle = color.Style{color.HiBlue}
default:
roleStyle = color.Style{color.FgWhite}
}
roleStyle.Add(color.Bold)
headerColor := color.FgYellow
separator := headerColor.Sprint("===")
timestamp := headerColor.Sprint(messageAge)
role := roleStyle.Sprint(m.FriendlyRole())
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
if m.OriginalContent != "" {
HighlightMarkdown(m.OriginalContent)
fmt.Println()
}
}

View File

@ -1,4 +1,4 @@
package util package cli
import ( import (
"fmt" "fmt"
@ -17,11 +17,11 @@ import (
// contents of the file exactly match the value of placeholder (no edits to the // contents of the file exactly match the value of placeholder (no edits to the
// file were made), then an empty string is returned. Otherwise, the contents // file were made), then an empty string is returned. Otherwise, the contents
// are returned. Example patten: message.*.md // are returned. Example patten: message.*.md
func InputFromEditor(placeholder string, pattern string, content string) (string, error) { func InputFromEditor(placeholder string, pattern string) (string, error) {
msgFile, _ := os.CreateTemp("/tmp", pattern) msgFile, _ := os.CreateTemp("/tmp", pattern)
defer os.Remove(msgFile.Name()) defer os.Remove(msgFile.Name())
os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend) os.WriteFile(msgFile.Name(), []byte(placeholder), os.ModeAppend)
editor := os.Getenv("EDITOR") editor := os.Getenv("EDITOR")
if editor == "" { if editor == "" {
@ -38,7 +38,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string
} }
bytes, _ := os.ReadFile(msgFile.Name()) bytes, _ := os.ReadFile(msgFile.Name())
content = string(bytes) content := string(bytes)
if placeholder != "" { if placeholder != "" {
if content == placeholder { if content == placeholder {
@ -56,7 +56,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string
// humanTimeElapsedSince returns a human-friendly "in the past" representation // humanTimeElapsedSince returns a human-friendly "in the past" representation
// of the given duration. // of the given duration.
func HumanTimeElapsedSince(d time.Duration) string { func humanTimeElapsedSince(d time.Duration) string {
seconds := d.Seconds() seconds := d.Seconds()
minutes := seconds / 60 minutes := seconds / 60
hours := minutes / 60 hours := minutes / 60
@ -137,8 +137,8 @@ func SetStructDefaults(data interface{}) bool {
} }
// Get the "default" struct tag // Get the "default" struct tag
defaultTag, ok := v.Type().Field(i).Tag.Lookup("default") defaultTag := v.Type().Field(i).Tag.Get("default")
if (!ok) { if defaultTag == "" {
continue continue
} }
@ -147,18 +147,10 @@ func SetStructDefaults(data interface{}) bool {
case reflect.String: case reflect.String:
defaultValue := defaultTag defaultValue := defaultTag
field.Set(reflect.ValueOf(&defaultValue)) field.Set(reflect.ValueOf(&defaultValue))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
field.Set(reflect.New(e))
field.Elem().SetUint(intValue)
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits()) intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
field.Set(reflect.New(e)) field.Set(reflect.New(e))
field.Elem().SetInt(intValue) field.Elem().SetInt(intValue)
case reflect.Float32, reflect.Float64:
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits())
field.Set(reflect.New(e))
field.Elem().SetFloat(floatValue)
case reflect.Bool: case reflect.Bool:
boolValue := defaultTag == "true" boolValue := defaultTag == "true"
field.Set(reflect.ValueOf(&boolValue)) field.Set(reflect.ValueOf(&boolValue))
@ -168,8 +160,10 @@ func SetStructDefaults(data interface{}) bool {
return changed return changed
} }
// ReadFileContents returns the string contents of the given file. // FileContents returns the string contents of the given file.
func ReadFileContents(file string) (string, error) { // TODO: we should support retrieving the content (or an approximation of)
// non-text documents, e.g. PDFs.
func FileContents(file string) (string, error) {
path := filepath.Clean(file) path := filepath.Clean(file)
content, err := os.ReadFile(path) content, err := os.ReadFile(path)
if err != nil { if err != nil {

View File

@ -1,48 +0,0 @@
package cmd
import (
"fmt"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui"
"github.com/spf13/cobra"
)
func ChatCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "chat [conversation]",
Short: "Open the chat interface",
Long: `Open the chat interface, optionally on a given conversation.`,
RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortname := ""
if len(args) == 1 {
shortname = args[0]
}
if shortname != ""{
_, err := cmdutil.LookupConversationE(ctx, shortname)
if err != nil {
return err
}
}
err = tui.Launch(ctx, shortname)
if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err)
}
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
applyGenerationFlags(ctx, cmd)
return cmd
}

View File

@ -1,47 +0,0 @@
package cmd
import (
"fmt"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func CloneCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "clone <conversation>",
Short: "Clone conversations",
Long: `Clones the provided conversation.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0]
toClone, err := cmdutil.LookupConversationE(ctx, shortName)
if err != nil {
return err
}
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
if err != nil {
return fmt.Errorf("Failed to clone conversation: %v", err)
}
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
return cmd
}

View File

@ -1,108 +0,0 @@
package cmd
import (
"fmt"
"slices"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/spf13/cobra"
)
func RootCmd(ctx *lmcli.Context) *cobra.Command {
var root = &cobra.Command{
Use: "lmcli <command> [flags]",
Long: `lmcli - Large Language Model CLI`,
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
cmd.Usage()
},
}
root.AddCommand(
ChatCmd(ctx),
ContinueCmd(ctx),
CloneCmd(ctx),
EditCmd(ctx),
ListCmd(ctx),
NewCmd(ctx),
PromptCmd(ctx),
RenameCmd(ctx),
ReplyCmd(ctx),
RetryCmd(ctx),
RemoveCmd(ctx),
ViewCmd(ctx),
)
return root
}
func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
f := cmd.Flags()
// -m, --model
f.StringVarP(
ctx.Config.Defaults.Model, "model", "m",
*ctx.Config.Defaults.Model, "Which model to generate a response with",
)
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
// -a, --agent
f.StringVarP(&ctx.Config.Defaults.Agent, "agent", "a", ctx.Config.Defaults.Agent, "Which agent to interact with")
cmd.RegisterFlagCompletionFunc("agent", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetAgents(), cobra.ShellCompDirectiveDefault
})
// --max-length
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
// --temperature
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
// --system-prompt
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
// --system-prompt-file
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
f := cmd.Flags()
model, err := f.GetString("model")
if err != nil {
return fmt.Errorf("Error parsing --model: %w", err)
}
if model != "" && !slices.Contains(ctx.GetModels(), model) {
return fmt.Errorf("Unknown model: %s", model)
}
agent, err := f.GetString("agent")
if err != nil {
return fmt.Errorf("Error parsing --agent: %w", err)
}
if agent != "" && !slices.Contains(ctx.GetAgents(), agent) {
return fmt.Errorf("Unknown agent: %s", agent)
}
return nil
}
// inputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates
// the editor and gets stripped from the final output.
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
var err error
if len(args) == 0 {
message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage)
if err != nil {
lmcli.Fatal("Failed to get input: %v\n", err)
}
} else {
message = strings.Join(args, " ")
}
message = strings.Trim(message, " \t\n")
return
}

View File

@ -1,78 +0,0 @@
package cmd
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "continue <conversation>",
Short: "Continue a conversation from the last message",
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("could not retrieve conversation messages: %v", err)
}
if len(messages) < 2 {
return fmt.Errorf("conversation expected to have at least 2 messages")
}
lastMessage := &messages[len(messages)-1]
if lastMessage.Role != api.MessageRoleAssistant {
return fmt.Errorf("the last message in the conversation is not an assistant message")
}
// Output the contents of the last message so far
fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err)
}
// Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
// Update the original message
err = ctx.Store.UpdateMessage(lastMessage)
if err != nil {
return fmt.Errorf("could not update the last message: %v", err)
}
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
applyGenerationFlags(ctx, cmd)
return cmd
}

View File

@ -1,99 +0,0 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func EditCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "edit <conversation>",
Short: "Edit the last user reply in a conversation",
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
}
offset, _ := cmd.Flags().GetInt("offset")
if offset < 0 {
offset = -offset
}
if offset > len(messages)-1 {
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
}
desiredIdx := len(messages) - 1 - offset
toEdit := messages[desiredIdx]
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
switch newContents {
case toEdit.Content:
return fmt.Errorf("No edits were made.")
case "":
return fmt.Errorf("No message was provided.")
}
toEdit.Content = newContents
role, _ := cmd.Flags().GetString("role")
if role != "" {
if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) {
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
}
toEdit.Role = api.MessageRole(role)
}
// Update the message in-place
inplace, _ := cmd.Flags().GetBool("in-place")
if inplace {
return ctx.Store.UpdateMessage(&toEdit)
}
// Otherwise, create a branch for the edited message
message, _, err := ctx.Store.CloneBranch(toEdit)
if err != nil {
return err
}
if desiredIdx > 0 {
// update selected reply
messages[desiredIdx-1].SelectedReply = message
err = ctx.Store.UpdateMessage(&messages[desiredIdx-1])
} else {
// update selected root
conversation.SelectedRoot = message
err = ctx.Store.UpdateConversation(conversation)
}
return err
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch")
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)")
return cmd
}

View File

@ -1,112 +0,0 @@
package cmd
import (
"fmt"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/spf13/cobra"
)
const (
LS_COUNT int = 5
)
func ListCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List conversations",
Long: `List conversations in order of recent activity`,
RunE: func(cmd *cobra.Command, args []string) error {
messages, err := ctx.Store.LatestConversationMessages()
if err != nil {
return fmt.Errorf("Could not fetch conversations: %v", err)
}
type Category struct {
name string
cutoff time.Duration
}
type ConversationLine struct {
timeSinceReply time.Duration
formatted string
}
now := time.Now()
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
dayOfWeek := int(now.Weekday())
categories := []Category{
{"today", now.Sub(midnight)},
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
{"this month", now.Sub(monthStart)},
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
{"older", now.Sub(time.Time{})},
}
categorized := map[string][]ConversationLine{}
all, _ := cmd.Flags().GetBool("all")
for _, message := range messages {
messageAge := now.Sub(message.CreatedAt)
var category string
for _, c := range categories {
if messageAge < c.cutoff {
category = c.name
break
}
}
formatted := fmt.Sprintf(
"%s - %s - %s",
message.Conversation.ShortName.String,
util.HumanTimeElapsedSince(messageAge),
message.Conversation.Title,
)
categorized[category] = append(
categorized[category],
ConversationLine{messageAge, formatted},
)
}
count, _ := cmd.Flags().GetInt("count")
var conversationsPrinted int
outer:
for _, category := range categories {
conversationLines, ok := categorized[category.name]
if !ok {
continue
}
fmt.Printf("%s:\n", category.name)
for _, conv := range conversationLines {
if conversationsPrinted >= count && !all {
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
break outer
}
fmt.Printf(" %s\n", conv.formatted)
conversationsPrinted++
}
}
return nil
},
}
cmd.Flags().BoolP("all", "a", false, "Show all conversations")
cmd.Flags().IntP("count", "c", LS_COUNT, "How many conversations to show")
return cmd
}

View File

@ -1,56 +0,0 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func NewCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "new [message]",
Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`,
RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
if input == "" {
return fmt.Errorf("No message was provided.")
}
messages := []api.Message{{
Role: api.MessageRoleUser,
Content: input,
}}
conversation, messages, err := ctx.Store.StartConversation(messages...)
if err != nil {
return fmt.Errorf("Could not start a new conversation: %v", err)
}
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
title, err := cmdutil.GenerateTitle(ctx, messages)
if err != nil {
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
}
conversation.Title = title
err = ctx.Store.UpdateConversation(conversation)
if err != nil {
lmcli.Warn("Could not save conversation title: %v\n", err)
}
return nil
},
}
applyGenerationFlags(ctx, cmd)
return cmd
}

View File

@ -1,43 +0,0 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func PromptCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "prompt [message]",
Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`,
RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
if input == "" {
return fmt.Errorf("No message was provided.")
}
messages := []api.Message{{
Role: api.MessageRoleUser,
Content: input,
}}
_, err = cmdutil.Prompt(ctx, messages, nil)
if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err)
}
return nil
},
}
applyGenerationFlags(ctx, cmd)
return cmd
}

View File

@ -1,60 +0,0 @@
package cmd
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "rm <conversation>...",
Short: "Remove conversations",
Long: `Remove conversations by their short names.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
var toRemove []*api.Conversation
for _, shortName := range args {
conversation := cmdutil.LookupConversation(ctx, shortName)
toRemove = append(toRemove, conversation)
}
var errors []error
for _, c := range toRemove {
err := ctx.Store.DeleteConversation(c)
if err != nil {
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
}
}
if len(errors) > 0 {
return fmt.Errorf("Could not remove some conversations: %v", errors)
}
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
var completions []string
outer:
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
parts := strings.Split(completion, "\t")
for _, arg := range args {
if parts[0] == arg {
continue outer
}
}
completions = append(completions, completion)
}
return completions, compMode
},
}
return cmd
}

View File

@ -1,67 +0,0 @@
package cmd
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
)
func RenameCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "rename <conversation> [title]",
Short: "Rename a conversation",
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName)
var err error
var title string
generate, _ := cmd.Flags().GetBool("generate")
if generate {
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
}
title, err = cmdutil.GenerateTitle(ctx, messages)
if err != nil {
return fmt.Errorf("Could not generate conversation title: %v", err)
}
} else {
if len(args) < 2 {
return fmt.Errorf("Conversation title not provided.")
}
title = strings.Join(args[1:], " ")
}
conversation.Title = title
err = ctx.Store.UpdateConversation(conversation)
if err != nil {
lmcli.Warn("Could not update conversation title: %v\n", err)
}
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
cmd.Flags().Bool("generate", false, "Generate a conversation title")
return cmd
}

View File

@ -1,55 +0,0 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "reply <conversation> [message]",
Short: "Reply to a conversation",
Long: `Sends a reply to conversation and writes the response to stdout.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName)
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
if reply == "" {
return fmt.Errorf("No reply was provided.")
}
cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{
Role: api.MessageRoleUser,
Content: reply,
})
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
applyGenerationFlags(ctx, cmd)
return cmd
}

View File

@ -1,78 +0,0 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func RetryCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "retry <conversation>",
Short: "Retry the last user reply in a conversation",
Long: `Prompt the conversation from the last user response.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
err := validateGenerationFlags(ctx, cmd)
if err != nil {
return err
}
shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName)
// Load the complete thread from the root message
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
}
offset, _ := cmd.Flags().GetInt("offset")
if offset < 0 {
offset = -offset
}
if offset > len(messages)-1 {
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
}
retryFromIdx := len(messages) - 1 - offset
// decrease retryFromIdx until we hit a user message
for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser {
retryFromIdx--
}
if messages[retryFromIdx].Role != api.MessageRoleUser {
return fmt.Errorf("No user messages to retry")
}
fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx])
// Start a new branch at the last user message
cmdutil.HandleReply(ctx, &messages[retryFromIdx], true)
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
cmd.Flags().Int("offset", 0, "Offset from the last message to retry from.")
applyGenerationFlags(ctx, cmd)
return cmd
}

View File

@ -1,337 +0,0 @@
package util
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss"
)
// Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
if err != nil {
return nil, err
}
params := api.RequestParameters{
Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature,
}
system := ctx.DefaultSystemPrompt()
agent := ctx.GetAgent(ctx.Config.Defaults.Agent)
if agent != nil {
if agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
params.Toolbox = agent.Toolbox
}
if system != "" {
messages = api.ApplySystemPrompt(messages, system, false)
}
content := make(chan api.Chunk)
defer close(content)
// render the content received over the channel
go ShowDelayedContent(content)
reply, err := provider.CreateChatCompletionStream(
context.Background(), params, messages, content,
)
if reply.Content != "" {
// there was some content, so break to a new line after it
fmt.Println()
if err != nil {
lmcli.Warn("Received partial response. Error: %v\n", err)
err = nil
}
}
return reply, err
}
// lookupConversation either returns the conversation found by the
// short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err)
}
if c.ID == 0 {
lmcli.Fatal("Conversation not found: %s\n", shortName)
}
return c
}
func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
}
if c.ID == 0 {
return nil, fmt.Errorf("Conversation not found: %s", shortName)
}
return c, nil
}
func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) {
if to == nil {
lmcli.Fatal("Can't prompt from an empty message.")
}
existing, err := ctx.Store.PathToRoot(to)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []api.Message
if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil {
lmcli.Warn("Could not save messages: %v\n", err)
}
}
// render a message header with no contents
RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant}))
var lastSavedMessage *api.Message
lastSavedMessage = to
if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply api.Message) {
if !persist {
return
}
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
lastSavedMessage = &savedReplies[0]
}
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
}
func FormatForExternalPrompt(messages []api.Message, system bool) string {
sb := strings.Builder{}
for _, message := range messages {
if message.Content == "" {
continue
}
switch message.Role {
case api.MessageRoleAssistant, api.MessageRoleToolCall:
sb.WriteString("Assistant:\n\n")
case api.MessageRoleUser:
sb.WriteString("User:\n\n")
default:
continue
}
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
}
return sb.String()
}
func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) {
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
Example conversation:
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}]
Example response:
{"title": "Help with math homework"}
`
type msg struct {
Role string
Content string
}
var msgs []msg
for _, m := range messages {
switch m.Role {
case api.MessageRoleAssistant, api.MessageRoleUser:
msgs = append(msgs, msg{string(m.Role), m.Content})
}
}
// Serialize the conversation to JSON
conversation, err := json.Marshal(msgs)
if err != nil {
return "", err
}
generateRequest := []api.Message{
{
Role: api.MessageRoleSystem,
Content: systemPrompt,
},
{
Role: api.MessageRoleUser,
Content: string(conversation),
},
}
m, provider, err := ctx.GetModelProvider(
*ctx.Config.Conversations.TitleGenerationModel,
)
if err != nil {
return "", err
}
requestParams := api.RequestParameters{
Model: m,
MaxTokens: 25,
}
response, err := provider.CreateChatCompletion(
context.Background(), requestParams, generateRequest,
)
if err != nil {
return "", err
}
// Parse the JSON response
var jsonResponse struct {
Title string `json:"title"`
}
err = json.Unmarshal([]byte(response.Content), &jsonResponse)
if err != nil {
return "", err
}
return jsonResponse.Title, nil
}
// ShowWaitAnimation prints an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// notify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
// Save the current cursor position
fmt.Print("\033[s")
animationStep := 0
for {
select {
case _ = <-signal:
// Relmcli the cursor position
fmt.Print("\033[u")
signal <- ""
return
default:
// Move the cursor to the saved position
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\033[u")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// ShowDelayedContent displays a waiting animation to stdout while waiting
// for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed.
func ShowDelayedContent(content <-chan api.Chunk) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range content {
if firstChunk {
// notify wait animation that we've received data
waitSignal <- ""
// wait for signal that wait animation has completed
<-waitSignal
firstChunk = false
}
fmt.Print(chunk.Content)
}
}
// RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) {
l := len(messages)
for i, message := range messages {
RenderMessage(ctx, &message)
if i < l-1 || spaceForResponse {
// print an additional space before the next message
fmt.Println()
}
}
}
func RenderMessage(ctx *lmcli.Context, m *api.Message) {
var messageAge string
if m.CreatedAt.IsZero() {
messageAge = "now"
} else {
now := time.Now()
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
}
headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role {
case api.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case api.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case api.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
}
role := headerStyle.Render(m.Role.FriendlyRole())
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
separator := separatorStyle.Render("===")
timestamp := separatorStyle.Render(messageAge)
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
if m.Content != "" {
ctx.Chroma.Highlight(os.Stdout, m.Content)
fmt.Println()
}
}

View File

@ -1,45 +0,0 @@
package cmd
import (
"fmt"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra"
)
func ViewCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "view <conversation>",
Short: "View messages in a conversation",
Long: `Finds a conversation by its short name and displays its contents.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
}
cmdutil.RenderConversation(ctx, messages, false)
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
return cmd
}

View File

@ -1,76 +0,0 @@
package lmcli
import (
"fmt"
"os"
"git.mlow.ca/mlow/lmcli/pkg/util"
"gopkg.in/yaml.v3"
)
type Config struct {
Defaults *struct {
Model *string `yaml:"model" default:"gpt-4"`
MaxTokens *int `yaml:"maxTokens" default:"256"`
Temperature *float32 `yaml:"temperature" default:"0.2"`
SystemPrompt string `yaml:"systemPrompt,omitempty"`
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
// CLI only
Agent string `yaml:"-"`
} `yaml:"defaults"`
Conversations *struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
} `yaml:"conversations"`
Chroma *struct {
Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"`
} `yaml:"chroma"`
Agents []*struct {
Name string `yaml:"name"`
SystemPrompt string `yaml:"systemPrompt"`
Tools []string `yaml:"tools"`
} `yaml:"agents"`
Providers []*struct {
Name string `yaml:"name,omitempty"`
Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"`
Models []string `yaml:"models"`
} `yaml:"providers"`
}
func NewConfig(configFile string) (*Config, error) {
shouldWriteDefaults := false
c := &Config{}
configExists := true
configBytes, err := os.ReadFile(configFile)
if os.IsNotExist(err) {
configExists = false
} else if err != nil {
return nil, fmt.Errorf("Could not read config file: %v", err)
} else {
yaml.Unmarshal(configBytes, c)
}
shouldWriteDefaults = util.SetStructDefaults(c)
if !configExists || shouldWriteDefaults {
if configExists {
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
os.Rename(configFile, configFile+".bak")
}
fmt.Printf("Writing configuration file to %s\n", configFile)
file, err := os.Create(configFile)
if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
}
encoder := yaml.NewEncoder(file)
encoder.SetIndent(2)
err = encoder.Encode(c)
if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err)
}
}
return c, nil
}

View File

@ -1,229 +0,0 @@
package lmcli
import (
"fmt"
"os"
"path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Agent struct {
Name string
SystemPrompt string
Toolbox []api.ToolSpec
}
type Context struct {
// high level app configuration, may be mutated at runtime
Config Config
Store ConversationStore
Chroma *tty.ChromaHighlighter
}
func NewContext() (*Context, error) {
configFile := filepath.Join(configDir(), "config.yaml")
config, err := NewConfig(configFile)
if err != nil {
return nil, err
}
databaseFile := filepath.Join(dataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
store, err := NewSQLStore(db)
if err != nil {
return nil, err
}
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
return &Context{*config, store, chroma}, nil
}
func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers {
name := p.Kind
if p.Name != "" {
name = p.Name
}
for _, m := range p.Models {
modelCounts[m]++
models = append(models, fmt.Sprintf("%s@%s", m, name))
}
}
for m, c := range modelCounts {
if c == 1 {
models = append(models, m)
}
}
return
}
func (c *Context) GetAgents() (agents []string) {
for _, p := range c.Config.Agents {
agents = append(agents, p.Name)
}
return
}
func (c *Context) GetAgent(name string) *Agent {
if name == "" {
return nil
}
for _, a := range c.Config.Agents {
if name != a.Name {
continue
}
var enabledTools []api.ToolSpec
for _, toolName := range a.Tools {
tool, ok := agents.AvailableTools[toolName]
if ok {
enabledTools = append(enabledTools, tool)
}
}
return &Agent{
Name: a.Name,
SystemPrompt: a.SystemPrompt,
Toolbox: enabledTools,
}
}
return nil
}
func (c *Context) DefaultSystemPrompt() string {
if c.Config.Defaults.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
}
return content
}
return c.Config.Defaults.SystemPrompt
}
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
parts := strings.Split(model, "@")
var provider string
if len(parts) > 1 {
model = parts[0]
provider = parts[1]
}
for _, p := range c.Config.Providers {
name := p.Kind
if p.Name != "" {
name = p.Name
}
if provider != "" && name != provider {
continue
}
for _, m := range p.Models {
if m == model {
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &google.Client{
BaseURL: url,
APIKey: p.APIKey,
}, nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &ollama.OllamaClient{
BaseURL: url,
}, nil
case "openai":
url := "https://api.openai.com"
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
}, nil
default:
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
}
}
}
}
return "", nil, fmt.Errorf("unknown model: %s", model)
}
func configDir() string {
var configDir string
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
if xdgConfigHome != "" {
configDir = filepath.Join(xdgConfigHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
configDir = filepath.Join(userHomeDir, ".config/lmcli")
}
os.MkdirAll(configDir, 0755)
return configDir
}
func dataDir() string {
var dataDir string
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
dataDir = filepath.Join(xdgDataHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
}
os.MkdirAll(dataDir, 0755)
return dataDir
}
func Fatal(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
os.Exit(1)
}
func Warn(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
}

View File

@ -1,433 +0,0 @@
package lmcli
import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
sqids "github.com/sqids/sqids-go"
"gorm.io/gorm"
)
type ConversationStore interface {
ConversationByShortName(shortName string) (*api.Conversation, error)
ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]api.Message, error)
LatestConversationMessages() ([]api.Message, error)
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
UpdateConversation(conversation *api.Conversation) error
DeleteConversation(conversation *api.Conversation) error
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
MessageByID(messageID uint) (*api.Message, error)
MessageReplies(messageID uint) ([]api.Message, error)
UpdateMessage(message *api.Message) error
DeleteMessage(message *api.Message, prune bool) error
CloneBranch(toClone api.Message) (*api.Message, uint, error)
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
PathToRoot(message *api.Message) ([]api.Message, error)
PathToLeaf(message *api.Message) ([]api.Message, error)
}
type SQLStore struct {
db *gorm.DB
sqids *sqids.Sqids
}
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
models := []any{
&api.Conversation{},
&api.Message{},
}
for _, x := range models {
err := db.AutoMigrate(x)
if err != nil {
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
}
}
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &SQLStore{db, _sqids}, nil
}
func (s *SQLStore) createConversation() (*api.Conversation, error) {
// Create the new conversation
c := &api.Conversation{}
err := s.db.Save(c).Error
if err != nil {
return nil, err
}
// Generate and save its "short name"
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
c.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Updates(c).Error
if err != nil {
return nil, err
}
return c, nil
}
func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
}
return s.db.Updates(c).Error
}
func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
// Delete messages first
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
if err != nil {
return err
}
return s.db.Delete(c).Error
}
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
panic("Not yet implemented")
//return s.db.Delete(&message).Error
}
func (s *SQLStore) UpdateMessage(m *api.Message) error {
if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)")
}
return s.db.Updates(m).Error
}
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var conversations []api.Conversation
// ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
}
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
if shortName == "" {
return nil, errors.New("shortName is empty")
}
var conversation api.Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
var rootMessages []api.Message
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil {
return nil, err
}
return rootMessages, nil
}
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
var message api.Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err
}
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
var replies []api.Message
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
return replies, err
}
// StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
// Create new conversation
conversation, err := s.createConversation()
if err != nil {
return nil, nil, err
}
// Create first message
messages[0].Conversation = conversation
err = s.db.Create(&messages[0]).Error
if err != nil {
return nil, nil, err
}
// Update conversation's selected root message
conversation.SelectedRoot = &messages[0]
err = s.UpdateConversation(conversation)
if err != nil {
return nil, nil, err
}
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]api.Message{messages[0]}, newMessages...)
}
return conversation, messages, nil
}
// CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID)
if err != nil {
return nil, 0, err
}
clone, err := s.createConversation()
if err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
}
clone.Title = toClone.Title + " - Clone"
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = &clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
// Reply to a message with a series of messages (each following the next)
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
var savedMessages []api.Message
err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to
for i := range messages {
parent := currentParent
message := messages[i]
message.Parent = parent
message.Conversation = parent.Conversation
message.ID = 0
message.CreatedAt = time.Time{}
if err := tx.Create(&message).Error; err != nil {
return err
}
// update parent selected reply
parent.Replies = append(parent.Replies, message)
parent.SelectedReply = &message
if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil {
return err
}
savedMessages = append(savedMessages, message)
currentParent = &message
}
return nil
})
return savedMessages, err
}
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the messageToClone
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
newMessage.Replies = nil
newMessage.SelectedReplyID = nil
newMessage.SelectedReply = nil
originalReplies, err := s.MessageReplies(messageToClone.ID)
if err != nil {
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
}
if err := s.db.Create(&newMessage).Error; err != nil {
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
}
var replyCount uint = 0
for _, reply := range originalReplies {
replyCount++
newReply := reply
newReply.ConversationID = messageToClone.ConversationID
newReply.ParentID = &newMessage.ID
newReply.Parent = &newMessage
res, c, err := s.CloneBranch(newReply)
if err != nil {
return nil, 0, err
}
newMessage.Replies = append(newMessage.Replies, *res)
replyCount += c
if reply.ID == *messageToClone.SelectedReplyID {
newMessage.SelectedReplyID = &res.ID
if err := s.UpdateMessage(&newMessage); err != nil {
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
}
}
}
return &newMessage, replyCount, nil
}
func fetchMessages(db *gorm.DB) ([]api.Message, error) {
var messages []api.Message
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err)
}
messageMap := make(map[uint]api.Message)
for i, message := range messages {
messageMap[messages[i].ID] = message
}
// Create a map to store replies by their parent ID
repliesMap := make(map[uint][]api.Message)
for i, message := range messages {
if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
}
}
// Assign replies, parent, and selected reply to each message
for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]api.Message, len(replies))
for j, m := range replies {
messages[i].Replies[j] = m
}
}
if messages[i].ParentID != nil {
if parent, exists := messageMap[*messages[i].ParentID]; exists {
messages[i].Parent = &parent
}
}
if messages[i].SelectedReplyID != nil {
if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists {
messages[i].SelectedReply = &selectedReply
}
}
}
return messages, nil
}
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
var messages []api.Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil {
return nil, err
}
// Create a map to store messages by their ID
messageMap := make(map[uint]*api.Message)
for i := range messages {
messageMap[messages[i].ID] = &messages[i]
}
// Build the path
var path []api.Message
nextID := &message.ID
for {
current, exists := messageMap[*nextID]
if !exists {
return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID)
}
path = append(path, *current)
nextID = getNext(current)
if nextID == nil {
break
}
}
return path, nil
}
// PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
path, err := s.buildPath(message, func(m *api.Message) *uint {
return m.ParentID
})
if err != nil {
return nil, err
}
slices.Reverse(path)
return path, nil
}
// PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
return s.buildPath(message, func(m *api.Message) *uint {
return m.SelectedReplyID
})
}
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
var latestMessages []api.Message
subQuery := s.db.Model(&api.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id")
err := s.db.Model(&api.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Group("messages.conversation_id").
Order("created_at DESC").
Preload("Conversation").
Find(&latestMessages).Error
if err != nil {
return nil, err
}
return latestMessages, nil
}

View File

@ -1,67 +0,0 @@
package bubbles
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type ConfirmPrompt struct {
Question string
Style lipgloss.Style
Payload interface{}
value bool
answered bool
focused bool
}
func NewConfirmPrompt(question string, payload interface{}) ConfirmPrompt {
return ConfirmPrompt{
Question: question,
Style: lipgloss.NewStyle(),
Payload: payload,
focused: true, // focus by default
}
}
type MsgConfirmPromptAnswered struct {
Value bool
Payload interface{}
}
func (b ConfirmPrompt) Update(msg tea.Msg) (ConfirmPrompt, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
if !b.focused || b.answered {
return b, nil
}
switch msg.String() {
case "y", "Y":
b.value = true
b.answered = true
b.focused = false
return b, func() tea.Msg { return MsgConfirmPromptAnswered{true, b.Payload} }
case "n", "N", "esc":
b.value = false
b.answered = true
b.focused = false
return b, func() tea.Msg { return MsgConfirmPromptAnswered{false, b.Payload} }
}
}
return b, nil
}
func (b ConfirmPrompt) View() string {
return b.Style.Render(b.Question) + lipgloss.NewStyle().Faint(true).Render(" (y/n)")
}
func (b *ConfirmPrompt) Focus() {
b.focused = true
}
func (b *ConfirmPrompt) Blur() {
b.focused = false
}
func (b ConfirmPrompt) Focused() bool {
return b.focused
}

View File

@ -1,52 +0,0 @@
package shared
import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
tea "github.com/charmbracelet/bubbletea"
)
type Values struct {
ConvShortname string
}
type Shared struct {
Ctx *lmcli.Context
Values *Values
Width int
Height int
Err error
}
// a convenience struct for holding rendered content for indiviudal UI
// elements
type Sections struct {
Header string
Content string
Error string
Input string
Footer string
}
type (
// send to change the current state
MsgViewChange View
// sent to a state when it is entered
MsgViewEnter struct{}
// sent when an error occurs
MsgError error
)
func WrapError(err error) tea.Cmd {
return func() tea.Msg {
return MsgError(err)
}
}
type View int
const (
StateChat View = iota
StateConversations
//StateSettings
//StateHelp
)

View File

@ -1,8 +0,0 @@
package styles
import "github.com/charmbracelet/lipgloss"
var Header = lipgloss.NewStyle().
PaddingLeft(1).
PaddingRight(1).
Background(lipgloss.Color("0"))

View File

@ -1,133 +0,0 @@
package tui
// The terminal UI for lmcli, launched from the `lmcli chat` command
// TODO:
// - change model
// - rename conversation
// - set system prompt
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/chat"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations"
tea "github.com/charmbracelet/bubbletea"
)
// Application model
type Model struct {
shared.Shared
state shared.View
chat chat.Model
conversations conversations.Model
}
func initialModel(ctx *lmcli.Context, values shared.Values) Model {
m := Model{
Shared: shared.Shared{
Ctx: ctx,
Values: &values,
},
}
m.state = shared.StateChat
m.chat = chat.Chat(m.Shared)
m.conversations = conversations.Conversations(m.Shared)
return m
}
func (m Model) Init() tea.Cmd {
return tea.Batch(
m.conversations.Init(),
m.chat.Init(),
func() tea.Msg {
return shared.MsgViewChange(m.state)
},
)
}
func (m *Model) handleGlobalInput(msg tea.KeyMsg) (bool, tea.Cmd) {
// delegate input to the active child state first, only handling it at the
// global level if the child state does not
var cmds []tea.Cmd
switch m.state {
case shared.StateChat:
handled, cmd := m.chat.HandleInput(msg)
cmds = append(cmds, cmd)
if handled {
m.chat, cmd = m.chat.Update(nil)
cmds = append(cmds, cmd)
return true, tea.Batch(cmds...)
}
case shared.StateConversations:
handled, cmd := m.conversations.HandleInput(msg)
cmds = append(cmds, cmd)
if handled {
m.conversations, cmd = m.conversations.Update(nil)
cmds = append(cmds, cmd)
return true, tea.Batch(cmds...)
}
}
switch msg.String() {
case "ctrl+c", "ctrl+q":
return true, tea.Quit
}
return false, nil
}
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
handled, cmd := m.handleGlobalInput(msg)
if handled {
return m, cmd
}
case shared.MsgViewChange:
m.state = shared.View(msg)
switch m.state {
case shared.StateChat:
m.chat.HandleResize(m.Width, m.Height)
case shared.StateConversations:
m.conversations.HandleResize(m.Width, m.Height)
}
return m, func() tea.Msg { return shared.MsgViewEnter(struct{}{}) }
case tea.WindowSizeMsg:
m.Width, m.Height = msg.Width, msg.Height
}
var cmd tea.Cmd
switch m.state {
case shared.StateConversations:
m.conversations, cmd = m.conversations.Update(msg)
case shared.StateChat:
m.chat, cmd = m.chat.Update(msg)
}
if cmd != nil {
cmds = append(cmds, cmd)
}
return m, tea.Batch(cmds...)
}
func (m Model) View() string {
switch m.state {
case shared.StateConversations:
return m.conversations.View()
case shared.StateChat:
return m.chat.View()
}
return ""
}
func Launch(ctx *lmcli.Context, convShortname string) error {
p := tea.NewProgram(initialModel(ctx, shared.Values{ConvShortname: convShortname}), tea.WithAltScreen())
if _, err := p.Run(); err != nil {
return fmt.Errorf("Error running program: %v", err)
}
return nil
}

View File

@ -1,101 +0,0 @@
package util
import (
"fmt"
"os"
"os/exec"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/ansi"
)
type MsgTempfileEditorClosed string
// OpenTempfileEditor opens $EDITOR on a temporary file with the given content.
// Upon closing, the contents of the file are read and returned wrapped in a
// MsgTempfileEditorClosed
func OpenTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
msgFile, _ := os.CreateTemp("/tmp", pattern)
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
if err != nil {
return func() tea.Msg { return err }
}
editor := os.Getenv("EDITOR")
if editor == "" {
editor = "vim"
}
c := exec.Command(editor, msgFile.Name())
return tea.ExecProcess(c, func(err error) tea.Msg {
bytes, err := os.ReadFile(msgFile.Name())
if err != nil {
return err
}
os.Remove(msgFile.Name())
fileContents := string(bytes)
if strings.HasPrefix(fileContents, placeholder) {
fileContents = fileContents[len(placeholder):]
}
stripped := strings.Trim(fileContents, "\n \t")
return MsgTempfileEditorClosed(stripped)
})
}
// similar to lipgloss.Height, except returns 0 instead of 1 on empty strings
func Height(str string) int {
if str == "" {
return 0
}
return strings.Count(str, "\n") + 1
}
// truncate a string until its rendered cell width + the provided tail fits
// within the given width
func TruncateToCellWidth(str string, width int, tail string) string {
cellWidth := ansi.PrintableRuneWidth(str)
if cellWidth <= width {
return str
}
tailWidth := ansi.PrintableRuneWidth(tail)
for {
str = str[:len(str)-((cellWidth+tailWidth)-width)]
cellWidth = ansi.PrintableRuneWidth(str)
if cellWidth+tailWidth <= max(width, 0) {
break
}
}
return str + tail
}
func ScrollIntoView(vp *viewport.Model, offset int, edge int) {
currentOffset := vp.YOffset
if offset >= currentOffset && offset < currentOffset+vp.Height {
return
}
distance := currentOffset - offset
if distance < 0 {
// we should scroll down until it just comes into view
vp.SetYOffset(currentOffset - (distance + (vp.Height - edge)) + 1)
} else {
// we should scroll up
vp.SetYOffset(currentOffset - distance - edge)
}
}
func ErrorBanner(err error, width int) string {
if err == nil {
return ""
}
return lipgloss.NewStyle().
Width(width).
AlignHorizontal(lipgloss.Center).
Bold(true).
Foreground(lipgloss.Color("1")).
Render(fmt.Sprintf("%s", err))
}

View File

@ -1,174 +0,0 @@
package chat
import (
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/textarea"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// custom tea.Msg types
type (
// sent when a conversation is (re)loaded
msgConversationLoaded struct {
conversation *api.Conversation
rootMessages []api.Message
}
// sent when a new conversation title generated
msgConversationTitleGenerated string
// sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct {
isNew bool
conversation *api.Conversation
messages []api.Message
}
// sent when a conversation's messages are laoded
msgMessagesLoaded []api.Message
// a special case of common.MsgError that stops the response waiting animation
msgChatResponseError error
// sent on each chunk received from LLM
msgChatResponseChunk api.Chunk
// sent on each completed reply
msgChatResponse *api.Message
// sent when the response is canceled
msgChatResponseCanceled struct{}
// sent when results from a tool call are returned
msgToolResults []api.ToolResult
// sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *api.Message
// sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *api.Message
// sent when a message's contents are updated and saved
msgMessageUpdated *api.Message
// sent when a message is cloned, with the cloned message
msgMessageCloned *api.Message
)
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type state int
const (
idle state = iota
loading
pendingResponse
)
type Model struct {
shared.Shared
shared.Sections
// app state
state state // current overall status of the view
conversation *api.Conversation
rootMessages []api.Message
messages []api.Message
selectedMessage int
editorTarget editorTarget
stopSignal chan struct{}
replyChan chan api.Message
chatReplyChunks chan api.Chunk
persistence bool // whether we will save new messages in the conversation
// ui state
focus focusState
wrap bool // whether message content is wrapped to viewport width
showToolResults bool // whether tool calls and results are shown
messageCache []string // cache of syntax highlighted and wrapped message content
messageOffsets []int
// ui elements
content viewport.Model
input textarea.Model
spinner spinner.Model
replyCursor cursor.Model // cursor to indicate incoming response
// metrics
tokenCount uint
startTime time.Time
elapsed time.Duration
}
func Chat(shared shared.Shared) Model {
m := Model{
Shared: shared,
state: idle,
conversation: &api.Conversation{},
persistence: true,
stopSignal: make(chan struct{}),
replyChan: make(chan api.Message),
chatReplyChunks: make(chan api.Chunk),
wrap: true,
selectedMessage: -1,
content: viewport.New(0, 0),
input: textarea.New(),
spinner: spinner.New(spinner.WithSpinner(
spinner.Spinner{
Frames: []string{
". ",
".. ",
"...",
".. ",
". ",
" ",
},
FPS: time.Second / 3,
},
)),
replyCursor: cursor.New(),
}
m.replyCursor.SetChar(" ")
m.replyCursor.Focus()
system := shared.Ctx.DefaultSystemPrompt()
agent := shared.Ctx.GetAgent(shared.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system != "" {
m.messages = api.ApplySystemPrompt(m.messages, system, false)
}
m.input.Focus()
m.input.MaxHeight = 0
m.input.CharLimit = 0
m.input.ShowLineNumbers = false
m.input.Placeholder = "Enter a message"
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
m.input.FocusedStyle.Base = inputFocusedStyle
m.input.BlurredStyle.Base = inputBlurredStyle
return m
}
func (m Model) Init() tea.Cmd {
return tea.Batch(
m.waitForResponseChunk(),
)
}

View File

@ -1,308 +0,0 @@
package chat
import (
"context"
"errors"
"fmt"
"time"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) setMessage(i int, msg api.Message) {
if i >= len(m.messages) {
panic("i out of range")
}
m.messages[i] = msg
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) addMessage(msg api.Message) {
m.messages = append(m.messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
}
func (m *Model) setMessageContents(i int, content string) {
if i >= len(m.messages) {
panic("i out of range")
}
m.messages[i].Content = content
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) rebuildMessageCache() {
m.messageCache = make([]string, len(m.messages))
for i := range m.messages {
m.messageCache[i] = m.renderMessage(i)
}
}
func (m *Model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationMessagesView())
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
}
func (m *Model) loadConversation(shortname string) tea.Cmd {
return func() tea.Msg {
if shortname == "" {
return nil
}
c, err := m.Shared.Ctx.Store.ConversationByShortName(shortname)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err))
}
if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
}
rootMessages, err := m.Shared.Ctx.Store.RootMessages(c.ID)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
}
return msgConversationLoaded{c, rootMessages}
}
}
func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.Shared.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
return msgMessagesLoaded(messages)
}
}
func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.Shared.Ctx, m.messages)
if err != nil {
return shared.MsgError(err)
}
return msgConversationTitleGenerated(title)
}
}
func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd {
return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateConversation(conversation)
if err != nil {
return shared.WrapError(err)
}
return nil
}
}
// Clones the given message (and its descendents). If selected is true, updates
// either its parent's SelectedReply or its conversation's SelectedRoot to
// point to the new clone
func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
return func() tea.Msg {
msg, _, err := m.Ctx.Store.CloneBranch(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not clone message: %v", err))
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = m.Shared.Ctx.Store.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = m.Shared.Ctx.Store.UpdateMessage(msg.Parent)
}
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
}
}
return msgMessageCloned(msg)
}
}
func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
}
return msgMessageUpdated(message)
}
}
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
// this should probably be an assert
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
// Wrap around to the last reply if at the beginning
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
// Wrap around to the first reply if at the end
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd {
if len(m.rootMessages) < 2 {
return nil
}
return func() tea.Msg {
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, m.rootMessages, dir)
if err != nil {
return shared.WrapError(err)
}
conv.SelectedRoot = nextRoot
err = m.Shared.Ctx.Store.UpdateConversation(conv)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
}
return msgSelectedRootCycled(nextRoot)
}
}
func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 {
return nil
}
return func() tea.Msg {
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return shared.WrapError(err)
}
message.SelectedReply = nextReply
err = m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
}
return msgSelectedReplyCycled(nextReply)
}
}
func (m *Model) persistConversation() tea.Cmd {
conversation := m.conversation
messages := m.messages
var err error
if conversation.ID == 0 {
return func() tea.Msg {
// Start a new conversation with all messages so far
conversation, messages, err = m.Shared.Ctx.Store.StartConversation(messages...)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
}
return msgConversationPersisted{true, conversation, messages}
}
}
return func() tea.Msg {
// else, we'll handle updating an existing conversation's messages
for i := range messages {
if messages[i].ID > 0 {
// message has an ID, update it
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
if err != nil {
return shared.MsgError(err)
}
} else if i > 0 {
// messages is new, so add it as a reply to previous message
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil {
return shared.MsgError(err)
}
messages[i] = saved[0]
} else {
// message has no id and no previous messages to add it to
// this shouldn't happen?
return fmt.Errorf("Error: no messages to reply to")
}
}
return msgConversationPersisted{false, conversation, messages}
}
}
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
return func() tea.Msg {
agent := m.Shared.Ctx.GetAgent(m.Shared.Ctx.Config.Defaults.Agent)
if agent == nil {
return shared.MsgError(fmt.Errorf("Attempted to execute tool calls with no agent configured"))
}
results, err := agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
if err != nil {
return shared.MsgError(err)
}
return msgToolResults(results)
}
}
func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse
m.replyCursor.Blink = false
m.startTime = time.Now()
m.elapsed = 0
m.tokenCount = 0
return func() tea.Msg {
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
if err != nil {
return shared.MsgError(err)
}
params := api.RequestParameters{
Model: model,
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
}
agent := m.Shared.Ctx.GetAgent(m.Shared.Ctx.Config.Defaults.Agent)
if agent != nil {
params.Toolbox = agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-m.stopSignal:
cancel()
}
}()
resp, err := provider.CreateChatCompletionStream(
ctx, params, m.messages, m.chatReplyChunks,
)
if errors.Is(err, context.Canceled) {
return msgChatResponseCanceled(struct{}{})
}
if err != nil {
return msgChatResponseError(err)
}
return msgChatResponse(resp)
}
}

View File

@ -1,180 +0,0 @@
package chat
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
tea "github.com/charmbracelet/bubbletea"
)
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
switch m.focus {
case focusInput:
consumed, cmd := m.handleInputKey(msg)
if consumed {
return true, cmd
}
case focusMessages:
consumed, cmd := m.handleMessagesKey(msg)
if consumed {
return true, cmd
}
}
switch msg.String() {
case "esc":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return true, nil
}
return true, func() tea.Msg {
return shared.MsgViewChange(shared.StateConversations)
}
case "ctrl+c":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return true, nil
}
case "ctrl+p":
m.persistence = !m.persistence
return true, nil
case "ctrl+t":
m.showToolResults = !m.showToolResults
m.rebuildMessageCache()
m.updateContent()
return true, nil
case "ctrl+w":
m.wrap = !m.wrap
m.rebuildMessageCache()
m.updateContent()
return true, nil
}
return false, nil
}
// handleMessagesKey handles input when the messages pane is focused
func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
switch msg.String() {
case "tab", "enter":
m.focus = focusInput
m.updateContent()
m.input.Focus()
return true, nil
case "e":
if m.selectedMessage < len(m.messages) {
m.editorTarget = selectedMessage
return true, tuiutil.OpenTempfileEditor(
"message.*.md",
m.messages[m.selectedMessage].Content,
"# Edit the message below\n",
)
}
return false, nil
case "ctrl+k":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage--
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
}
return true, nil
case "ctrl+j":
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage++
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
}
return true, nil
case "ctrl+h", "ctrl+l":
dir := CyclePrev
if msg.String() == "ctrl+l" {
dir = CycleNext
}
var cmd tea.Cmd
if m.selectedMessage == 0 {
cmd = m.cycleSelectedRoot(m.conversation, dir)
} else if m.selectedMessage > 0 {
cmd = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir)
}
return cmd != nil, cmd
case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message
if m.state == idle && m.selectedMessage < len(m.messages) {
m.messages = m.messages[:m.selectedMessage+1]
m.messageCache = m.messageCache[:m.selectedMessage+1]
cmd := m.promptLLM()
m.updateContent()
m.content.GotoBottom()
return true, cmd
}
}
return false, nil
}
// handleInputKey handles input when the input textarea is focused
func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
switch msg.String() {
case "esc":
m.focus = focusMessages
if len(m.messages) > 0 {
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
m.selectedMessage = len(m.messages) - 1
}
offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
}
m.updateContent()
m.input.Blur()
return true, nil
case "ctrl+s":
// TODO: call a "handleSend" function which returns a tea.Cmd
if m.state != idle {
return false, nil
}
input := strings.TrimSpace(m.input.Value())
if input == "" {
return true, nil
}
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser {
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message"))
}
m.addMessage(api.Message{
Role: api.MessageRoleUser,
Content: input,
})
m.input.SetValue("")
var cmds []tea.Cmd
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
cmds = append(cmds, m.promptLLM())
m.updateContent()
m.content.GotoBottom()
return true, tea.Batch(cmds...)
case "ctrl+e":
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input
return true, cmd
}
return false, nil
}

View File

@ -1,268 +0,0 @@
package chat
import (
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/cursor"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) HandleResize(width, height int) {
m.Width, m.Height = width, height
m.content.Width = width
m.input.SetWidth(width - m.input.FocusedStyle.Base.GetHorizontalFrameSize())
if len(m.messages) > 0 {
m.rebuildMessageCache()
m.updateContent()
}
}
func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg {
return msgChatResponseChunk(<-m.chatReplyChunks)
}
}
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
case shared.MsgViewEnter:
// wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick)
if m.Shared.Values.ConvShortname != "" {
// (re)load conversation contents
cmds = append(cmds, m.loadConversation(m.Shared.Values.ConvShortname))
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
// clear existing messages if we're loading a new conversation
m.messages = []api.Message{}
m.selectedMessage = 0
}
}
m.rebuildMessageCache()
m.updateContent()
case tuiutil.MsgTempfileEditorClosed:
contents := string(msg)
switch m.editorTarget {
case input:
m.input.SetValue(contents)
case selectedMessage:
toEdit := m.messages[m.selectedMessage]
if toEdit.Content != contents {
toEdit.Content = contents
m.setMessage(m.selectedMessage, toEdit)
if m.persistence && toEdit.ID > 0 {
// create clone of message with its new contents
cmds = append(cmds, m.cloneMessage(toEdit, true))
}
}
}
case msgConversationLoaded:
m.conversation = msg.conversation
m.rootMessages = msg.rootMessages
m.selectedMessage = -1
if len(m.rootMessages) > 0 {
cmds = append(cmds, m.loadConversationMessages())
}
case msgMessagesLoaded:
m.messages = msg
if m.selectedMessage == -1 {
m.selectedMessage = len(msg) - 1
} else {
m.selectedMessage = min(m.selectedMessage, len(m.messages))
}
m.rebuildMessageCache()
m.updateContent()
case msgChatResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
if msg.Content == "" {
break
}
last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role.IsAssistant() {
// append chunk to existing message
m.setMessageContents(last, m.messages[last].Content+msg.Content)
} else {
// use chunk in a new message
m.addMessage(api.Message{
Role: api.MessageRoleAssistant,
Content: msg.Content,
})
}
m.updateContent()
// show cursor and reset blink interval (simulate typing)
m.replyCursor.Blink = false
cmds = append(cmds, m.replyCursor.BlinkCmd())
m.tokenCount += msg.TokenCount
m.elapsed = time.Now().Sub(m.startTime)
case msgChatResponse:
m.state = idle
reply := (*api.Message)(msg)
reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply")
}
if m.messages[last].Role.IsAssistant() {
// TODO: handle continuations gracefully - some models support them well, others fail horribly.
m.setMessage(last, *reply)
} else {
m.addMessage(*reply)
}
switch reply.Role {
case api.MessageRoleToolCall:
// TODO: user confirmation before execution
// m.state = waitingForConfirmation
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
}
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
if m.conversation.Title == "" {
cmds = append(cmds, m.generateConversationTitle())
}
m.updateContent()
case msgChatResponseCanceled:
m.state = idle
m.updateContent()
case msgChatResponseError:
m.state = idle
m.Shared.Err = error(msg)
m.updateContent()
case msgToolResults:
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply")
}
if m.messages[last].Role != api.MessageRoleToolCall {
panic("Previous message not a tool call, unexpected")
}
m.addMessage(api.Message{
Role: api.MessageRoleToolResult,
ToolResults: api.ToolResults(msg),
})
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
m.updateContent()
case msgConversationTitleGenerated:
title := string(msg)
m.conversation.Title = title
if m.persistence {
cmds = append(cmds, m.updateConversationTitle(m.conversation))
}
case cursor.BlinkMsg:
if m.state == pendingResponse {
// ensure we show the updated "wait for response" cursor blink state
last := len(m.messages)-1
m.messageCache[last] = m.renderMessage(last)
m.updateContent()
}
case msgConversationPersisted:
m.conversation = msg.conversation
m.messages = msg.messages
if msg.isNew {
m.rootMessages = []api.Message{m.messages[0]}
}
m.rebuildMessageCache()
m.updateContent()
case msgMessageCloned:
if msg.Parent == nil {
m.conversation = msg.Conversation
m.rootMessages = append(m.rootMessages, *msg)
}
cmds = append(cmds, m.loadConversationMessages())
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
cmds = append(cmds, m.loadConversationMessages())
}
var cmd tea.Cmd
m.spinner, cmd = m.spinner.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
m.replyCursor, cmd = m.replyCursor.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
prevInputLineCnt := m.input.LineCount()
inputCaptured := false
m.input, cmd = m.input.Update(msg)
if cmd != nil {
inputCaptured = true
cmds = append(cmds, cmd)
}
if !inputCaptured {
m.content, cmd = m.content.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
}
// update views once window dimensions are known
if m.Width > 0 {
m.Header = m.headerView()
m.Footer = m.footerView()
m.Error = tuiutil.ErrorBanner(m.Err, m.Width)
fixedHeight := tuiutil.Height(m.Header) + tuiutil.Height(m.Error) + tuiutil.Height(m.Footer)
// calculate clamped input height to accomodate input text
// minimum 4 lines, maximum half of content area
newHeight := max(4, min((m.Height-fixedHeight-1)/2, m.input.LineCount()))
m.input.SetHeight(newHeight)
m.Input = m.input.View()
// remaining height towards content
m.content.Height = m.Height - fixedHeight - tuiutil.Height(m.Input)
m.Content = m.content.View()
}
// this is a pretty nasty hack to ensure the input area viewport doesn't
// scroll below its content, which can happen when the input viewport
// height has grown, or previously entered lines have been deleted
if prevInputLineCnt != m.input.LineCount() {
// dist is the distance we'd need to scroll up from the current cursor
// position to position the last input line at the bottom of the
// viewport. if negative, we're already scrolled above the bottom
dist := m.input.Line() - (m.input.LineCount() - m.input.Height())
if dist > 0 {
for i := 0; i < dist; i++ {
// move cursor up until content reaches the bottom of the viewport
m.input.CursorUp()
}
m.input, _ = m.input.Update(nil)
for i := 0; i < dist; i++ {
// move cursor back down to its previous position
m.input.CursorDown()
}
m.input, _ = m.input.Update(nil)
}
}
return m, tea.Batch(cmds...)
}

View File

@ -1,321 +0,0 @@
package chat
import (
"encoding/json"
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/wordwrap"
"github.com/muesli/reflow/wrap"
"gopkg.in/yaml.v3"
)
// styles
var (
messageHeadingStyle = lipgloss.NewStyle().
MarginTop(1).
MarginBottom(1).
PaddingLeft(1).
Bold(true)
userStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("10"))
assistantStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("12"))
messageStyle = lipgloss.NewStyle().
PaddingLeft(2).
PaddingRight(2)
inputFocusedStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder(), true, true, true, false)
inputBlurredStyle = lipgloss.NewStyle().
Faint(true).
Border(lipgloss.RoundedBorder(), true, true, true, false)
footerStyle = lipgloss.NewStyle()
)
func (m Model) View() string {
if m.Width == 0 {
return ""
}
sections := make([]string, 0, 6)
if m.Header != "" {
sections = append(sections, m.Header)
}
sections = append(sections, m.Content)
if m.Error != "" {
sections = append(sections, m.Error)
}
sections = append(sections, m.Input)
if m.Footer != "" {
sections = append(sections, m.Footer)
}
return lipgloss.JoinVertical(lipgloss.Left, sections...)
}
func (m *Model) renderMessageHeading(i int, message *api.Message) string {
icon := ""
friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Faint(true).Bold(true)
switch message.Role {
case api.MessageRoleSystem:
icon = "⚙️"
case api.MessageRoleUser:
style = userStyle
case api.MessageRoleAssistant:
style = assistantStyle
case api.MessageRoleToolCall:
style = assistantStyle
friendly = api.MessageRoleAssistant.FriendlyRole()
case api.MessageRoleToolResult:
icon = "🔧"
}
user := style.Render(icon + friendly)
var prefix string
var suffix string
faint := lipgloss.NewStyle().Faint(true)
if i == 0 && len(m.rootMessages) > 1 && m.conversation.SelectedRootID != nil {
selectedRootIndex := 0
for j, reply := range m.rootMessages {
if reply.ID == *m.conversation.SelectedRootID {
selectedRootIndex = j
break
}
}
suffix += faint.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.rootMessages)))
}
if i > 0 && len(m.messages[i-1].Replies) > 1 {
// Find the selected reply index
selectedReplyIndex := 0
for j, reply := range m.messages[i-1].Replies {
if reply.ID == *m.messages[i-1].SelectedReplyID {
selectedReplyIndex = j
break
}
}
suffix += faint.Render(fmt.Sprintf(" <%d/%d>", selectedReplyIndex+1, len(m.messages[i-1].Replies)))
}
if m.focus == focusMessages {
if i == m.selectedMessage {
prefix = "> "
}
}
if message.ID == 0 {
suffix += faint.Render(" (not saved)")
}
return messageHeadingStyle.Render(prefix + user + suffix)
}
// renderMessages renders the message at the given index as it should be shown
// *at this moment* - we render differently depending on the current application
// state (window size, etc, etc).
func (m *Model) renderMessage(i int) string {
msg := &m.messages[i]
// Write message contents
sb := &strings.Builder{}
sb.Grow(len(msg.Content) * 2)
if msg.Content != "" {
err := m.Shared.Ctx.Chroma.Highlight(sb, msg.Content)
if err != nil {
sb.Reset()
sb.WriteString(msg.Content)
}
}
isLast := i == len(m.messages)-1
isAssistant := msg.Role == api.MessageRoleAssistant
if m.state == pendingResponse && isLast && isAssistant {
// Show the assistant's cursor
sb.WriteString(m.replyCursor.View())
}
// Write tool call info
var toolString string
switch msg.Role {
case api.MessageRoleToolCall:
bytes, err := yaml.Marshal(msg.ToolCalls)
if err != nil {
toolString = "Could not serialize ToolCalls"
} else {
toolString = "tool_calls:\n" + string(bytes)
}
case api.MessageRoleToolResult:
type renderedResult struct {
ToolName string `yaml:"tool"`
Result any `yaml:"result,omitempty"`
}
var toolResults []renderedResult
for _, result := range msg.ToolResults {
if m.showToolResults {
var jsonResult interface{}
err := json.Unmarshal([]byte(result.Result), &jsonResult)
if err != nil {
// If parsing as JSON fails, treat Result as a plain string
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: result.Result,
})
} else {
// If parsing as JSON succeeds, marshal the parsed JSON into YAML
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: &jsonResult,
})
}
} else {
// Only show the tool name when results are hidden
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: "(hidden, press ctrl+t to view)",
})
}
}
bytes, err := yaml.Marshal(toolResults)
if err != nil {
toolString = "Could not serialize ToolResults"
} else {
toolString = "tool_results:\n" + string(bytes)
}
}
if toolString != "" {
toolString = strings.TrimRight(toolString, "\n")
if msg.Content != "" {
sb.WriteString("\n\n")
}
_ = m.Shared.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
}
content := strings.TrimRight(sb.String(), "\n")
if m.wrap {
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding()
// first we word-wrap text to slightly less than desired width (since
// wordwrap seems to have an off-by-1 issue), then hard wrap at
// desired with
content = wrap.String(wordwrap.String(content, wrapWidth-2), wrapWidth)
}
return messageStyle.Width(0).Render(content)
}
// render the conversation into a string
func (m *Model) conversationMessagesView() string {
sb := strings.Builder{}
m.messageOffsets = make([]int, len(m.messages))
lineCnt := 1
for i, message := range m.messages {
m.messageOffsets[i] = lineCnt
heading := m.renderMessageHeading(i, &message)
sb.WriteString(heading)
sb.WriteString("\n")
lineCnt += lipgloss.Height(heading)
rendered := m.messageCache[i]
sb.WriteString(rendered)
sb.WriteString("\n")
lineCnt += lipgloss.Height(rendered)
}
// Render a placeholder for the incoming assistant reply
if m.state == pendingResponse && m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant {
heading := m.renderMessageHeading(-1, &api.Message{
Role: api.MessageRoleAssistant,
})
sb.WriteString(heading)
sb.WriteString("\n")
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
sb.WriteString("\n")
}
return sb.String()
}
func (m *Model) headerView() string {
titleStyle := lipgloss.NewStyle().Bold(true)
var title string
if m.conversation != nil && m.conversation.Title != "" {
title = m.conversation.Title
} else {
title = "Untitled"
}
title = tuiutil.TruncateToCellWidth(title, m.Width-styles.Header.GetHorizontalPadding(), "...")
header := titleStyle.Render(title)
return styles.Header.Width(m.Width).Render(header)
}
func (m *Model) footerView() string {
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
segmentSeparator := "|"
savingStyle := segmentStyle.Copy().Bold(true)
saving := ""
if m.persistence {
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾")
} else {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
}
var status string
switch m.state {
case pendingResponse:
status = "Press ctrl+c to cancel" + m.spinner.View()
default:
status = "Press ctrl+s to send"
}
leftSegments := []string{
saving,
segmentStyle.Render(status),
}
rightSegments := []string{}
if m.elapsed > 0 && m.tokenCount > 0 {
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
rightSegments = append(rightSegments, segmentStyle.Render(throughput))
}
model := fmt.Sprintf("Model: %s", *m.Shared.Ctx.Config.Defaults.Model)
rightSegments = append(rightSegments, segmentStyle.Render(model))
left := strings.Join(leftSegments, segmentSeparator)
right := strings.Join(rightSegments, segmentSeparator)
totalWidth := lipgloss.Width(left) + lipgloss.Width(right)
remaining := m.Width - totalWidth
var padding string
if remaining > 0 {
padding = strings.Repeat(" ", remaining)
}
footer := left + padding + right
if remaining < 0 {
footer = tuiutil.TruncateToCellWidth(footer, m.Width, "...")
}
return footerStyle.Width(m.Width).Render(footer)
}

View File

@ -1,342 +0,0 @@
package conversations
import (
"fmt"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type loadedConversation struct {
conv api.Conversation
lastReply api.Message
}
type (
// sent when conversation list is loaded
msgConversationsLoaded ([]loadedConversation)
// sent when a conversation is selected
msgConversationSelected api.Conversation
// sent when a conversation is deleted
msgConversationDeleted struct{}
)
// Prompt payloads
type (
deleteConversationPayload api.Conversation
)
type Model struct {
shared.Shared
shared.Sections
conversations []loadedConversation
cursor int // index of the currently selected conversation
itemOffsets []int // keeps track of the viewport y offset of each rendered item
content viewport.Model
confirmPrompt bubbles.ConfirmPrompt
}
func Conversations(shared shared.Shared) Model {
m := Model{
Shared: shared,
content: viewport.New(0, 0),
}
return m
}
func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
if m.confirmPrompt.Focused() {
var cmd tea.Cmd
m.confirmPrompt, cmd = m.confirmPrompt.Update(msg)
if cmd != nil {
return true, cmd
}
}
switch msg.String() {
case "enter":
if len(m.conversations) > 0 && m.cursor < len(m.conversations) {
return true, func() tea.Msg {
return msgConversationSelected(m.conversations[m.cursor].conv)
}
}
case "j", "down":
if m.cursor < len(m.conversations)-1 {
m.cursor++
if m.cursor == len(m.conversations)-1 {
// if last conversation, simply scroll to the bottom
m.content.GotoBottom()
} else {
// this hack positions the *next* conversatoin slightly
// *off* the screen, ensuring the entire m.cursor is shown,
// even if its height may not be constant due to wrapping.
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor+1], -1)
}
m.content.SetContent(m.renderConversationList())
} else {
m.cursor = len(m.conversations) - 1
m.content.GotoBottom()
}
return true, nil
case "k", "up":
if m.cursor > 0 {
m.cursor--
if m.cursor == 0 {
m.content.GotoTop()
} else {
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor], 1)
}
m.content.SetContent(m.renderConversationList())
} else {
m.cursor = 0
m.content.GotoTop()
}
return true, nil
case "n":
// new conversation
case "d":
if !m.confirmPrompt.Focused() && len(m.conversations) > 0 && m.cursor < len(m.conversations) {
title := m.conversations[m.cursor].conv.Title
if title == "" {
title = "(untitled)"
}
m.confirmPrompt = bubbles.NewConfirmPrompt(
fmt.Sprintf("Delete '%s'?", title),
deleteConversationPayload(m.conversations[m.cursor].conv),
)
m.confirmPrompt.Style = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("3"))
return true, nil
}
case "c":
// copy/clone conversation
case "r":
// show prompt to rename conversation
case "shift+r":
// show prompt to generate name for conversation
}
return false, nil
}
func (m Model) Init() tea.Cmd {
return nil
}
func (m *Model) HandleResize(width, height int) {
m.Width, m.Height = width, height
m.content.Width = width
}
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case shared.MsgViewEnter:
cmds = append(cmds, m.loadConversations())
m.content.SetContent(m.renderConversationList())
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
m.content.SetContent(m.renderConversationList())
case msgConversationsLoaded:
m.conversations = msg
m.cursor = max(0, min(len(m.conversations), m.cursor))
m.content.SetContent(m.renderConversationList())
case msgConversationSelected:
m.Values.ConvShortname = msg.ShortName.String
cmds = append(cmds, func() tea.Msg {
return shared.MsgViewChange(shared.StateChat)
})
case bubbles.MsgConfirmPromptAnswered:
m.confirmPrompt.Blur()
if msg.Value {
switch payload := msg.Payload.(type) {
case deleteConversationPayload:
cmds = append(cmds, m.deleteConversation(api.Conversation(payload)))
}
}
case msgConversationDeleted:
cmds = append(cmds, m.loadConversations())
}
var cmd tea.Cmd
m.content, cmd = m.content.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
if m.Width > 0 {
wrap := lipgloss.NewStyle().Width(m.Width)
m.Header = m.headerView()
m.Footer = "" // TODO: "Press ? for help"
if m.confirmPrompt.Focused() {
m.Footer = wrap.Render(m.confirmPrompt.View())
}
m.Error = tuiutil.ErrorBanner(m.Err, m.Width)
fixedHeight := tuiutil.Height(m.Header) + tuiutil.Height(m.Error) + tuiutil.Height(m.Footer)
m.content.Height = m.Height - fixedHeight
m.Content = m.content.View()
}
return m, tea.Batch(cmds...)
}
func (m *Model) loadConversations() tea.Cmd {
return func() tea.Msg {
messages, err := m.Ctx.Store.LatestConversationMessages()
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversations: %v", err))
}
loaded := make([]loadedConversation, len(messages))
for i, m := range messages {
loaded[i].lastReply = m
loaded[i].conv = *m.Conversation
}
return msgConversationsLoaded(loaded)
}
}
func (m *Model) deleteConversation(conv api.Conversation) tea.Cmd {
return func() tea.Msg {
err := m.Ctx.Store.DeleteConversation(&conv)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not delete conversation: %v", err))
}
return msgConversationDeleted{}
}
}
func (m Model) View() string {
if m.Width == 0 {
return ""
}
sections := make([]string, 0, 6)
if m.Header != "" {
sections = append(sections, m.Header)
}
sections = append(sections, m.Content)
if m.Error != "" {
sections = append(sections, m.Error)
}
if m.Footer != "" {
sections = append(sections, m.Footer)
}
return lipgloss.JoinVertical(lipgloss.Left, sections...)
}
func (m *Model) headerView() string {
titleStyle := lipgloss.NewStyle().Bold(true)
header := titleStyle.Render("Conversations")
return styles.Header.Width(m.Width).Render(header)
}
func (m *Model) renderConversationList() string {
now := time.Now()
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
dayOfWeek := int(now.Weekday())
categories := []struct {
name string
cutoff time.Duration
}{
{"Today", now.Sub(midnight)},
{"Yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
{"This week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
{"Last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
{"This month", now.Sub(monthStart)},
{"Last month", now.Sub(monthStart.AddDate(0, -1, 0))},
{"2 Months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
{"3 Months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
{"4 Months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
{"5 Months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
{"6 Months ago", now.Sub(monthStart.AddDate(0, -6, 0))},
{"Older", now.Sub(time.Time{})},
}
categoryStyle := lipgloss.NewStyle().
MarginBottom(1).
Foreground(lipgloss.Color("12")).
PaddingLeft(1).
Bold(true)
itemStyle := lipgloss.NewStyle().
MarginBottom(1)
ageStyle := lipgloss.NewStyle().Faint(true).SetString()
titleStyle := lipgloss.NewStyle().Bold(true)
untitledStyle := lipgloss.NewStyle().Faint(true).Italic(true)
selectedStyle := lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6"))
var (
currentOffset int
currentCategory string
sb strings.Builder
)
m.itemOffsets = make([]int, len(m.conversations))
sb.WriteRune('\n')
currentOffset += 1
for i, c := range m.conversations {
lastReplyAge := now.Sub(c.lastReply.CreatedAt)
var category string
for _, g := range categories {
if lastReplyAge < g.cutoff {
category = g.name
break
}
}
// print the category
if category != currentCategory {
currentCategory = category
heading := categoryStyle.Render(currentCategory)
sb.WriteString(heading)
currentOffset += tuiutil.Height(heading)
sb.WriteRune('\n')
}
tStyle := titleStyle.Copy()
if c.conv.Title == "" {
tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)")
}
if i == m.cursor {
tStyle = tStyle.Inherit(selectedStyle)
}
title := tStyle.Width(m.Width - 3).PaddingLeft(2).Render(c.conv.Title)
if i == m.cursor {
title = ">" + title[1:]
}
m.itemOffsets[i] = currentOffset
item := itemStyle.Render(fmt.Sprintf(
"%s\n %s",
title,
ageStyle.Render(util.HumanTimeElapsedSince(lastReplyAge)),
))
sb.WriteString(item)
currentOffset += tuiutil.Height(item)
if i < len(m.conversations)-1 {
sb.WriteRune('\n')
}
}
return sb.String()
}

View File

@ -1,86 +0,0 @@
package tty
import (
"io"
"strings"
"github.com/alecthomas/chroma/v2"
"github.com/alecthomas/chroma/v2/formatters"
"github.com/alecthomas/chroma/v2/lexers"
"github.com/alecthomas/chroma/v2/styles"
)
type ChromaHighlighter struct {
lexer chroma.Lexer
formatter chroma.Formatter
style *chroma.Style
}
func NewChromaHighlighter(lang, format, style string) *ChromaHighlighter {
l := lexers.Get(lang)
if l == nil {
l = lexers.Fallback
}
l = chroma.Coalesce(l)
f := formatters.Get(format)
if f == nil {
f = formatters.Fallback
}
s := styles.Get(style)
if s == nil {
s = styles.Fallback
}
return &ChromaHighlighter{
lexer: l,
formatter: f,
style: s,
}
}
func (s *ChromaHighlighter) Highlight(w io.Writer, text string) error {
it, err := s.lexer.Tokenise(nil, text)
if err != nil {
return err
}
return s.formatter.Format(w, s.style, it)
}
func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
it, err := s.lexer.Tokenise(nil, text)
if err != nil {
return "", err
}
sb := strings.Builder{}
sb.Grow(len(text) * 2)
s.formatter.Format(&sb, s.style, it)
return sb.String(), nil
}
func (s *ChromaHighlighter) HighlightLang(w io.Writer, text string, lang string) (error) {
l := lexers.Get(lang)
if l == nil {
l = lexers.Fallback
}
l = chroma.Coalesce(l)
old := s.lexer
s.lexer = l
err := s.Highlight(w, text)
s.lexer = old
return err
}
func (s *ChromaHighlighter) HighlightLangS(text string, lang string) (string, error) {
l := lexers.Get(lang)
if l == nil {
l = lexers.Fallback
}
l = chroma.Coalesce(l)
old := s.lexer
s.lexer = l
highlighted, err := s.HighlightS(text)
s.lexer = old
return highlighted, err
}