Compare commits

...

31 Commits

Author SHA1 Message Date
f05e2e30f7 Update TODO.md 2024-10-27 18:03:15 +00:00
ec21a02ec0 Tweaks/cleanups to conversation management in tui
- Pass around message/conversation values instead of pointers where it
makes more sense, and store values instead of pointers in the globally
(within the TUI) shared `App` (pointers provide no utility here).

- Split conversation persistence into separate conversation/message
  saving stages
2024-10-25 16:57:15 +00:00
07c96082e7 Add LastMessageAt field to conversation
Replaced `LatestConversationMessages` with `LoadConversationList`, which
utilizes `LastMessageAt` for much faster conversation loading in the
conversation listing TUI and `lmcli list` command.
2024-10-22 17:53:13 +00:00
0384c7cb66 Large refactor - it compiles!
This refactor splits out all conversation concerns into a new
`conversation` package. There is now a split between `conversation` and
`api`s representation of `Message`, the latter storing the minimum
information required for interaction with LLM providers. There is
necessary conversation between the two when making LLM calls.
2024-10-22 17:53:13 +00:00
2ea8a73eb5 Final(?) chat view footer fix 2024-10-22 17:53:13 +00:00
c9a7eee090 Update go.mod 2024-10-20 22:39:35 +00:00
ae1e85e166 Update Anthropic token counting logic 2024-10-20 02:38:25 +00:00
304820c919 tui: Show selected message indicator only when focused on messages 2024-10-03 16:47:28 +00:00
93c2fb3d1e tui: revamp footer (some more)
Simplified layout logic, reorganized elements
2024-10-03 16:47:28 +00:00
bb48bc9abd Display generation model in message header and other tweaks
Adjusted `ctrl+t` in chat view to toggle `showDetails` which toggles the
display of system messages, message metadata (generation model), and
tool call details

Modified message selection update logic to skip messages that aren't
shown
2024-09-30 20:26:09 +00:00
5d13c3e056 Add metadata json field to Message, store generation model/provider 2024-09-30 17:44:43 +00:00
327a128b2f Moved api.ChatCompletionProvider, api.Chunk to api/provider 2024-09-30 16:15:42 +00:00
a441866f2f Configure database logging to file 2024-09-30 16:10:51 +00:00
ce7b07ad95 tui: rename 2024-09-30 04:32:47 +00:00
2fed682969 tui: Chat view footer rewrite
Rewrote footer handling to better handle truncation, and use
`ActiveModel` to return the (stylized) active model
2024-09-26 18:32:33 +00:00
69cdc0a5aa tui: Add setting view with support for changing the current model 2024-09-26 18:32:22 +00:00
3ec2675632 tui: Error handling tweak
Moved errors to bottom of screen, fix infinite loop by typing errors
properly
2024-09-23 04:43:26 +00:00
172bfc57e1 Allow specifying --agent none to mean no agent 2024-09-23 03:04:43 +00:00
a46d211e10 Improve TUI system prompt handling
+ allow setting a default agent
2024-09-23 03:00:03 +00:00
676aa7b004 Refactor TUI rendering handling and general cleanup
Improves render handling by moving the responsibility of laying out the
whole UI from each view and into the main `tui` model. Our `ViewModel`
interface has now diverged from bubbletea's `Model` and introduces
individual `Header`, `Content`, and `Footer` methods for rendering those
UI elements.

Also moved away from using value receivers on our Update and View
functions (as is common across Bubbletea) to pointer receivers, which
cleaned up some of the weirder aspects of the code (e.g. before we
essentially had no choice but to do our rendering in `Update` in order
to calculate and update the final height of the main content's
`viewport`).
2024-09-23 02:49:08 +00:00
b7c89a4dd1 Update TODO.md 2024-09-21 20:13:47 +00:00
b8e3172ce0 Start new conversations from TUI 2024-09-21 02:47:03 +00:00
a1fdf3f7cd Deprecation fix 2024-09-21 02:46:51 +00:00
a488ec4fd8 Fixed message loading
Root messages weren't being loaded since the refactor, and there was
dead code
2024-09-21 02:32:54 +00:00
463ca9ef40 TUI view management and input handling cleanup 2024-09-16 16:18:18 +00:00
24b5cdbbf6 More monior TUI refactor/cleanup
`tui/tui.go` is no longer responsible for passing window resize updates
to all views, instead we request a new window size message to be sent at
the same time we enter the view, allowing the view to catch and handle
it.

Add `Initialized` to `tui/shared/View` model, now we only call
`Init` on a view before entering it for the first time, rather than
calling `Init` on all views when the application starts.

Renames file, small cleanups
2024-09-16 14:04:08 +00:00
7c0bfefc65 Update deps 2024-09-16 03:49:04 +00:00
443c8096d3 TUI refactor
- Clean up, improved startup logic, initial conversation load
- Moved converation/message business logic (mostly) into `model/tui`
2024-09-16 00:48:45 +00:00
1570988b98 Fix LatestConversationMessages preload
Load the conversation's selected root as well
2024-09-16 00:37:42 +00:00
434fc4672b Allow custom headers on OpenAI providers (to be added to more later) 2024-08-12 17:14:53 +00:00
fe838f400f Minor adjustment to seleted message style 2024-07-10 01:21:06 +00:00
45 changed files with 2421 additions and 1571 deletions

View File

@ -14,7 +14,7 @@
system prompt, rather than having them in the conversation messages) system prompt, rather than having them in the conversation messages)
- [ ] Agents may have some form of long term memory management (key-value? - [ ] Agents may have some form of long term memory management (key-value?
natural lang?). natural lang?).
- [ ] Sandboxed python, js interpreter (both useful for different reasons) - [ ] Sandboxed python, js interpreters (implemented with containers)
- [ ] Support for arbitrary external script tools - [ ] Support for arbitrary external script tools
- [ ] Search - RAG driven search of existing conversation "hey, remind me of - [ ] Search - RAG driven search of existing conversation "hey, remind me of
the conversation we had six months ago about X") the conversation we had six months ago about X")
@ -23,12 +23,15 @@
- [ ] Image input - [ ] Image input
- [ ] Image output (sixel support?) - [ ] Image output (sixel support?)
- [ ] Conversation exports to html/pdf/json - [ ] Conversation exports to html/pdf/json
- [ ] Store message generation model
- [ ] Hidden CoT
- [ ] Token accounting
## UI ## UI
- [x] Prettify/normalize tool_call and tool_result outputs so they can be - [x] Prettify/normalize tool_call and tool_result outputs so they can be
shown/optionally hidden in `lmcli view` and `lmcli chat` shown/optionally hidden in `lmcli view` and `lmcli chat`
- [x] Conversation deletion in conversations view
- [ ] User confirmation before calling (some?) tools - [ ] User confirmation before calling (some?) tools
- [ ] Conversation deletion in conversations view
- [ ] Message deletion, Ctrl+D to delete a message and attach its children to - [ ] Message deletion, Ctrl+D to delete a message and attach its children to
its parent, Ctrl+Shift+D to delete a message and its descendents its parent, Ctrl+Shift+D to delete a message and its descendents
- [ ] Show available key bindings and their action in any given view - [ ] Show available key bindings and their action in any given view

35
go.mod
View File

@ -3,40 +3,41 @@ module git.mlow.ca/mlow/lmcli
go 1.21 go 1.21
require ( require (
github.com/alecthomas/chroma/v2 v2.11.1 github.com/alecthomas/chroma/v2 v2.14.0
github.com/charmbracelet/bubbles v0.18.0 github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v0.25.0 github.com/charmbracelet/bubbletea v1.1.1
github.com/charmbracelet/lipgloss v0.10.0 github.com/charmbracelet/lipgloss v0.13.0
github.com/muesli/reflow v0.3.0 github.com/muesli/reflow v0.3.0
github.com/spf13/cobra v1.8.0 github.com/spf13/cobra v1.8.1
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.5.4 gorm.io/driver/sqlite v1.5.6
gorm.io/gorm v1.25.5 gorm.io/gorm v1.25.12
) )
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect github.com/charmbracelet/x/ansi v0.3.1 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/charmbracelet/x/term v0.2.0 // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/mattn/go-sqlite3 v1.14.23 // indirect
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sync v0.1.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.14.0 // indirect golang.org/x/sys v0.25.0 // indirect
golang.org/x/term v0.6.0 // indirect golang.org/x/text v0.18.0 // indirect
golang.org/x/text v0.3.8 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
) )

84
go.sum
View File

@ -1,25 +1,31 @@
github.com/alecthomas/assert/v2 v2.2.1 h1:XivOgYcduV98QCahG8T5XTezV5bylXe+lBxLG2K2ink= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/alecthomas/assert/v2 v2.2.1/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJWtdzJPs= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw= github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0= github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw= github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM= github.com/charmbracelet/bubbletea v1.1.1 h1:KJ2/DnmpfqFtDNVTvYZ6zpPFL9iRCRr0qqKOCvppbPY=
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg= github.com/charmbracelet/bubbletea v1.1.1/go.mod h1:9Ogk0HrdbHolIKHdjfFpyXJmiCzGwy+FesYkZr7hYU4=
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s= github.com/charmbracelet/lipgloss v0.13.0 h1:4X3PPeoWEDCMvzDvGmTajSyYPcZM4+y8sCA/SsA3cjw=
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE= github.com/charmbracelet/lipgloss v0.13.0/go.mod h1:nw4zy0SBX/F/eAO1cWdcvy6qnkDUxr8Lw7dvFrAIbbY=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY= github.com/charmbracelet/x/ansi v0.3.1 h1:CRO6lc/6HCx2/D6S/GZ87jDvRvk6GtPyFP+IljkNtqI=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk= github.com/charmbracelet/x/ansi v0.3.1/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0=
github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@ -34,17 +40,17 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
@ -59,28 +65,26 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=

View File

@ -1,49 +1,126 @@
package api package api
import ( import (
"context" "encoding/json"
"fmt"
) )
type ReplyCallback func(Message) type MessageRole string
type Chunk struct { const (
Content string MessageRoleSystem MessageRole = "system"
TokenCount uint MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
Content string // TODO: support multi-part messages
Role MessageRole
ToolCalls []ToolCall
ToolResults []ToolResult
} }
type RequestParameters struct { type ToolSpec struct {
Model string Name string
Description string
MaxTokens int Parameters []ToolParameter
Temperature float32 Impl func(*ToolSpec, map[string]interface{}) (string, error)
TopP float32
Toolbox []ToolSpec
} }
type ChatCompletionProvider interface { type ToolParameter struct {
// CreateChatCompletion requests a response to the provided messages. Name string `json:"name"`
// Replies are appended to the given replies struct, and the Type string `json:"type"` // "string", "integer", "boolean"
// complete user-facing response is returned as a string. Required bool `json:"required"`
CreateChatCompletion( Description string `json:"description"`
ctx context.Context, Enum []string `json:"enum,omitempty"`
params RequestParameters,
messages []Message,
) (*Message, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received.
CreateChatCompletionStream(
ctx context.Context,
params RequestParameters,
messages []Message,
chunks chan<- Chunk,
) (*Message, error)
} }
func IsAssistantContinuation(messages []Message) bool { type ToolCall struct {
if len(messages) == 0 { ID string `json:"id" yaml:"-"`
Name string `json:"name" yaml:"tool"`
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
}
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
func NewMessageWithAssistant(content string) *Message {
return &Message{
Role: MessageRoleAssistant,
Content: content,
}
}
func NewMessageWithToolCalls(content string, toolCalls []ToolCall) *Message {
return &Message{
Role: MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
}
func (m MessageRole) IsAssistant() bool {
switch m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false return false
} }
return messages[len(messages)-1].Role == MessageRoleAssistant
func (m MessageRole) IsUser() bool {
switch m {
case MessageRoleUser, MessageRoleToolResult:
return true
}
return false
}
func (m MessageRole) IsSystem() bool {
switch m {
case MessageRoleSystem:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m MessageRole) FriendlyRole() string {
switch m {
case MessageRoleUser:
return "You"
case MessageRoleSystem:
return "System"
case MessageRoleAssistant:
return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default:
return string(m)
}
}
// TODO: remove this
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
} }

View File

@ -1,11 +0,0 @@
package api
import "database/sql"
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}

View File

@ -1,72 +0,0 @@
package api
import (
"time"
)
type MessageRole string
const (
MessageRoleSystem MessageRole = "system"
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
ID uint `gorm:"primaryKey"`
ConversationID *uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
Content string
Role MessageRole
CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
}
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
if len(m) > 0 && m[0].Role == MessageRoleSystem {
if force {
m[0].Content = system
}
return m
} else {
return append([]Message{{
Role: MessageRoleSystem,
Content: system,
}}, m...)
}
}
func (m *MessageRole) IsAssistant() bool {
switch *m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m MessageRole) FriendlyRole() string {
switch m {
case MessageRoleUser:
return "You"
case MessageRoleSystem:
return "System"
case MessageRoleAssistant:
return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default:
return string(m)
}
}

View File

@ -1,98 +0,0 @@
package api
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type ToolSpec struct {
Name string
Description string
Parameters []ToolParameter
Impl func(*ToolSpec, map[string]interface{}) (string, error)
}
type ToolParameter struct {
Name string `json:"name"`
Type string `json:"type"` // "string", "integer", "boolean"
Required bool `json:"required"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ToolCall struct {
ID string `json:"id" yaml:"-"`
Name string `json:"name" yaml:"tool"`
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
}
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@ -6,6 +6,7 @@ import (
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui" "git.mlow.ca/mlow/lmcli/pkg/tui"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -19,17 +20,30 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
shortname := ""
if len(args) == 1 { var opts []tui.LaunchOption
shortname = args[0]
} list, err := cmd.Flags().GetBool("list")
if shortname != ""{
_, err := cmdutil.LookupConversationE(ctx, shortname)
if err != nil { if err != nil {
return err return err
} }
if !list && len(args) == 1 {
shortname := args[0]
if shortname != ""{
conv, err := cmdutil.LookupConversationE(ctx, shortname)
if err != nil {
return err
} }
err = tui.Launch(ctx, shortname) opts = append(opts, tui.WithInitialConversation(conv))
}
}
if list {
opts = append(opts, tui.WithInitialView(shared.ViewConversations))
}
err = tui.Launch(ctx, opts...)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }
@ -40,9 +54,13 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
// -l, --list
cmd.Flags().BoolP("list", "l", false, "View/manage conversations")
applyGenerationFlags(ctx, cmd) applyGenerationFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -27,7 +27,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
return err return err
} }
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone) clone, messageCnt, err := ctx.Conversations.CloneConversation(*toClone)
if err != nil { if err != nil {
return fmt.Errorf("Failed to clone conversation: %v", err) return fmt.Errorf("Failed to clone conversation: %v", err)
} }
@ -40,7 +40,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
return cmd return cmd

View File

@ -83,7 +83,7 @@ func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
if err != nil { if err != nil {
return fmt.Errorf("Error parsing --agent: %w", err) return fmt.Errorf("Error parsing --agent: %w", err)
} }
if agent != "" && !slices.Contains(ctx.GetAgents(), agent) { if agent != "" && agent != "none" && !slices.Contains(ctx.GetAgents(), agent) {
return fmt.Errorf("Unknown agent: %s", agent) return fmt.Errorf("Unknown agent: %s", agent)
} }
return nil return nil

View File

@ -29,9 +29,9 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
} }
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) c := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("could not retrieve conversation messages: %v", err) return fmt.Errorf("could not retrieve conversation messages: %v", err)
} }
@ -58,7 +58,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Conversations.UpdateMessage(lastMessage)
if err != nil { if err != nil {
return fmt.Errorf("could not update the last message: %v", err) return fmt.Errorf("could not update the last message: %v", err)
} }
@ -70,7 +70,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyGenerationFlags(ctx, cmd) applyGenerationFlags(ctx, cmd)

View File

@ -22,11 +22,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) c := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title)
} }
offset, _ := cmd.Flags().GetInt("offset") offset, _ := cmd.Flags().GetInt("offset")
@ -62,11 +62,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
// Update the message in-place // Update the message in-place
inplace, _ := cmd.Flags().GetBool("in-place") inplace, _ := cmd.Flags().GetBool("in-place")
if inplace { if inplace {
return ctx.Store.UpdateMessage(&toEdit) return ctx.Conversations.UpdateMessage(&toEdit)
} }
// Otherwise, create a branch for the edited message // Otherwise, create a branch for the edited message
message, _, err := ctx.Store.CloneBranch(toEdit) message, _, err := ctx.Conversations.CloneBranch(toEdit)
if err != nil { if err != nil {
return err return err
} }
@ -74,11 +74,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
if desiredIdx > 0 { if desiredIdx > 0 {
// update selected reply // update selected reply
messages[desiredIdx-1].SelectedReply = message messages[desiredIdx-1].SelectedReply = message
err = ctx.Store.UpdateMessage(&messages[desiredIdx-1]) err = ctx.Conversations.UpdateMessage(&messages[desiredIdx-1])
} else { } else {
// update selected root // update selected root
conversation.SelectedRoot = message c.SelectedRoot = message
err = ctx.Store.UpdateConversation(conversation) err = ctx.Conversations.UpdateConversation(c)
} }
return err return err
}, },
@ -87,7 +87,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -20,9 +20,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
Short: "List conversations", Short: "List conversations",
Long: `List conversations in order of recent activity`, Long: `List conversations in order of recent activity`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
messages, err := ctx.Store.LatestConversationMessages() list, err := ctx.Conversations.LoadConversationList()
if err != nil { if err != nil {
return fmt.Errorf("Could not fetch conversations: %v", err) return fmt.Errorf("Could not load conversations: %v", err)
} }
type Category struct { type Category struct {
@ -57,12 +57,12 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
all, _ := cmd.Flags().GetBool("all") all, _ := cmd.Flags().GetBool("all")
for _, message := range messages { for _, item := range list.Items {
messageAge := now.Sub(message.CreatedAt) age := now.Sub(item.LastMessageAt)
var category string var category string
for _, c := range categories { for _, c := range categories {
if messageAge < c.cutoff { if age < c.cutoff {
category = c.name category = c.name
break break
} }
@ -70,14 +70,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
formatted := fmt.Sprintf( formatted := fmt.Sprintf(
"%s - %s - %s", "%s - %s - %s",
message.Conversation.ShortName.String, item.ShortName,
util.HumanTimeElapsedSince(messageAge), util.HumanTimeElapsedSince(age),
message.Conversation.Title, item.Title,
) )
categorized[category] = append( categorized[category] = append(
categorized[category], categorized[category],
ConversationLine{messageAge, formatted}, ConversationLine{age, formatted},
) )
} }
@ -93,7 +93,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Printf("%s:\n", category.name) fmt.Printf("%s:\n", category.name)
for _, conv := range conversationLines { for _, conv := range conversationLines {
if conversationsPrinted >= count && !all { if conversationsPrinted >= count && !all {
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted) fmt.Printf("%d remaining conversation(s), use --all to view.\n", list.Total-conversationsPrinted)
break outer break outer
} }

View File

@ -5,6 +5,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -25,12 +26,12 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
messages := []api.Message{{ messages := []conversation.Message{{
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}} }}
conversation, messages, err := ctx.Store.StartConversation(messages...) conversation, messages, err := ctx.Conversations.StartConversation(messages...)
if err != nil { if err != nil {
return fmt.Errorf("Could not start a new conversation: %v", err) return fmt.Errorf("Could not start a new conversation: %v", err)
} }
@ -43,7 +44,7 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.UpdateConversation(conversation) err = ctx.Conversations.UpdateConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation title: %v\n", err) lmcli.Warn("Could not save conversation title: %v\n", err)
} }

View File

@ -5,6 +5,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -25,7 +26,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
messages := []api.Message{{ messages := []conversation.Message{{
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}} }}

View File

@ -4,8 +4,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -23,14 +23,14 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
var toRemove []*api.Conversation var toRemove []*conversation.Conversation
for _, shortName := range args { for _, shortName := range args {
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
toRemove = append(toRemove, conversation) toRemove = append(toRemove, conversation)
} }
var errors []error var errors []error
for _, c := range toRemove { for _, c := range toRemove {
err := ctx.Store.DeleteConversation(c) err := ctx.Conversations.DeleteConversation(c)
if err != nil { if err != nil {
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err)) errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
} }
@ -44,7 +44,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
var completions []string var completions []string
outer: outer:
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) { for _, completion := range ctx.Conversations.ConversationShortNameCompletions(toComplete) {
parts := strings.Split(completion, "\t") parts := strings.Split(completion, "\t")
for _, arg := range args { for _, arg := range args {
if parts[0] == arg { if parts[0] == arg {

View File

@ -30,7 +30,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
generate, _ := cmd.Flags().GetBool("generate") generate, _ := cmd.Flags().GetBool("generate")
if generate { if generate {
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve conversation messages: %v", err) return fmt.Errorf("Could not retrieve conversation messages: %v", err)
} }
@ -46,7 +46,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.UpdateConversation(conversation) err = ctx.Conversations.UpdateConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not update conversation title: %v\n", err) lmcli.Warn("Could not update conversation title: %v\n", err)
} }
@ -57,7 +57,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -5,6 +5,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -28,14 +29,14 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
} }
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) c := cmdutil.LookupConversation(ctx, shortName)
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "") reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
if reply == "" { if reply == "" {
return fmt.Errorf("No reply was provided.") return fmt.Errorf("No reply was provided.")
} }
cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{ cmdutil.HandleConversationReply(ctx, c, true, conversation.Message{
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: reply, Content: reply,
}) })
@ -46,7 +47,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -28,12 +28,12 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
} }
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) c := cmdutil.LookupConversation(ctx, shortName)
// Load the complete thread from the root message // Load the complete thread from the root message
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title)
} }
offset, _ := cmd.Flags().GetInt("offset") offset, _ := cmd.Flags().GetInt("offset")
@ -67,7 +67,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -9,6 +9,8 @@ import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
@ -16,13 +18,13 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
params := api.RequestParameters{ params := provider.RequestParameters{
Model: m, Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens, MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature, Temperature: *ctx.Config.Defaults.Temperature,
@ -39,17 +41,17 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
} }
if system != "" { if system != "" {
messages = api.ApplySystemPrompt(messages, system, false) messages = conversation.ApplySystemPrompt(messages, system, false)
} }
content := make(chan api.Chunk) content := make(chan provider.Chunk)
defer close(content) defer close(content)
// render the content received over the channel // render the content received over the channel
go ShowDelayedContent(content) go ShowDelayedContent(content)
reply, err := provider.CreateChatCompletionStream( reply, err := p.CreateChatCompletionStream(
context.Background(), params, messages, content, context.Background(), params, conversation.MessagesToAPI(messages), content,
) )
if reply.Content != "" { if reply.Content != "" {
@ -66,8 +68,8 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
// short name or exits the program // short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation { func LookupConversation(ctx *lmcli.Context, shortName string) *conversation.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName) c, err := ctx.Conversations.FindConversationByShortName(shortName)
if err != nil { if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err) lmcli.Fatal("Could not lookup conversation: %v\n", err)
} }
@ -77,8 +79,8 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation
return c return c
} }
func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) { func LookupConversationE(ctx *lmcli.Context, shortName string) (*conversation.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName) c, err := ctx.Conversations.FindConversationByShortName(shortName)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err) return nil, fmt.Errorf("Could not lookup conversation: %v", err)
} }
@ -88,8 +90,8 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversatio
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) { func HandleConversationReply(ctx *lmcli.Context, c *conversation.Conversation, persist bool, toSend ...conversation.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot) messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err) lmcli.Fatal("Could not load messages: %v\n", err)
} }
@ -98,40 +100,40 @@ func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bo
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) { func HandleReply(ctx *lmcli.Context, to *conversation.Message, persist bool, messages ...conversation.Message) {
if to == nil { if to == nil {
lmcli.Fatal("Can't prompt from an empty message.") lmcli.Fatal("Can't prompt from an empty message.")
} }
existing, err := ctx.Store.PathToRoot(to) existing, err := ctx.Conversations.PathToRoot(to)
if err != nil { if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err) lmcli.Fatal("Could not load messages: %v\n", err)
} }
RenderConversation(ctx, append(existing, messages...), true) RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []api.Message var savedReplies []conversation.Message
if persist && len(messages) > 0 { if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...) savedReplies, err = ctx.Conversations.Reply(to, messages...)
if err != nil { if err != nil {
lmcli.Warn("Could not save messages: %v\n", err) lmcli.Warn("Could not save messages: %v\n", err)
} }
} }
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant})) RenderMessage(ctx, (&conversation.Message{Role: api.MessageRoleAssistant}))
var lastSavedMessage *api.Message var lastSavedMessage *conversation.Message
lastSavedMessage = to lastSavedMessage = to
if len(savedReplies) > 0 { if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1] lastSavedMessage = &savedReplies[len(savedReplies)-1]
} }
replyCallback := func(reply api.Message) { replyCallback := func(reply conversation.Message) {
if !persist { if !persist {
return return
} }
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply) savedReplies, err = ctx.Conversations.Reply(lastSavedMessage, reply)
if err != nil { if err != nil {
lmcli.Warn("Could not save reply: %v\n", err) lmcli.Warn("Could not save reply: %v\n", err)
} }
@ -144,7 +146,7 @@ func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...
} }
} }
func FormatForExternalPrompt(messages []api.Message, system bool) string { func FormatForExternalPrompt(messages []conversation.Message, system bool) string {
sb := strings.Builder{} sb := strings.Builder{}
for _, message := range messages { for _, message := range messages {
if message.Content == "" { if message.Content == "" {
@ -163,7 +165,7 @@ func FormatForExternalPrompt(messages []api.Message, system bool) string {
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) { func GenerateTitle(ctx *lmcli.Context, messages []conversation.Message) (string, error) {
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below. const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
Example conversation: Example conversation:
@ -188,36 +190,36 @@ Example response:
} }
// Serialize the conversation to JSON // Serialize the conversation to JSON
conversation, err := json.Marshal(msgs) jsonBytes, err := json.Marshal(msgs)
if err != nil { if err != nil {
return "", err return "", err
} }
generateRequest := []api.Message{ generateRequest := []conversation.Message{
{ {
Role: api.MessageRoleSystem, Role: api.MessageRoleSystem,
Content: systemPrompt, Content: systemPrompt,
}, },
{ {
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: string(conversation), Content: string(jsonBytes),
}, },
} }
m, provider, err := ctx.GetModelProvider( m, _, p, err := ctx.GetModelProvider(
*ctx.Config.Conversations.TitleGenerationModel, *ctx.Config.Conversations.TitleGenerationModel, "",
) )
if err != nil { if err != nil {
return "", err return "", err
} }
requestParams := api.RequestParameters{ requestParams := provider.RequestParameters{
Model: m, Model: m,
MaxTokens: 25, MaxTokens: 25,
} }
response, err := provider.CreateChatCompletion( response, err := p.CreateChatCompletion(
context.Background(), requestParams, generateRequest, context.Background(), requestParams, conversation.MessagesToAPI(generateRequest),
) )
if err != nil { if err != nil {
return "", err return "", err
@ -272,7 +274,7 @@ func ShowWaitAnimation(signal chan any) {
// chunked) content is received on the channel, the waiting animation is // chunked) content is received on the channel, the waiting animation is
// replaced by the content. // replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func ShowDelayedContent(content <-chan api.Chunk) { func ShowDelayedContent(content <-chan provider.Chunk) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)
@ -292,7 +294,7 @@ func ShowDelayedContent(content <-chan api.Chunk) {
// RenderConversation renders the given messages to TTY, with optional space // RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters // for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true) // are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) { func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spaceForResponse bool) {
l := len(messages) l := len(messages)
for i, message := range messages { for i, message := range messages {
RenderMessage(ctx, &message) RenderMessage(ctx, &message)
@ -303,7 +305,7 @@ func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResp
} }
} }
func RenderMessage(ctx *lmcli.Context, m *api.Message) { func RenderMessage(ctx *lmcli.Context, m *conversation.Message) {
var messageAge string var messageAge string
if m.CreatedAt.IsZero() { if m.CreatedAt.IsZero() {
messageAge = "now" messageAge = "now"

View File

@ -24,7 +24,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err) return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
} }
@ -37,7 +37,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
if len(args) != 0 { if len(args) != 0 {
return nil, compMode return nil, compMode
} }
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
}, },
} }

View File

@ -0,0 +1,99 @@
package conversation
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
RootMessages []Message `gorm:"-:all"`
LastMessageAt time.Time
}
type MessageMeta struct {
GenerationProvider *string `json:"generation_provider,omitempty"`
GenerationModel *string `json:"generation_model,omitempty"`
}
type Message struct {
ID uint `gorm:"primaryKey"`
CreatedAt time.Time
Metadata MessageMeta
ConversationID *uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
Role api.MessageRole
Content string
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
}
func (m *MessageMeta) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), m)
}
func (m MessageMeta) Value() (driver.Value, error) {
return json.Marshal(m)
}
type ToolCalls []api.ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResults []api.ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]api.ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@ -1,4 +1,4 @@
package lmcli package conversation
import ( import (
"database/sql" "database/sql"
@ -8,43 +8,57 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
sqids "github.com/sqids/sqids-go" sqids "github.com/sqids/sqids-go"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ConversationStore interface { // Repo exposes low-level message and conversation management. See
ConversationByShortName(shortName string) (*api.Conversation, error) // Service for high-level helpers
type Repo interface {
LoadConversationList() (ConversationList, error)
FindConversationByShortName(shortName string) (*Conversation, error)
ConversationShortNameCompletions(search string) []string ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]api.Message, error) GetConversationByID(int uint) (*Conversation, error)
LatestConversationMessages() ([]api.Message, error) GetRootMessages(conversationID uint) ([]Message, error)
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) CreateConversation(title string) (*Conversation, error)
UpdateConversation(conversation *api.Conversation) error UpdateConversation(*Conversation) error
DeleteConversation(conversation *api.Conversation) error DeleteConversation(*Conversation) error
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) DeleteConversationById(id uint) error
MessageByID(messageID uint) (*api.Message, error) GetMessageByID(messageID uint) (*Message, error)
MessageReplies(messageID uint) ([]api.Message, error)
UpdateMessage(message *api.Message) error SaveMessage(message Message) (*Message, error)
DeleteMessage(message *api.Message, prune bool) error UpdateMessage(message *Message) error
CloneBranch(toClone api.Message) (*api.Message, uint, error) DeleteMessage(message *Message, prune bool) error
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) CloneBranch(toClone Message) (*Message, uint, error)
Reply(to *Message, messages ...Message) ([]Message, error)
PathToRoot(message *api.Message) ([]api.Message, error) PathToRoot(message *Message) ([]Message, error)
PathToLeaf(message *api.Message) ([]api.Message, error) PathToLeaf(message *Message) ([]Message, error)
// Retrieves and return the "selected thread" of the conversation.
// The "selected thread" of the conversation is a chain of messages
// starting from the Conversation's SelectedRoot Message, following each
// Message's SelectedReply until the tail Message is reached.
GetSelectedThread(*Conversation) ([]Message, error)
// Start a new conversation with the given messages
StartConversation(messages ...Message) (*Conversation, []Message, error)
CloneConversation(toClone Conversation) (*Conversation, uint, error)
} }
type SQLStore struct { type repo struct {
db *gorm.DB db *gorm.DB
sqids *sqids.Sqids sqids *sqids.Sqids
} }
func NewSQLStore(db *gorm.DB) (*SQLStore, error) { func NewRepo(db *gorm.DB) (Repo, error) {
models := []any{ models := []any{
&api.Conversation{}, &Conversation{},
&api.Message{}, &Message{},
} }
for _, x := range models { for _, x := range models {
@ -55,13 +69,86 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
} }
_sqids, _ := sqids.New(sqids.Options{MinLength: 4}) _sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &SQLStore{db, _sqids}, nil return &repo{db, _sqids}, nil
} }
func (s *SQLStore) createConversation() (*api.Conversation, error) { type ConversationListItem struct {
ID uint
ShortName string
Title string
LastMessageAt time.Time
}
type ConversationList struct {
Total int
Items []ConversationListItem
}
// LoadConversationList loads existing conversations, ordered by the date
// of their latest message, from most recent to oldest.
func (s *repo) LoadConversationList() (ConversationList, error) {
list := ConversationList{}
var convos []Conversation
err := s.db.Order("last_message_at DESC").Find(&convos).Error
if err != nil {
return list, err
}
for _, c := range convos {
list.Items = append(list.Items, ConversationListItem{
ID: c.ID,
ShortName: c.ShortName.String,
Title: c.Title,
LastMessageAt: c.LastMessageAt,
})
}
list.Total = len(list.Items)
return list, nil
}
func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) {
if shortName == "" {
return nil, errors.New("shortName is empty")
}
var conversation Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *repo) ConversationShortNameCompletions(shortName string) []string {
var conversations []Conversation
// ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
}
func (s *repo) GetConversationByID(id uint) (*Conversation, error) {
var conversation Conversation
err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error
if err != nil {
return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err)
}
rootMessages, err := s.GetRootMessages(id)
if err != nil {
return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err)
}
conversation.RootMessages = rootMessages
return &conversation, nil
}
func (s *repo) CreateConversation(title string) (*Conversation, error) {
// Create the new conversation // Create the new conversation
c := &api.Conversation{} c := &Conversation{Title: title}
err := s.db.Save(c).Error err := s.db.Create(c).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,159 +162,61 @@ func (s *SQLStore) createConversation() (*api.Conversation, error) {
return c, nil return c, nil
} }
func (s *SQLStore) UpdateConversation(c *api.Conversation) error { func (s *repo) UpdateConversation(c *Conversation) error {
if c == nil || c.ID == 0 { if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)") return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(c).Error return s.db.Updates(c).Error
} }
func (s *SQLStore) DeleteConversation(c *api.Conversation) error { func (s *repo) DeleteConversation(c *Conversation) error {
// Delete messages first if c == nil || c.ID == 0 {
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error return fmt.Errorf("Conversation is nil or invalid (missing ID)")
}
return s.DeleteConversationById(c.ID)
}
func (s *repo) DeleteConversationById(id uint) error {
if id == 0 {
return fmt.Errorf("Invalid conversation ID: %d", id)
}
err := s.db.Where("conversation_id = ?", id).Delete(&Message{}).Error
if err != nil { if err != nil {
return err return err
} }
return s.db.Delete(c).Error return s.db.Where("id = ?", id).Delete(&Conversation{}).Error
} }
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error { func (s *repo) SaveMessage(m Message) (*Message, error) {
panic("Not yet implemented") if m.Conversation == nil {
//return s.db.Delete(&message).Error return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)")
}
newMessage := m
newMessage.ID = 0
newMessage.CreatedAt = time.Now()
return &newMessage, s.db.Create(&newMessage).Error
} }
func (s *SQLStore) UpdateMessage(m *api.Message) error { func (s *repo) UpdateMessage(m *Message) error {
if m == nil || m.ID == 0 { if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)") return fmt.Errorf("Message is nil or invalid (missing ID)")
} }
return s.db.Updates(m).Error return s.db.Updates(m).Error
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *repo) DeleteMessage(message *Message, prune bool) error {
var conversations []api.Conversation return s.db.Delete(&message).Error
// ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
}
}
return completions
} }
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) { func (s *repo) GetMessageByID(messageID uint) (*Message, error) {
if shortName == "" { var message Message
return nil, errors.New("shortName is empty")
}
var conversation api.Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
var rootMessages []api.Message
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil {
return nil, err
}
return rootMessages, nil
}
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
var message api.Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err return &message, err
} }
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) { // Reply to a message with a series of messages (each followed by the next)
var replies []api.Message func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) {
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error var savedMessages []Message
return replies, err
}
// StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
// Create new conversation
conversation, err := s.createConversation()
if err != nil {
return nil, nil, err
}
// Create first message
messages[0].Conversation = conversation
err = s.db.Create(&messages[0]).Error
if err != nil {
return nil, nil, err
}
// Update conversation's selected root message
conversation.SelectedRoot = &messages[0]
err = s.UpdateConversation(conversation)
if err != nil {
return nil, nil, err
}
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]api.Message{messages[0]}, newMessages...)
}
return conversation, messages, nil
}
// CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID)
if err != nil {
return nil, 0, err
}
clone, err := s.createConversation()
if err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
}
clone.Title = toClone.Title + " - Clone"
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = &clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
// Reply to a message with a series of messages (each following the next)
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
var savedMessages []api.Message
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to currentParent := to
@ -256,23 +245,26 @@ func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Messag
return nil return nil
}) })
if err != nil {
return savedMessages, err
}
to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt
err = s.UpdateConversation(to.Conversation)
return savedMessages, err return savedMessages, err
} }
// CloneBranch returns a deep clone of the given message and its replies, returning // CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as // a new message object. The new message will be attached to the same parent as
// the messageToClone // the messageToClone
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) { func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) {
newMessage := messageToClone newMessage := messageToClone
newMessage.ID = 0 newMessage.ID = 0
newMessage.Replies = nil newMessage.Replies = nil
newMessage.SelectedReplyID = nil newMessage.SelectedReplyID = nil
newMessage.SelectedReply = nil newMessage.SelectedReply = nil
originalReplies, err := s.MessageReplies(messageToClone.ID) originalReplies := messageToClone.Replies
if err != nil {
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
}
if err := s.db.Create(&newMessage).Error; err != nil { if err := s.db.Create(&newMessage).Error; err != nil {
return nil, 0, fmt.Errorf("Could not clone message: %s", err) return nil, 0, fmt.Errorf("Could not clone message: %s", err)
@ -304,19 +296,19 @@ func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint,
return &newMessage, replyCount, nil return &newMessage, replyCount, nil
} }
func fetchMessages(db *gorm.DB) ([]api.Message, error) { func fetchMessages(db *gorm.DB) ([]Message, error) {
var messages []api.Message var messages []Message
if err := db.Preload("Conversation").Find(&messages).Error; err != nil { if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err) return nil, fmt.Errorf("Could not fetch messages: %v", err)
} }
messageMap := make(map[uint]api.Message) messageMap := make(map[uint]Message)
for i, message := range messages { for i, message := range messages {
messageMap[messages[i].ID] = message messageMap[messages[i].ID] = message
} }
// Create a map to store replies by their parent ID // Create a map to store replies by their parent ID
repliesMap := make(map[uint][]api.Message) repliesMap := make(map[uint][]Message)
for i, message := range messages { for i, message := range messages {
if messages[i].ParentID != nil { if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
@ -326,7 +318,7 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) {
// Assign replies, parent, and selected reply to each message // Assign replies, parent, and selected reply to each message
for i := range messages { for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists { if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]api.Message, len(replies)) messages[i].Replies = make([]Message, len(replies))
for j, m := range replies { for j, m := range replies {
messages[i].Replies[j] = m messages[i].Replies[j] = m
} }
@ -345,21 +337,51 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) {
return messages, nil return messages, nil
} }
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) { func (r repo) GetRootMessages(conversationID uint) ([]Message, error) {
var messages []api.Message var rootMessages []Message
err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil {
return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err)
}
return rootMessages, nil
}
func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) {
var messages []Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Create a map to store messages by their ID // Create a map to store messages by their ID
messageMap := make(map[uint]*api.Message) messageMap := make(map[uint]*Message, len(messages))
for i := range messages { for i := range messages {
messageMap[messages[i].ID] = &messages[i] messageMap[messages[i].ID] = &messages[i]
} }
// Construct Replies
repliesMap := make(map[uint][]*Message, len(messages))
for _, m := range messageMap {
if m.ParentID == nil {
continue
}
if p, ok := messageMap[*m.ParentID]; ok {
repliesMap[p.ID] = append(repliesMap[p.ID], m)
}
}
// Add replies to messages
for _, m := range messageMap {
if replies, ok := repliesMap[m.ID]; ok {
m.Replies = make([]Message, len(replies))
for idx, reply := range replies {
m.Replies[idx] = *reply
}
}
}
// Build the path // Build the path
var path []api.Message var path []Message
nextID := &message.ID nextID := &message.ID
for { for {
@ -382,12 +404,12 @@ func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *u
// PathToRoot traverses the provided message's Parent until reaching the tree // PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order // root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided) // (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) { func (s *repo) PathToRoot(message *Message) ([]Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
} }
path, err := s.buildPath(message, func(m *api.Message) *uint { path, err := s.buildPath(message, func(m *Message) *uint {
return m.ParentID return m.ParentID
}) })
if err != nil { if err != nil {
@ -401,33 +423,98 @@ func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
// PathToLeaf traverses the provided message's SelectedReply until reaching a // PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological // tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf) // order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) { func (s *repo) PathToLeaf(message *Message) ([]Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
} }
return s.buildPath(message, func(m *api.Message) *uint { return s.buildPath(message, func(m *Message) *uint {
return m.SelectedReplyID return m.SelectedReplyID
}) })
} }
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) { func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) {
var latestMessages []api.Message if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
subQuery := s.db.Model(&api.Message{}). }
Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id")
err := s.db.Model(&api.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Group("messages.conversation_id").
Order("created_at DESC").
Preload("Conversation").
Find(&latestMessages).Error
// Create new conversation
conversation, err := s.CreateConversation("")
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
messages[0].Conversation = conversation
// Create first message
firstMessage, err := s.SaveMessage(messages[0])
if err != nil {
return nil, nil, err
}
messages[0] = *firstMessage
// Update conversation's selected root message
conversation.RootMessages = []Message{messages[0]}
conversation.SelectedRoot = &messages[0]
conversation.LastMessageAt = messages[0].CreatedAt
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]Message{messages[0]}, newMessages...)
conversation.LastMessageAt = messages[len(messages)-1].CreatedAt
} }
return latestMessages, nil err = s.UpdateConversation(conversation)
return conversation, messages, err
}
// CloneConversation clones the given conversation and all of its meesages
func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) {
rootMessages, err := s.GetRootMessages(toClone.ID)
if err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
}
clone, err := s.CreateConversation(toClone.Title + " - Clone")
if err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
}
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = &clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) {
if c.SelectedRoot == nil {
return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug")
}
return s.PathToLeaf(c.SelectedRoot)
} }

55
pkg/conversation/tools.go Normal file
View File

@ -0,0 +1,55 @@
package conversation
import (
"git.mlow.ca/mlow/lmcli/pkg/api"
)
// ApplySystemPrompt updates the contents of an existing system Message if it
// exists, or returns a new slice with the system Message prepended.
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
if len(m) > 0 && m[0].Role == api.MessageRoleSystem {
if force {
m[0].Content = system
}
return m
} else {
return append([]Message{{
Role: api.MessageRoleSystem,
Content: system,
}}, m...)
}
}
func MessageToAPI(m Message) api.Message {
return api.Message{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolResults: m.ToolResults,
}
}
func MessagesToAPI(messages []Message) []api.Message {
ret := make([]api.Message, 0, len(messages))
for _, m := range messages {
ret = append(ret, MessageToAPI(m))
}
return ret
}
func MessageFromAPI(m api.Message) Message {
return Message{
Role: m.Role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolResults: m.ToolResults,
}
}
func MessagesFromAPI(messages []api.Message) []Message {
ret := make([]Message, 0, len(messages))
for _, m := range messages {
ret = append(ret, MessageFromAPI(m))
}
return ret
}

View File

@ -15,8 +15,7 @@ type Config struct {
Temperature *float32 `yaml:"temperature" default:"0.2"` Temperature *float32 `yaml:"temperature" default:"0.2"`
SystemPrompt string `yaml:"systemPrompt,omitempty"` SystemPrompt string `yaml:"systemPrompt,omitempty"`
SystemPromptFile string `yaml:"systemPromptFile,omitempty"` SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
// CLI only Agent string `yaml:"agent"`
Agent string `yaml:"-"`
} `yaml:"defaults"` } `yaml:"defaults"`
Conversations *struct { Conversations *struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
@ -32,10 +31,12 @@ type Config struct {
} `yaml:"agents"` } `yaml:"agents"`
Providers []*struct { Providers []*struct {
Name string `yaml:"name,omitempty"` Name string `yaml:"name,omitempty"`
Display string `yaml:"display,omitempty"`
Kind string `yaml:"kind"` Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"` BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"` APIKey string `yaml:"apiKey,omitempty"`
Models []string `yaml:"models"` Models []string `yaml:"models"`
Headers map[string]string `yaml:"headers"`
} `yaml:"providers"` } `yaml:"providers"`
} }

View File

@ -1,21 +1,28 @@
package lmcli package lmcli
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
) )
type Agent struct { type Agent struct {
@ -27,7 +34,7 @@ type Agent struct {
type Context struct { type Context struct {
// high level app configuration, may be mutated at runtime // high level app configuration, may be mutated at runtime
Config Config Config Config
Store ConversationStore Conversations conversation.Repo
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
} }
@ -38,23 +45,55 @@ func NewContext() (*Context, error) {
return nil, err return nil, err
} }
databaseFile := filepath.Join(dataDir(), "conversations.db") store, err := getConversationService()
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
store, err := NewSQLStore(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
return &Context{*config, store, chroma}, nil return &Context{*config, store, chroma}, nil
} }
func createOrOpenAppend(path string) (*os.File, error) {
var file *os.File
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
file, err = os.Create(path)
if err != nil {
return nil, err
}
} else {
file, err = os.OpenFile(path, os.O_APPEND, fs.ModeAppend)
if err != nil {
return nil, err
}
}
return file, nil
}
func getConversationService() (conversation.Repo, error) {
databaseFile := filepath.Join(dataDir(), "conversations.db")
gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log"))
if err != nil {
return nil, fmt.Errorf("Could not open database log file: %v", err)
}
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
Logger: logger.New(log.New(gormLogFile, "\n", log.LstdFlags), logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
Colorful: true,
}),
})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
repo, err := conversation.NewRepo(db)
if err != nil {
return nil, err
}
return repo, nil
}
func (c *Context) GetModels() (models []string) { func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int) modelCounts := make(map[string]int)
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
@ -86,7 +125,7 @@ func (c *Context) GetAgents() (agents []string) {
} }
func (c *Context) GetAgent(name string) *Agent { func (c *Context) GetAgent(name string) *Agent {
if name == "" { if name == "" || name == "none" {
return nil return nil
} }
@ -123,11 +162,10 @@ func (c *Context) DefaultSystemPrompt() string {
return c.Config.Defaults.SystemPrompt return c.Config.Defaults.SystemPrompt
} }
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")
var provider string if provider == "" && len(parts) > 1 {
if len(parts) > 1 {
model = parts[0] model = parts[0]
provider = parts[1] provider = parts[1]
} }
@ -150,7 +188,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &anthropic.AnthropicClient{ return model, name, &anthropic.AnthropicClient{
BaseURL: url, BaseURL: url,
APIKey: p.APIKey, APIKey: p.APIKey,
}, nil }, nil
@ -159,7 +197,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &google.Client{ return model, name, &google.Client{
BaseURL: url, BaseURL: url,
APIKey: p.APIKey, APIKey: p.APIKey,
}, nil }, nil
@ -168,7 +206,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &ollama.OllamaClient{ return model, name, &ollama.OllamaClient{
BaseURL: url, BaseURL: url,
}, nil }, nil
case "openai": case "openai":
@ -176,17 +214,18 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &openai.OpenAIClient{ return model, name, &openai.OpenAIClient{
BaseURL: url, BaseURL: url,
APIKey: p.APIKey, APIKey: p.APIKey,
Headers: p.Headers,
}, nil }, nil
default: default:
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
} }
} }
} }
} }
return "", nil, fmt.Errorf("unknown model: %s", model) return "", "", nil, fmt.Errorf("unknown model: %s", model)
} }
func configDir() string { func configDir() string {

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
) )
const ANTHROPIC_VERSION = "2023-06-01" const ANTHROPIC_VERSION = "2023-06-01"
@ -117,7 +118,7 @@ func convertTools(tools []api.ToolSpec) []Tool {
} }
func createChatCompletionRequest( func createChatCompletionRequest(
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) (string, ChatCompletionRequest) { ) (string, ChatCompletionRequest) {
requestMessages := make([]ChatCompletionMessage, 0, len(messages)) requestMessages := make([]ChatCompletionMessage, 0, len(messages))
@ -188,7 +189,8 @@ func createChatCompletionRequest(
} }
var prefill string var prefill string
if api.IsAssistantContinuation(messages) { if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
// Prompting on an assitant message, use its content as prefill
prefill = messages[len(messages)-1].Content prefill = messages[len(messages)-1].Content
} }
@ -226,7 +228,7 @@ func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionReque
func (c *AnthropicClient) CreateChatCompletion( func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
@ -253,9 +255,9 @@ func (c *AnthropicClient) CreateChatCompletion(
func (c *AnthropicClient) CreateChatCompletionStream( func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
output chan<- api.Chunk, output chan<- provider.Chunk,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages") return nil, fmt.Errorf("can't create completion from no messages")
@ -349,9 +351,10 @@ func (c *AnthropicClient) CreateChatCompletionStream(
firstChunkReceived = true firstChunkReceived = true
} }
block.Text += text block.Text += text
output <- api.Chunk{ output <- provider.Chunk{
Content: text, Content: text,
TokenCount: 1, // rough, anthropic performs some chunking
TokenCount: uint(len(strings.Split(text, " "))),
} }
} }
case "input_json_delta": case "input_json_delta":
@ -436,15 +439,9 @@ func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error)
} }
} }
message := &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
ToolCalls: toolCalls,
}
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
message.Role = api.MessageRoleToolCall return api.NewMessageWithToolCalls(content.String(), toolCalls), nil
} }
return message, nil return api.NewMessageWithAssistant(content.String()), nil
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
) )
type Client struct { type Client struct {
@ -172,7 +173,7 @@ func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionRespons
} }
func createGenerateContentRequest( func createGenerateContentRequest(
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) (*GenerateContentRequest, error) { ) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages)) requestContents := make([]Content, 0, len(messages))
@ -279,7 +280,7 @@ func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
func (c *Client) CreateChatCompletion( func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
@ -336,24 +337,17 @@ func (c *Client) CreateChatCompletion(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
Role: api.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return &api.Message{ return api.NewMessageWithAssistant(content), nil
Role: api.MessageRoleAssistant,
Content: content,
}, nil
} }
func (c *Client) CreateChatCompletionStream( func (c *Client) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
output chan<- api.Chunk, output chan<- provider.Chunk,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
@ -425,7 +419,7 @@ func (c *Client) CreateChatCompletionStream(
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- api.Chunk{ output <- provider.Chunk{
Content: part.Text, Content: part.Text,
TokenCount: uint(tokens), TokenCount: uint(tokens),
} }
@ -434,17 +428,9 @@ func (c *Client) CreateChatCompletionStream(
} }
} }
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
Role: api.MessageRoleToolCall,
Content: content.String(),
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return &api.Message{ return api.NewMessageWithAssistant(content.String()), nil
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
) )
type OllamaClient struct { type OllamaClient struct {
@ -42,7 +43,7 @@ type OllamaResponse struct {
} }
func createOllamaRequest( func createOllamaRequest(
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) OllamaRequest { ) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages)) requestMessages := make([]OllamaMessage, 0, len(messages))
@ -82,7 +83,7 @@ func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
func (c *OllamaClient) CreateChatCompletion( func (c *OllamaClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
@ -114,17 +115,14 @@ func (c *OllamaClient) CreateChatCompletion(
return nil, err return nil, err
} }
return &api.Message{ return api.NewMessageWithAssistant(completionResp.Message.Content), nil
Role: api.MessageRoleAssistant,
Content: completionResp.Message.Content,
}, nil
} }
func (c *OllamaClient) CreateChatCompletionStream( func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
output chan<- api.Chunk, output chan<- provider.Chunk,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
@ -173,7 +171,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
output <- api.Chunk{ output <- provider.Chunk{
Content: streamResp.Message.Content, Content: streamResp.Message.Content,
TokenCount: 1, TokenCount: 1,
} }
@ -181,8 +179,5 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
} }
return &api.Message{ return api.NewMessageWithAssistant(content.String()), nil
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
} }

View File

@ -11,11 +11,13 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
) )
type OpenAIClient struct { type OpenAIClient struct {
APIKey string APIKey string
BaseURL string BaseURL string
Headers map[string]string
} }
type ChatCompletionMessage struct { type ChatCompletionMessage struct {
@ -139,7 +141,7 @@ func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
} }
func createChatCompletionRequest( func createChatCompletionRequest(
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) ChatCompletionRequest { ) ChatCompletionRequest {
requestMessages := make([]ChatCompletionMessage, 0, len(messages)) requestMessages := make([]ChatCompletionMessage, 0, len(messages))
@ -198,6 +200,9 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey) req.Header.Set("Authorization", "Bearer "+c.APIKey)
for header, val := range c.Headers {
req.Header.Set(header, val)
}
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
@ -215,7 +220,7 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest)
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
@ -248,24 +253,17 @@ func (c *OpenAIClient) CreateChatCompletion(
toolCalls := choice.Message.ToolCalls toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
Role: api.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return &api.Message{ return api.NewMessageWithAssistant(content), nil
Role: api.MessageRoleAssistant,
Content: content,
}, nil
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params provider.RequestParameters,
messages []api.Message, messages []api.Message,
output chan<- api.Chunk, output chan<- provider.Chunk,
) (*api.Message, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
@ -329,7 +327,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
if len(delta.Content) > 0 { if len(delta.Content) > 0 {
output <- api.Chunk{ output <- provider.Chunk{
Content: delta.Content, Content: delta.Content,
TokenCount: 1, TokenCount: 1,
} }
@ -338,15 +336,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
Role: api.MessageRoleToolCall,
Content: content.String(),
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return &api.Message{ return api.NewMessageWithAssistant(content.String()), nil
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
} }

41
pkg/provider/provider.go Normal file
View File

@ -0,0 +1,41 @@
package provider
import (
"context"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type Chunk struct {
Content string
TokenCount uint
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
Toolbox []api.ToolSpec
}
type ChatCompletionProvider interface {
// CreateChatCompletion generates a chat completion response to the
// provided messages.
CreateChatCompletion(
ctx context.Context,
params RequestParameters,
messages []api.Message,
) (*api.Message, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel.
CreateChatCompletionStream(
ctx context.Context,
params RequestParameters,
messages []api.Message,
chunks chan<- Chunk,
) (*api.Message, error)
}

View File

@ -0,0 +1,260 @@
package list
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type Option struct {
Label string
Value interface{}
}
type OptionGroup struct {
Name string
Options []Option
}
type Model struct {
ID int
HeaderStyle lipgloss.Style
ItemStyle lipgloss.Style
SelectedStyle lipgloss.Style
ItemRender func(Option, bool) string
Width int
Height int
optionGroups []OptionGroup
selected int
filterInput textinput.Model
filteredIndices []filteredIndex
content viewport.Model
itemYOffsets []int
}
type filteredIndex struct {
groupIndex int
optionIndex int
}
type MsgOptionSelected struct {
ID int
Option Option
}
func New(opts []Option) Model {
return NewWithGroups([]OptionGroup{{Options: opts}})
}
func NewWithGroups(groups []OptionGroup) Model {
ti := textinput.New()
ti.Prompt = "/"
ti.PromptStyle = lipgloss.NewStyle().Faint(true)
m := Model{
HeaderStyle: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("12")).Padding(1, 0, 1, 1),
ItemStyle: lipgloss.NewStyle(),
SelectedStyle: lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6")),
optionGroups: groups,
selected: 0,
filterInput: ti,
filteredIndices: make([]filteredIndex, 0),
content: viewport.New(0, 0),
itemYOffsets: make([]int, 0),
}
m.filterItems()
m.content.SetContent(m.renderOptionsList())
return m
}
func (m *Model) Focused() {
m.filterInput.Focused()
}
func (m *Model) Focus() {
m.filterInput.Focus()
}
func (m *Model) Blur() {
m.filterInput.Blur()
}
func (m *Model) filterItems() {
filterText := strings.ToLower(m.filterInput.Value())
var prevSelection *filteredIndex
if m.selected <= len(m.filteredIndices)-1 {
prevSelection = &m.filteredIndices[m.selected]
}
m.filteredIndices = make([]filteredIndex, 0)
for groupIndex, group := range m.optionGroups {
for optionIndex, option := range group.Options {
if filterText == "" ||
strings.Contains(strings.ToLower(option.Label), filterText) ||
(group.Name != "" && strings.Contains(strings.ToLower(group.Name), filterText)) {
m.filteredIndices = append(m.filteredIndices, filteredIndex{groupIndex, optionIndex})
}
}
}
found := false
if len(m.filteredIndices) > 0 && prevSelection != nil {
// Preserve previous selection if possible
for i, filterIdx := range m.filteredIndices {
if prevSelection.groupIndex == filterIdx.groupIndex && prevSelection.optionIndex == filterIdx.optionIndex {
m.selected = i
found = true
break
}
}
}
if !found {
m.selected = 0
}
}
func (m *Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
if m.filterInput.Focused() {
switch msg.String() {
case "esc":
m.filterInput.Blur()
m.filterInput.SetValue("")
m.filterItems()
m.refreshContent()
return *m, shared.KeyHandled(msg)
case "enter":
m.filterInput.Blur()
m.refreshContent()
break
case "up", "down":
break
default:
m.filterInput, cmd = m.filterInput.Update(msg)
m.filterItems()
m.refreshContent()
return *m, cmd
}
}
switch msg.String() {
case "up", "k":
m.moveSelection(-1)
return *m, shared.KeyHandled(msg)
case "down", "j":
m.moveSelection(1)
return *m, shared.KeyHandled(msg)
case "enter":
return *m, func() tea.Msg {
idx := m.filteredIndices[m.selected]
return MsgOptionSelected{
ID: m.ID,
Option: m.optionGroups[idx.groupIndex].Options[idx.optionIndex],
}
}
case "/":
m.filterInput.Focus()
return *m, textinput.Blink
}
}
m.content, cmd = m.content.Update(msg)
return *m, cmd
}
func (m *Model) refreshContent() {
m.content.SetContent(m.renderOptionsList())
m.ensureSelectedVisible()
}
func (m *Model) ensureSelectedVisible() {
if m.selected == 0 {
m.content.GotoTop()
} else if m.selected == len(m.filteredIndices)-1 {
m.content.GotoBottom()
} else {
tuiutil.ScrollIntoView(&m.content, m.itemYOffsets[m.selected], 0)
}
}
func (m *Model) moveSelection(delta int) {
prev := m.selected
m.selected = min(len(m.filteredIndices)-1, max(0, m.selected+delta))
if prev != m.selected {
m.refreshContent()
}
}
func (m *Model) View() string {
filter := ""
if m.filterInput.Focused() {
m.filterInput.Width = m.Width
filter = m.filterInput.View()
}
contentHeight := m.Height - tuiutil.Height(filter)
m.content.Width, m.content.Height = m.Width, contentHeight
parts := []string{m.content.View()}
if filter != "" {
parts = append(parts, filter)
}
return lipgloss.JoinVertical(lipgloss.Left, parts...)
}
func (m *Model) renderOptionsList() string {
yOffset := 0
lastGroupIndex := -1
m.itemYOffsets = make([]int, len(m.filteredIndices))
var sb strings.Builder
for i, idx := range m.filteredIndices {
if idx.groupIndex != lastGroupIndex {
group := m.optionGroups[idx.groupIndex].Name
if group != "" {
headingStr := m.HeaderStyle.Render(group)
yOffset += tuiutil.Height(headingStr)
sb.WriteString(headingStr)
sb.WriteRune('\n')
}
lastGroupIndex = idx.groupIndex
}
m.itemYOffsets[i] = yOffset
option := m.optionGroups[idx.groupIndex].Options[idx.optionIndex]
var item string
if m.ItemRender != nil {
item = m.ItemRender(option, i == m.selected)
} else {
prefix := " "
if i == m.selected {
prefix = "> "
item = m.SelectedStyle.Render(option.Label)
} else {
item = m.ItemStyle.Render(option.Label)
}
item = fmt.Sprintf("%s%s", prefix, item)
}
sb.WriteString(item)
yOffset += tuiutil.Height(item)
if i < len(m.filteredIndices)-1 {
sb.WriteRune('\n')
}
}
return sb.String()
}

281
pkg/tui/model/model.go Normal file
View File

@ -0,0 +1,281 @@
package model
import (
"context"
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"github.com/charmbracelet/lipgloss"
)
type AppModel struct {
Ctx *lmcli.Context
Conversations conversation.ConversationList
Conversation conversation.Conversation
Messages []conversation.Message
Model string
ProviderName string
Provider provider.ChatCompletionProvider
Agent *lmcli.Agent
}
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Model: *ctx.Config.Defaults.Model,
}
if initialConversation == nil {
app.NewConversation()
} else {
}
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model
app.ProviderName = provider
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app
}
var (
defaultStyle = lipgloss.NewStyle().Faint(true)
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
)
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
}
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
func (m *AppModel) ClearConversation() {
m.Conversation = conversation.Conversation{}
m.Messages = []conversation.Message{}
}
func (m *AppModel) ApplySystemPrompt() {
var system string
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system == "" {
system = m.Ctx.DefaultSystemPrompt()
}
if system != "" {
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
}
}
func (m *AppModel) NewConversation() {
m.ClearConversation()
m.ApplySystemPrompt()
}
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
}
}
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
return a.Ctx.Conversations.UpdateMessage(message)
}
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
if len(conv.RootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Conversations.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = a.Ctx.Conversations.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistMessages() ([]conversation.Message, error) {
messages := make([]conversation.Message, len(a.Messages))
for i, m := range a.Messages {
if i == 0 && m.ID == 0 {
m.Conversation = &a.Conversation
m, err := a.Ctx.Conversations.SaveMessage(m)
if err != nil {
return nil, fmt.Errorf("Could not create first message %d: %v", a.Messages[i].ID, err)
}
messages[i] = *m
// let's set the conversation root message(s), as this is the first message
m.Conversation.RootMessages = []conversation.Message{*m}
m.Conversation.SelectedRoot = &m.Conversation.RootMessages[0]
a.Ctx.Conversations.UpdateConversation(m.Conversation)
} else if m.ID > 0 {
// Existing message, update it
err := a.Ctx.Conversations.UpdateMessage(&m)
if err != nil {
return nil, fmt.Errorf("Could not update message %d: %v", a.Messages[i].ID, err)
}
messages[i] = m
} else if i > 0 {
// New message, reply to previous
replies, err := a.Ctx.Conversations.Reply(&messages[i-1], m)
if err != nil {
return nil, fmt.Errorf("Could not reply with new message: %v", err)
}
messages[i] = replies[0]
} else {
return nil, fmt.Errorf("No messages to reply to (this is a bug)")
}
}
return messages, nil
}
func (a *AppModel) PersistConversation() (conversation.Conversation, error) {
conv := a.Conversation
var err error
if a.Conversation.ID > 0 {
err = a.Ctx.Conversations.UpdateConversation(&conv)
} else {
c, e := a.Ctx.Conversations.CreateConversation("")
err = e
if e == nil && c != nil {
conv = *c
}
}
return conv, err
}
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent == nil {
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
}
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
}
func (a *AppModel) Prompt(
messages []conversation.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*conversation.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
}
params := provider.RequestParameters{
Model: model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature,
}
if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-stopSignal:
cancel()
}
}()
msg, err := p.CreateChatCompletionStream(
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
)
if msg != nil {
msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
return &msg, err
}
return nil, err
}

View File

@ -1,52 +1,66 @@
package shared package shared
import ( import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
) )
type Values struct { // An analogue to tea.Model with support for checking if the model has been
ConvShortname string // initialized before
} type ViewModel interface {
Init() tea.Cmd
Update(tea.Msg) (ViewModel, tea.Cmd)
type Shared struct { // View methods
Ctx *lmcli.Context Header(width int) string
Values *Values // Render the view's main content into a container of the given dimensions
Width int Content(width, height int) string
Height int Footer(width int) string
Err error
}
// a convenience struct for holding rendered content for indiviudal UI
// elements
type Sections struct {
Header string
Content string
Error string
Input string
Footer string
}
type (
// send to change the current state
MsgViewChange View
// sent to a state when it is entered
MsgViewEnter struct{}
// sent when an error occurs
MsgError error
)
func WrapError(err error) tea.Cmd {
return func() tea.Msg {
return MsgError(err)
}
} }
type View int type View int
const ( const (
StateChat View = iota ViewChat View = iota
StateConversations ViewConversations
//StateSettings ViewSettings
//StateHelp //StateHelp
) )
type (
// send to change the current state
MsgViewChange View
// sent to a state when it is entered, with the view we're leaving
MsgViewEnter View
// sent when a recoverable error occurs (displayed to user)
MsgError struct { Err error }
// sent when the view has handled a key input
MsgKeyHandled tea.KeyMsg
)
func ViewEnter(from View) tea.Cmd {
return func() tea.Msg {
return MsgViewEnter(from)
}
}
func ChangeView(view View) tea.Cmd {
return func() tea.Msg {
return MsgViewChange(view)
}
}
func KeyHandled(key tea.KeyMsg) tea.Cmd {
return func() tea.Msg {
return MsgKeyHandled(key)
}
}
func WrapError(err error) tea.Cmd {
return func() tea.Msg {
return MsgError{ Err: err }
}
}
func AsMsgError(err error) MsgError {
return MsgError{ Err: err }
}

View File

@ -1,132 +1,163 @@
package tui package tui
// The terminal UI for lmcli, launched from the `lmcli chat` command
// TODO:
// - change model
// - rename conversation
// - set system prompt
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/chat" "git.mlow.ca/mlow/lmcli/pkg/tui/views/chat"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations" "git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/settings"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
) )
// Application model
type Model struct { type Model struct {
shared.Shared App *model.AppModel
state shared.View // window size
chat chat.Model width int
conversations conversations.Model height int
// errors to display
// TODO: allow dismissing errors
errs []error
activeView shared.View
views map[shared.View]shared.ViewModel
} }
func initialModel(ctx *lmcli.Context, values shared.Values) Model { func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model {
app := model.NewAppModel(ctx, opts.InitialConversation)
m := Model{ m := Model{
Shared: shared.Shared{ App: app,
Ctx: ctx, activeView: opts.InitialView,
Values: &values, views: map[shared.View]shared.ViewModel{
shared.ViewChat: chat.Chat(app),
shared.ViewConversations: conversations.Conversations(app),
shared.ViewSettings: settings.Settings(app),
}, },
} }
m.state = shared.StateChat return &m
m.chat = chat.Chat(m.Shared)
m.conversations = conversations.Conversations(m.Shared)
return m
} }
func (m Model) Init() tea.Cmd { func (m *Model) Init() tea.Cmd {
return tea.Batch(
m.conversations.Init(),
m.chat.Init(),
func() tea.Msg {
return shared.MsgViewChange(m.state)
},
)
}
func (m *Model) handleGlobalInput(msg tea.KeyMsg) (bool, tea.Cmd) {
// delegate input to the active child state first, only handling it at the
// global level if the child state does not
var cmds []tea.Cmd var cmds []tea.Cmd
switch m.state { for _, v := range m.views {
case shared.StateChat: // Init views
handled, cmd := m.chat.HandleInput(msg) cmds = append(cmds, v.Init())
cmds = append(cmds, cmd)
if handled {
m.chat, cmd = m.chat.Update(nil)
cmds = append(cmds, cmd)
return true, tea.Batch(cmds...)
} }
case shared.StateConversations: cmds = append(cmds, func() tea.Msg {
handled, cmd := m.conversations.HandleInput(msg) // Initial view change
cmds = append(cmds, cmd) return shared.MsgViewChange(m.activeView)
if handled { })
m.conversations, cmd = m.conversations.Update(nil) return tea.Batch(cmds...)
cmds = append(cmds, cmd)
return true, tea.Batch(cmds...)
} }
func (m *Model) handleGlobalInput(msg tea.KeyMsg) tea.Cmd {
view, cmd := m.views[m.activeView].Update(msg)
m.views[m.activeView] = view
if cmd != nil {
return cmd
} }
switch msg.String() { switch msg.String() {
case "ctrl+c", "ctrl+q": case "ctrl+c", "ctrl+q":
return true, tea.Quit return tea.Quit
} }
return false, nil return nil
} }
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height
case tea.KeyMsg: case tea.KeyMsg:
handled, cmd := m.handleGlobalInput(msg) cmd := m.handleGlobalInput(msg)
if handled { if cmd != nil {
return m, cmd return m, cmd
} }
case shared.MsgViewChange: case shared.MsgViewChange:
m.state = shared.View(msg) currView := m.activeView
switch m.state { m.activeView = shared.View(msg)
case shared.StateChat: return m, tea.Batch(tea.WindowSize(), shared.ViewEnter(currView))
m.chat.HandleResize(m.Width, m.Height) case shared.MsgError:
case shared.StateConversations: m.errs = append(m.errs, msg.Err)
m.conversations.HandleResize(m.Width, m.Height)
}
return m, func() tea.Msg { return shared.MsgViewEnter(struct{}{}) }
case tea.WindowSizeMsg:
m.Width, m.Height = msg.Width, msg.Height
} }
var cmd tea.Cmd view, cmd := m.views[m.activeView].Update(msg)
switch m.state { m.views[m.activeView] = view
case shared.StateConversations: return m, cmd
m.conversations, cmd = m.conversations.Update(msg)
case shared.StateChat:
m.chat, cmd = m.chat.Update(msg)
}
if cmd != nil {
cmds = append(cmds, cmd)
} }
return m, tea.Batch(cmds...) func (m *Model) View() string {
} if m.width == 0 || m.height == 0 {
// we're dimensionless!
func (m Model) View() string {
switch m.state {
case shared.StateConversations:
return m.conversations.View()
case shared.StateChat:
return m.chat.View()
}
return "" return ""
} }
func Launch(ctx *lmcli.Context, convShortname string) error { header := m.views[m.activeView].Header(m.width)
p := tea.NewProgram(initialModel(ctx, shared.Values{ConvShortname: convShortname}), tea.WithAltScreen()) footer := m.views[m.activeView].Footer(m.width)
if _, err := p.Run(); err != nil { fixedUIHeight := tuiutil.Height(header) + tuiutil.Height(footer)
errBanners := make([]string, len(m.errs))
for idx, err := range m.errs {
errBanners[idx] = tuiutil.ErrorBanner(err, m.width)
fixedUIHeight += tuiutil.Height(errBanners[idx])
}
content := m.views[m.activeView].Content(m.width, m.height-fixedUIHeight)
sections := make([]string, 0, 4)
if header != "" {
sections = append(sections, header)
}
if content != "" {
sections = append(sections, content)
}
if footer != "" {
sections = append(sections, footer)
}
for _, errBanner := range errBanners {
sections = append(sections, errBanner)
}
return lipgloss.JoinVertical(lipgloss.Left, sections...)
}
type LaunchOptions struct {
InitialConversation *conversation.Conversation
InitialView shared.View
}
type LaunchOption func(*LaunchOptions)
func WithInitialConversation(conv *conversation.Conversation) LaunchOption {
return func(opts *LaunchOptions) {
opts.InitialConversation = conv
}
}
func WithInitialView(view shared.View) LaunchOption {
return func(opts *LaunchOptions) {
opts.InitialView = view
}
}
func Launch(ctx *lmcli.Context, options ...LaunchOption) error {
opts := &LaunchOptions{
InitialView: shared.ViewChat,
}
for _, opt := range options {
opt(opts)
}
program := tea.NewProgram(initialModel(ctx, *opts), tea.WithAltScreen())
if _, err := program.Run(); err != nil {
return fmt.Errorf("Error running program: %v", err) return fmt.Errorf("Error running program: %v", err)
} }
return nil return nil

View File

@ -54,23 +54,60 @@ func Height(str string) int {
return strings.Count(str, "\n") + 1 return strings.Count(str, "\n") + 1
} }
// truncate a string until its rendered cell width + the provided tail fits func Width(str string) int {
// within the given width if str == "" {
func TruncateToCellWidth(str string, width int, tail string) string { return 0
}
return ansi.PrintableRuneWidth(str)
}
func TruncateRightToCellWidth(str string, width int, tail string) string {
cellWidth := ansi.PrintableRuneWidth(str) cellWidth := ansi.PrintableRuneWidth(str)
if cellWidth <= width { if cellWidth <= width {
return str return str
} }
tailWidth := ansi.PrintableRuneWidth(tail) tailWidth := ansi.PrintableRuneWidth(tail)
for { if width <= tailWidth {
str = str[:len(str)-((cellWidth+tailWidth)-width)] return tail[:width]
cellWidth = ansi.PrintableRuneWidth(str)
if cellWidth+tailWidth <= max(width, 0) {
break
}
} }
targetWidth := width - tailWidth
runes := []rune(str)
for i := len(runes) - 1; i >= 0; i-- {
str = string(runes[:i])
if ansi.PrintableRuneWidth(str) <= targetWidth {
return str + tail return str + tail
} }
}
return tail
}
func TruncateLeftToCellWidth(str string, width int, tail string) string {
cellWidth := ansi.PrintableRuneWidth(str)
if cellWidth <= width {
return str
}
tailWidth := ansi.PrintableRuneWidth(tail)
if width <= tailWidth {
return tail[:width]
}
targetWidth := width - tailWidth
runes := []rune(str)
for i := 0; i < len(runes); i++ {
str = string(runes[i:])
if ansi.PrintableRuneWidth(str) <= targetWidth {
return tail + str
}
}
return tail
}
func ScrollIntoView(vp *viewport.Model, offset int, edge int) { func ScrollIntoView(vp *viewport.Model, offset int, edge int) {
currentOffset := vp.YOffset currentOffset := vp.YOffset
@ -98,4 +135,3 @@ func ErrorBanner(err error, width int) string {
Foreground(lipgloss.Color("1")). Foreground(lipgloss.Color("1")).
Render(fmt.Sprintf("%s", err)) Render(fmt.Sprintf("%s", err))
} }

View File

@ -4,7 +4,9 @@ import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/textarea" "github.com/charmbracelet/bubbles/textarea"
@ -15,39 +17,35 @@ import (
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent when a conversation is (re)loaded
msgConversationLoaded struct {
conversation *api.Conversation
rootMessages []api.Message
}
// sent when a new conversation title generated // sent when a new conversation title generated
msgConversationTitleGenerated string msgConversationTitleGenerated string
// sent when the conversation has been persisted, triggers a reload of contents // sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct { msgConversationPersisted conversation.Conversation
isNew bool msgMessagesPersisted []conversation.Message
conversation *api.Conversation
messages []api.Message
}
// sent when a conversation's messages are laoded // sent when a conversation's messages are laoded
msgMessagesLoaded []api.Message msgConversationMessagesLoaded struct {
messages []conversation.Message
}
// a special case of common.MsgError that stops the response waiting animation // a special case of common.MsgError that stops the response waiting animation
msgChatResponseError error msgChatResponseError struct {
Err error
}
// sent on each chunk received from LLM // sent on each chunk received from LLM
msgChatResponseChunk api.Chunk msgChatResponseChunk provider.Chunk
// sent on each completed reply // sent on each completed reply
msgChatResponse *api.Message msgChatResponse conversation.Message
// sent when the response is canceled // sent when the response is canceled
msgChatResponseCanceled struct{} msgChatResponseCanceled struct{}
// sent when results from a tool call are returned // sent when results from a tool call are returned
msgToolResults []api.ToolResult msgToolResults []api.ToolResult
// sent when the given message is made the new selected reply of its parent // sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *api.Message msgSelectedReplyCycled *conversation.Message
// sent when the given message is made the new selected root of the current conversation // sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *api.Message msgSelectedRootCycled *conversation.Message
// sent when a message's contents are updated and saved // sent when a message's contents are updated and saved
msgMessageUpdated *api.Message msgMessageUpdated *conversation.Message
// sent when a message is cloned, with the cloned message // sent when a message is cloned, with the cloned message
msgMessageCloned *api.Message msgMessageCloned *conversation.Message
) )
type focusState int type focusState int
@ -73,25 +71,24 @@ const (
) )
type Model struct { type Model struct {
shared.Shared // App state
shared.Sections App *model.AppModel
Height int
Width int
// app state // Chat view state
state state // current overall status of the view state state // current overall status of the view
conversation *api.Conversation
rootMessages []api.Message
messages []api.Message
selectedMessage int selectedMessage int
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan api.Message replyChan chan conversation.Message
chatReplyChunks chan api.Chunk chatReplyChunks chan provider.Chunk
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // UI state
focus focusState focus focusState
showDetails bool // whether various details are shown in the UI (e.g. system prompt, tool calls/results, message metadata)
wrap bool // whether message content is wrapped to viewport width wrap bool // whether message content is wrapped to viewport width
showToolResults bool // whether tool calls and results are shown
messageCache []string // cache of syntax highlighted and wrapped message content messageCache []string // cache of syntax highlighted and wrapped message content
messageOffsets []int messageOffsets []int
@ -107,53 +104,51 @@ type Model struct {
elapsed time.Duration elapsed time.Duration
} }
func Chat(shared shared.Shared) Model { func getSpinner() spinner.Model {
return spinner.New(spinner.WithSpinner(
spinner.Spinner{
Frames: []string{
"∙∙∙",
"●∙∙",
"●●∙",
"●●●",
"∙●●",
"∙∙●",
"∙∙∙",
"∙∙●",
"∙●●",
"●●●",
"●●∙",
"●∙∙",
},
FPS: 440 * time.Millisecond,
},
))
}
func Chat(app *model.AppModel) *Model {
m := Model{ m := Model{
Shared: shared, App: app,
state: idle, state: idle,
conversation: &api.Conversation{},
persistence: true, persistence: true,
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan api.Message), replyChan: make(chan conversation.Message),
chatReplyChunks: make(chan api.Chunk), chatReplyChunks: make(chan provider.Chunk),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,
content: viewport.New(0, 0), content: viewport.New(0, 0),
input: textarea.New(), input: textarea.New(),
spinner: spinner.New(spinner.WithSpinner( spinner: getSpinner(),
spinner.Spinner{
Frames: []string{
". ",
".. ",
"...",
".. ",
". ",
" ",
},
FPS: time.Second / 3,
},
)),
replyCursor: cursor.New(), replyCursor: cursor.New(),
} }
m.replyCursor.SetChar(" ") m.replyCursor.SetChar(" ")
m.replyCursor.Focus() m.replyCursor.Focus()
system := shared.Ctx.DefaultSystemPrompt()
agent := shared.Ctx.GetAgent(shared.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system != "" {
m.messages = api.ApplySystemPrompt(m.messages, system, false)
}
m.input.Focus() m.input.Focus()
m.input.MaxHeight = 0 m.input.MaxHeight = 0
m.input.CharLimit = 0 m.input.CharLimit = 0
@ -163,11 +158,10 @@ func Chat(shared shared.Shared) Model {
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle() m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
m.input.FocusedStyle.Base = inputFocusedStyle m.input.FocusedStyle.Base = inputFocusedStyle
m.input.BlurredStyle.Base = inputBlurredStyle m.input.BlurredStyle.Base = inputBlurredStyle
return &m
return m
} }
func (m Model) Init() tea.Cmd { func (m *Model) Init() tea.Cmd {
return tea.Batch( return tea.Batch(
m.waitForResponseChunk(), m.waitForResponseChunk(),
) )

136
pkg/tui/views/chat/cmds.go Normal file
View File

@ -0,0 +1,136 @@
package chat
import (
"time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg {
return msgChatResponseChunk(<-m.chatReplyChunks)
}
}
func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.App.LoadConversationMessages()
if err != nil {
return shared.AsMsgError(err)
}
return msgConversationMessagesLoaded{messages}
}
}
func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := m.App.GenerateConversationTitle(m.App.Messages)
if err != nil {
return shared.AsMsgError(err)
}
return msgConversationTitleGenerated(title)
}
}
func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd {
return func() tea.Msg {
msg, err := m.App.CloneMessage(message, selected)
if err != nil {
return shared.WrapError(err)
}
return msgMessageCloned(msg)
}
}
func (m *Model) updateMessageContent(message *conversation.Message) tea.Cmd {
return func() tea.Msg {
err := m.App.UpdateMessageContent(message)
if err != nil {
return shared.WrapError(err)
}
return msgMessageUpdated(message)
}
}
func (m *Model) cycleSelectedRoot(conv *conversation.Conversation, dir model.MessageCycleDirection) tea.Cmd {
if len(conv.RootMessages) < 2 {
return nil
}
return func() tea.Msg {
nextRoot, err := m.App.CycleSelectedRoot(conv, dir)
if err != nil {
return shared.WrapError(err)
}
return msgSelectedRootCycled(nextRoot)
}
}
func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 {
return nil
}
return func() tea.Msg {
nextReply, err := m.App.CycleSelectedReply(message, dir)
if err != nil {
return shared.WrapError(err)
}
return msgSelectedReplyCycled(nextReply)
}
}
func (m *Model) persistConversation() tea.Cmd {
return func() tea.Msg {
conversation, err := m.App.PersistConversation()
if err != nil {
return shared.AsMsgError(err)
}
return msgConversationPersisted(conversation)
}
}
func (m *Model) persistMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.App.PersistMessages()
if err != nil {
return shared.AsMsgError(err)
}
return msgMessagesPersisted(messages)
}
}
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
return func() tea.Msg {
results, err := m.App.ExecuteToolCalls(toolCalls)
if err != nil {
return shared.AsMsgError(err)
}
return msgToolResults(results)
}
}
func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse
m.spinner = getSpinner()
m.replyCursor.Blink = false
m.startTime = time.Now()
m.elapsed = 0
m.tokenCount = 0
return tea.Batch(
m.spinner.Tick,
func() tea.Msg {
resp, err := m.App.Prompt(m.App.Messages, m.chatReplyChunks, m.stopSignal)
if err != nil {
return msgChatResponseError{Err: err}
}
return msgChatResponse(*resp)
},
)
}

View File

@ -1,308 +0,0 @@
package chat
import (
"context"
"errors"
"fmt"
"time"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) setMessage(i int, msg api.Message) {
if i >= len(m.messages) {
panic("i out of range")
}
m.messages[i] = msg
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) addMessage(msg api.Message) {
m.messages = append(m.messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
}
func (m *Model) setMessageContents(i int, content string) {
if i >= len(m.messages) {
panic("i out of range")
}
m.messages[i].Content = content
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) rebuildMessageCache() {
m.messageCache = make([]string, len(m.messages))
for i := range m.messages {
m.messageCache[i] = m.renderMessage(i)
}
}
func (m *Model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationMessagesView())
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
}
func (m *Model) loadConversation(shortname string) tea.Cmd {
return func() tea.Msg {
if shortname == "" {
return nil
}
c, err := m.Shared.Ctx.Store.ConversationByShortName(shortname)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err))
}
if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
}
rootMessages, err := m.Shared.Ctx.Store.RootMessages(c.ID)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
}
return msgConversationLoaded{c, rootMessages}
}
}
func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.Shared.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
return msgMessagesLoaded(messages)
}
}
func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.Shared.Ctx, m.messages)
if err != nil {
return shared.MsgError(err)
}
return msgConversationTitleGenerated(title)
}
}
func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd {
return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateConversation(conversation)
if err != nil {
return shared.WrapError(err)
}
return nil
}
}
// Clones the given message (and its descendents). If selected is true, updates
// either its parent's SelectedReply or its conversation's SelectedRoot to
// point to the new clone
func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
return func() tea.Msg {
msg, _, err := m.Ctx.Store.CloneBranch(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not clone message: %v", err))
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = m.Shared.Ctx.Store.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = m.Shared.Ctx.Store.UpdateMessage(msg.Parent)
}
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
}
}
return msgMessageCloned(msg)
}
}
func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
}
return msgMessageUpdated(message)
}
}
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
// this should probably be an assert
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
// Wrap around to the last reply if at the beginning
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
// Wrap around to the first reply if at the end
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd {
if len(m.rootMessages) < 2 {
return nil
}
return func() tea.Msg {
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, m.rootMessages, dir)
if err != nil {
return shared.WrapError(err)
}
conv.SelectedRoot = nextRoot
err = m.Shared.Ctx.Store.UpdateConversation(conv)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
}
return msgSelectedRootCycled(nextRoot)
}
}
func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 {
return nil
}
return func() tea.Msg {
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return shared.WrapError(err)
}
message.SelectedReply = nextReply
err = m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
}
return msgSelectedReplyCycled(nextReply)
}
}
func (m *Model) persistConversation() tea.Cmd {
conversation := m.conversation
messages := m.messages
var err error
if conversation.ID == 0 {
return func() tea.Msg {
// Start a new conversation with all messages so far
conversation, messages, err = m.Shared.Ctx.Store.StartConversation(messages...)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
}
return msgConversationPersisted{true, conversation, messages}
}
}
return func() tea.Msg {
// else, we'll handle updating an existing conversation's messages
for i := range messages {
if messages[i].ID > 0 {
// message has an ID, update it
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
if err != nil {
return shared.MsgError(err)
}
} else if i > 0 {
// messages is new, so add it as a reply to previous message
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil {
return shared.MsgError(err)
}
messages[i] = saved[0]
} else {
// message has no id and no previous messages to add it to
// this shouldn't happen?
return fmt.Errorf("Error: no messages to reply to")
}
}
return msgConversationPersisted{false, conversation, messages}
}
}
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
return func() tea.Msg {
agent := m.Shared.Ctx.GetAgent(m.Shared.Ctx.Config.Defaults.Agent)
if agent == nil {
return shared.MsgError(fmt.Errorf("Attempted to execute tool calls with no agent configured"))
}
results, err := agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
if err != nil {
return shared.MsgError(err)
}
return msgToolResults(results)
}
}
func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse
m.replyCursor.Blink = false
m.startTime = time.Now()
m.elapsed = 0
m.tokenCount = 0
return func() tea.Msg {
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
if err != nil {
return shared.MsgError(err)
}
params := api.RequestParameters{
Model: model,
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
}
agent := m.Shared.Ctx.GetAgent(m.Shared.Ctx.Config.Defaults.Agent)
if agent != nil {
params.Toolbox = agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-m.stopSignal:
cancel()
}
}()
resp, err := provider.CreateChatCompletionStream(
ctx, params, m.messages, m.chatReplyChunks,
)
if errors.Is(err, context.Canceled) {
return msgChatResponseCanceled(struct{}{})
}
if err != nil {
return msgChatResponseError(err)
}
return msgChatResponse(resp)
}
}

View File

@ -5,29 +5,24 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
) )
type MessageCycleDirection int func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
switch m.focus { switch m.focus {
case focusInput: case focusInput:
consumed, cmd := m.handleInputKey(msg) cmd := m.handleInputKey(msg)
if consumed { if cmd != nil {
return true, cmd return cmd
} }
case focusMessages: case focusMessages:
consumed, cmd := m.handleMessagesKey(msg) cmd := m.handleMessagesKey(msg)
if consumed { if cmd != nil {
return true, cmd return cmd
} }
} }
@ -35,126 +30,152 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
case "esc": case "esc":
if m.state == pendingResponse { if m.state == pendingResponse {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return shared.KeyHandled(msg)
} }
return true, func() tea.Msg { return func() tea.Msg {
return shared.MsgViewChange(shared.StateConversations) return shared.MsgViewChange(shared.ViewConversations)
} }
case "ctrl+c": case "ctrl+c":
if m.state == pendingResponse { if m.state == pendingResponse {
m.stopSignal <- struct{}{} m.stopSignal <- struct{}{}
return true, nil return shared.KeyHandled(msg)
}
case "ctrl+g":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return shared.KeyHandled(msg)
}
return func() tea.Msg {
return shared.MsgViewChange(shared.ViewSettings)
} }
case "ctrl+p": case "ctrl+p":
m.persistence = !m.persistence m.persistence = !m.persistence
return true, nil return shared.KeyHandled(msg)
case "ctrl+t": case "ctrl+t":
m.showToolResults = !m.showToolResults m.showDetails = !m.showDetails
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
return true, nil return shared.KeyHandled(msg)
case "ctrl+w": case "ctrl+w":
m.wrap = !m.wrap m.wrap = !m.wrap
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
return true, nil return shared.KeyHandled(msg)
case "ctrl+n":
m.App.NewConversation()
m.rebuildMessageCache()
m.updateContent()
return shared.KeyHandled(msg)
} }
return false, nil return nil
}
func (m *Model) scrollSelection(dir int) {
if m.selectedMessage+dir < 0 || m.selectedMessage+dir >= len(m.App.Messages) {
return
}
newIdx := m.selectedMessage
for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir {
if !m.showDetails && m.App.Messages[i].Role.IsSystem() {
continue
}
newIdx = i
break
}
if newIdx != m.selectedMessage {
m.selectedMessage = newIdx
m.updateContent()
}
yOffset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, yOffset, m.content.Height/2)
} }
// handleMessagesKey handles input when the messages pane is focused // handleMessagesKey handles input when the messages pane is focused
func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) { func (m *Model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() { switch msg.String() {
case "tab", "enter": case "tab", "enter":
m.focus = focusInput m.focus = focusInput
m.updateContent() m.updateContent()
m.input.Focus() m.input.Focus()
return true, nil return shared.KeyHandled(msg)
case "e": case "e":
if m.selectedMessage < len(m.messages) { if m.selectedMessage < len(m.App.Messages) {
m.editorTarget = selectedMessage m.editorTarget = selectedMessage
return true, tuiutil.OpenTempfileEditor( return tuiutil.OpenTempfileEditor(
"message.*.md", "message.*.md",
m.messages[m.selectedMessage].Content, m.App.Messages[m.selectedMessage].Content,
"# Edit the message below\n", "# Edit the message below\n",
) )
} }
return false, nil return nil
case "ctrl+k": case "ctrl+k", "ctrl+up":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) { if m.selectedMessage > 0 {
m.selectedMessage-- m.scrollSelection(-1)
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
} }
return true, nil return shared.KeyHandled(msg)
case "ctrl+j": case "ctrl+j", "ctrl+down":
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) { if m.selectedMessage < len(m.App.Messages)-1 {
m.selectedMessage++ m.scrollSelection(1)
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
} }
return true, nil return shared.KeyHandled(msg)
case "ctrl+h", "ctrl+l": case "ctrl+h", "ctrl+left", "ctrl+l", "ctrl+right":
dir := CyclePrev dir := model.CyclePrev
if msg.String() == "ctrl+l" { if msg.String() == "ctrl+l" || msg.String() == "ctrl+right" {
dir = CycleNext dir = model.CycleNext
} }
var cmd tea.Cmd var cmd tea.Cmd
if m.selectedMessage == 0 { if m.selectedMessage == 0 {
cmd = m.cycleSelectedRoot(m.conversation, dir) cmd = m.cycleSelectedRoot(&m.App.Conversation, dir)
} else if m.selectedMessage > 0 { } else if m.selectedMessage > 0 {
cmd = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir) cmd = m.cycleSelectedReply(&m.App.Messages[m.selectedMessage-1], dir)
} }
return cmd
return cmd != nil, cmd
case "ctrl+r": case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message // prompt the model with all messages up to and including the selected message
if m.state == idle && m.selectedMessage < len(m.messages) { if m.state == idle && m.selectedMessage < len(m.App.Messages) {
m.messages = m.messages[:m.selectedMessage+1] m.App.Messages = m.App.Messages[:m.selectedMessage+1]
m.messageCache = m.messageCache[:m.selectedMessage+1] m.messageCache = m.messageCache[:m.selectedMessage+1]
cmd := m.promptLLM() cmd := m.promptLLM()
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
return true, cmd return cmd
} }
} }
return false, nil return nil
} }
// handleInputKey handles input when the input textarea is focused // handleInputKey handles input when the input textarea is focused
func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) { func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() { switch msg.String() {
case "esc": case "esc":
m.focus = focusMessages m.focus = focusMessages
if len(m.messages) > 0 { if len(m.App.Messages) > 0 {
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) { if m.selectedMessage < 0 || m.selectedMessage >= len(m.App.Messages) {
m.selectedMessage = len(m.messages) - 1 m.selectedMessage = len(m.App.Messages) - 1
} }
offset := m.messageOffsets[m.selectedMessage] offset := m.messageOffsets[m.selectedMessage]
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2) tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
} }
m.updateContent() m.updateContent()
m.input.Blur() m.input.Blur()
return true, nil return shared.KeyHandled(msg)
case "ctrl+s": case "ctrl+s":
// TODO: call a "handleSend" function which returns a tea.Cmd
if m.state != idle { if m.state != idle {
return false, nil return nil
} }
input := strings.TrimSpace(m.input.Value()) input := strings.TrimSpace(m.input.Value())
if input == "" { if input == "" {
return true, nil return shared.KeyHandled(msg)
} }
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser { if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role.IsUser() {
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message")) return shared.WrapError(fmt.Errorf("Can't reply to a user message"))
} }
m.addMessage(api.Message{ m.addMessage(conversation.Message{
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}) })
@ -170,11 +191,11 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
return true, tea.Batch(cmds...) return tea.Batch(cmds...)
case "ctrl+e": case "ctrl+e":
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n") cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input m.editorTarget = input
return true, cmd return cmd
} }
return false, nil return nil
} }

View File

@ -5,57 +5,87 @@ import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
) )
func (m *Model) HandleResize(width, height int) { func (m *Model) setMessage(i int, msg conversation.Message) {
m.Width, m.Height = width, height if i >= len(m.App.Messages) {
m.content.Width = width panic("i out of range")
m.input.SetWidth(width - m.input.FocusedStyle.Base.GetHorizontalFrameSize()) }
if len(m.messages) > 0 { m.App.Messages[i] = msg
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) addMessage(msg conversation.Message) {
m.App.Messages = append(m.App.Messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.App.Messages)-1))
}
func (m *Model) setMessageContents(i int, content string) {
if i >= len(m.App.Messages) {
panic("i out of range")
}
m.App.Messages[i].Content = content
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) rebuildMessageCache() {
m.messageCache = make([]string, len(m.App.Messages))
for i := range m.App.Messages {
m.messageCache[i] = m.renderMessage(i)
}
}
func (m *Model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationMessagesView())
if atBottom {
m.content.GotoBottom()
}
}
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
inputHandled := false
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
cmd := m.handleInput(msg)
if cmd != nil {
inputHandled = true
cmds = append(cmds, cmd)
}
case tea.WindowSizeMsg:
m.Width, m.Height = msg.Width, msg.Height
m.content.Width = msg.Width
m.input.SetWidth(msg.Width - m.input.FocusedStyle.Base.GetHorizontalFrameSize())
if len(m.App.Messages) > 0 {
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
} }
}
func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg {
return msgChatResponseChunk(<-m.chatReplyChunks)
}
}
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
case shared.MsgViewEnter: case shared.MsgViewEnter:
// wake up spinners and cursors // wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick) cmds = append(cmds, cursor.Blink, m.spinner.Tick)
if m.Shared.Values.ConvShortname != "" { // Refresh view
// (re)load conversation contents
cmds = append(cmds, m.loadConversation(m.Shared.Values.ConvShortname))
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
// clear existing messages if we're loading a new conversation
m.messages = []api.Message{}
m.selectedMessage = 0
}
}
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
if m.App.Conversation.ID > 0 {
// (re)load conversation contents
cmds = append(cmds, m.loadConversationMessages())
}
case tuiutil.MsgTempfileEditorClosed: case tuiutil.MsgTempfileEditorClosed:
contents := string(msg) contents := string(msg)
switch m.editorTarget { switch m.editorTarget {
case input: case input:
m.input.SetValue(contents) m.input.SetValue(contents)
case selectedMessage: case selectedMessage:
toEdit := m.messages[m.selectedMessage] toEdit := m.App.Messages[m.selectedMessage]
if toEdit.Content != contents { if toEdit.Content != contents {
toEdit.Content = contents toEdit.Content = contents
m.setMessage(m.selectedMessage, toEdit) m.setMessage(m.selectedMessage, toEdit)
@ -65,19 +95,12 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
} }
} }
case msgConversationLoaded: case msgConversationMessagesLoaded:
m.conversation = msg.conversation m.App.Messages = msg.messages
m.rootMessages = msg.rootMessages
m.selectedMessage = -1
if len(m.rootMessages) > 0 {
cmds = append(cmds, m.loadConversationMessages())
}
case msgMessagesLoaded:
m.messages = msg
if m.selectedMessage == -1 { if m.selectedMessage == -1 {
m.selectedMessage = len(msg) - 1 m.selectedMessage = len(msg.messages) - 1
} else { } else {
m.selectedMessage = min(m.selectedMessage, len(m.messages)) m.selectedMessage = min(m.selectedMessage, len(m.App.Messages))
} }
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
@ -88,13 +111,13 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
break break
} }
last := len(m.messages) - 1 last := len(m.App.Messages) - 1
if last >= 0 && m.messages[last].Role.IsAssistant() { if last >= 0 && m.App.Messages[last].Role.IsAssistant() {
// append chunk to existing message // append chunk to existing message
m.setMessageContents(last, m.messages[last].Content+msg.Content) m.setMessageContents(last, m.App.Messages[last].Content+msg.Content)
} else { } else {
// use chunk in a new message // use chunk in a new message
m.addMessage(api.Message{ m.addMessage(conversation.Message{
Role: api.MessageRoleAssistant, Role: api.MessageRoleAssistant,
Content: msg.Content, Content: msg.Content,
}) })
@ -110,25 +133,24 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
case msgChatResponse: case msgChatResponse:
m.state = idle m.state = idle
reply := (*api.Message)(msg) reply := conversation.Message(msg)
reply.Content = strings.TrimSpace(reply.Content) reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1 last := len(m.App.Messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply") panic("Unexpected empty messages handling msgAssistantReply")
} }
if m.messages[last].Role.IsAssistant() { if m.App.Messages[last].Role.IsAssistant() {
// TODO: handle continuations gracefully - some models support them well, others fail horribly. // TODO: handle continuations gracefully - only some models support them
m.setMessage(last, *reply) m.setMessage(last, reply)
} else { } else {
m.addMessage(*reply) m.addMessage(reply)
} }
switch reply.Role { if reply.Role == api.MessageRoleToolCall {
case api.MessageRoleToolCall:
// TODO: user confirmation before execution // TODO: user confirmation before execution
// m.state = waitingForConfirmation // m.state = confirmToolUse
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls)) cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
} }
@ -136,31 +158,29 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
cmds = append(cmds, m.persistConversation()) cmds = append(cmds, m.persistConversation())
} }
if m.conversation.Title == "" { if m.App.Conversation.Title == "" && len(m.App.Messages) > 0 {
cmds = append(cmds, m.generateConversationTitle()) cmds = append(cmds, m.generateConversationTitle())
} }
m.updateContent()
case msgChatResponseCanceled: case msgChatResponseCanceled:
m.state = idle m.state = idle
m.updateContent() m.updateContent()
case msgChatResponseError: case msgChatResponseError:
m.state = idle m.state = idle
m.Shared.Err = error(msg)
m.updateContent() m.updateContent()
return m, shared.WrapError(msg.Err)
case msgToolResults: case msgToolResults:
last := len(m.messages) - 1 last := len(m.App.Messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply") panic("Unexpected empty messages handling msgAssistantReply")
} }
if m.messages[last].Role != api.MessageRoleToolCall { if m.App.Messages[last].Role != api.MessageRoleToolCall {
panic("Previous message not a tool call, unexpected") panic("Previous message not a tool call, unexpected")
} }
m.addMessage(api.Message{ m.addMessage(conversation.Message{
Role: api.MessageRoleToolResult, Role: api.MessageRoleToolResult,
ToolResults: api.ToolResults(msg), ToolResults: conversation.ToolResults(msg),
}) })
if m.persistence { if m.persistence {
@ -170,30 +190,25 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.updateContent() m.updateContent()
case msgConversationTitleGenerated: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
m.conversation.Title = title m.App.Conversation.Title = title
if m.persistence { if m.persistence && m.App.Conversation.ID > 0 {
cmds = append(cmds, m.updateConversationTitle(m.conversation)) cmds = append(cmds, m.persistConversation())
} }
case cursor.BlinkMsg: case cursor.BlinkMsg:
if m.state == pendingResponse { if m.state == pendingResponse {
// ensure we show the updated "wait for response" cursor blink state // ensure we show the updated "wait for response" cursor blink state
last := len(m.messages)-1 last := len(m.App.Messages) - 1
m.messageCache[last] = m.renderMessage(last) m.messageCache[last] = m.renderMessage(last)
m.updateContent() m.updateContent()
} }
case msgConversationPersisted: case msgConversationPersisted:
m.conversation = msg.conversation m.App.Conversation = conversation.Conversation(msg)
m.messages = msg.messages cmds = append(cmds, m.persistMessages())
if msg.isNew { case msgMessagesPersisted:
m.rootMessages = []api.Message{m.messages[0]} m.App.Messages = msg
}
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case msgMessageCloned: case msgMessageCloned:
if msg.Parent == nil {
m.conversation = msg.Conversation
m.rootMessages = append(m.rootMessages, *msg)
}
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated: case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())
@ -210,38 +225,22 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
prevInputLineCnt := m.input.LineCount() prevInputLineCnt := m.input.LineCount()
inputCaptured := false
if !inputHandled {
m.input, cmd = m.input.Update(msg) m.input, cmd = m.input.Update(msg)
if cmd != nil { if cmd != nil {
inputCaptured = true inputHandled = true
cmds = append(cmds, cmd) cmds = append(cmds, cmd)
} }
}
if !inputCaptured { if !inputHandled {
m.content, cmd = m.content.Update(msg) m.content, cmd = m.content.Update(msg)
if cmd != nil { if cmd != nil {
cmds = append(cmds, cmd) cmds = append(cmds, cmd)
} }
} }
// update views once window dimensions are known
if m.Width > 0 {
m.Header = m.headerView()
m.Footer = m.footerView()
m.Error = tuiutil.ErrorBanner(m.Err, m.Width)
fixedHeight := tuiutil.Height(m.Header) + tuiutil.Height(m.Error) + tuiutil.Height(m.Footer)
// calculate clamped input height to accomodate input text
// minimum 4 lines, maximum half of content area
newHeight := max(4, min((m.Height-fixedHeight-1)/2, m.input.LineCount()))
m.input.SetHeight(newHeight)
m.Input = m.input.View()
// remaining height towards content
m.content.Height = m.Height - fixedHeight - tuiutil.Height(m.Input)
m.Content = m.content.View()
}
// this is a pretty nasty hack to ensure the input area viewport doesn't // this is a pretty nasty hack to ensure the input area viewport doesn't
// scroll below its content, which can happen when the input viewport // scroll below its content, which can happen when the input viewport
// height has grown, or previously entered lines have been deleted // height has grown, or previously entered lines have been deleted
@ -264,5 +263,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
} }
if len(cmds) > 0 {
return m, tea.Batch(cmds...) return m, tea.Batch(cmds...)
} }
return m, nil
}

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles" "git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
@ -16,15 +17,19 @@ import (
// styles // styles
var ( var (
boldStyle = lipgloss.NewStyle().Bold(true)
faintStyle = lipgloss.NewStyle().Faint(true)
boldFaintStyle = lipgloss.NewStyle().Faint(true).Bold(true)
messageHeadingStyle = lipgloss.NewStyle(). messageHeadingStyle = lipgloss.NewStyle().
MarginTop(1). MarginTop(1).
MarginBottom(1). MarginBottom(1)
PaddingLeft(1).
Bold(true)
userStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("10")) userStyle = boldFaintStyle.Foreground(lipgloss.Color("10"))
assistantStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("12")) assistantStyle = boldFaintStyle.Foreground(lipgloss.Color("12"))
systemStyle = boldStyle.Foreground(lipgloss.Color("8"))
messageStyle = lipgloss.NewStyle(). messageStyle = lipgloss.NewStyle().
PaddingLeft(2). PaddingLeft(2).
@ -37,40 +42,14 @@ var (
Faint(true). Faint(true).
Border(lipgloss.RoundedBorder(), true, true, true, false) Border(lipgloss.RoundedBorder(), true, true, true, false)
footerStyle = lipgloss.NewStyle() footerStyle = lipgloss.NewStyle().Padding(0, 1)
) )
func (m Model) View() string { func (m *Model) renderMessageHeading(i int, message *conversation.Message) string {
if m.Width == 0 {
return ""
}
sections := make([]string, 0, 6)
if m.Header != "" {
sections = append(sections, m.Header)
}
sections = append(sections, m.Content)
if m.Error != "" {
sections = append(sections, m.Error)
}
sections = append(sections, m.Input)
if m.Footer != "" {
sections = append(sections, m.Footer)
}
return lipgloss.JoinVertical(lipgloss.Left, sections...)
}
func (m *Model) renderMessageHeading(i int, message *api.Message) string {
icon := ""
friendly := message.Role.FriendlyRole() friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Faint(true).Bold(true) style := systemStyle
switch message.Role { switch message.Role {
case api.MessageRoleSystem:
icon = "⚙️"
case api.MessageRoleUser: case api.MessageRoleUser:
style = userStyle style = userStyle
case api.MessageRoleAssistant: case api.MessageRoleAssistant:
@ -78,70 +57,75 @@ func (m *Model) renderMessageHeading(i int, message *api.Message) string {
case api.MessageRoleToolCall: case api.MessageRoleToolCall:
style = assistantStyle style = assistantStyle
friendly = api.MessageRoleAssistant.FriendlyRole() friendly = api.MessageRoleAssistant.FriendlyRole()
case api.MessageRoleSystem:
case api.MessageRoleToolResult: case api.MessageRoleToolResult:
icon = "🔧"
} }
user := style.Render(icon + friendly) user := style.Render(friendly)
var prefix string var prefix, suffix string
var suffix string
faint := lipgloss.NewStyle().Faint(true) if i == m.selectedMessage && m.focus == focusMessages {
prefix = "> "
} else {
prefix = " "
}
if i == 0 && len(m.rootMessages) > 1 && m.conversation.SelectedRootID != nil { if i == 0 && m.App.Conversation.SelectedRootID != nil && len(m.App.Conversation.RootMessages) > 1 {
selectedRootIndex := 0 selectedRootIndex := 0
for j, reply := range m.rootMessages { for j, reply := range m.App.Conversation.RootMessages {
if reply.ID == *m.conversation.SelectedRootID { if reply.ID == *m.App.Conversation.SelectedRootID {
selectedRootIndex = j selectedRootIndex = j
break break
} }
} }
suffix += faint.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.rootMessages))) suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.Conversation.RootMessages)))
} }
if i > 0 && len(m.messages[i-1].Replies) > 1 { if i > 0 && len(m.App.Messages[i-1].Replies) > 1 {
// Find the selected reply index // Find the selected reply index
selectedReplyIndex := 0 selectedReplyIndex := 0
for j, reply := range m.messages[i-1].Replies { for j, reply := range m.App.Messages[i-1].Replies {
if reply.ID == *m.messages[i-1].SelectedReplyID { if reply.ID == *m.App.Messages[i-1].SelectedReplyID {
selectedReplyIndex = j selectedReplyIndex = j
break break
} }
} }
suffix += faint.Render(fmt.Sprintf(" <%d/%d>", selectedReplyIndex+1, len(m.messages[i-1].Replies))) suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedReplyIndex+1, len(m.App.Messages[i-1].Replies)))
}
if m.focus == focusMessages {
if i == m.selectedMessage {
prefix = "> "
}
} }
if message.ID == 0 { if message.ID == 0 {
suffix += faint.Render(" (not saved)") suffix += faintStyle.Render(" (not saved)")
} }
return messageHeadingStyle.Render(prefix + user + suffix) heading := prefix + user + suffix
if message.Metadata.GenerationModel != nil && m.showDetails {
heading += faintStyle.Render(
fmt.Sprintf(" | %s", *message.Metadata.GenerationModel),
)
}
return messageHeadingStyle.Render(heading)
} }
// renderMessages renders the message at the given index as it should be shown // renderMessages renders the message at the given index as it should be shown
// *at this moment* - we render differently depending on the current application // *at this moment* - we render differently depending on the current application
// state (window size, etc, etc). // state (window size, etc, etc).
func (m *Model) renderMessage(i int) string { func (m *Model) renderMessage(i int) string {
msg := &m.messages[i] msg := &m.App.Messages[i]
// Write message contents // Write message contents
sb := &strings.Builder{} sb := &strings.Builder{}
sb.Grow(len(msg.Content) * 2) sb.Grow(len(msg.Content) * 2)
if msg.Content != "" { if msg.Content != "" {
err := m.Shared.Ctx.Chroma.Highlight(sb, msg.Content) err := m.App.Ctx.Chroma.Highlight(sb, msg.Content)
if err != nil { if err != nil {
sb.Reset() sb.Reset()
sb.WriteString(msg.Content) sb.WriteString(msg.Content)
} }
} }
isLast := i == len(m.messages)-1 isLast := i == len(m.App.Messages)-1
isAssistant := msg.Role == api.MessageRoleAssistant isAssistant := msg.Role == api.MessageRoleAssistant
if m.state == pendingResponse && isLast && isAssistant { if m.state == pendingResponse && isLast && isAssistant {
@ -167,7 +151,7 @@ func (m *Model) renderMessage(i int) string {
var toolResults []renderedResult var toolResults []renderedResult
for _, result := range msg.ToolResults { for _, result := range msg.ToolResults {
if m.showToolResults { if m.showDetails {
var jsonResult interface{} var jsonResult interface{}
err := json.Unmarshal([]byte(result.Result), &jsonResult) err := json.Unmarshal([]byte(result.Result), &jsonResult)
if err != nil { if err != nil {
@ -205,7 +189,7 @@ func (m *Model) renderMessage(i int) string {
if msg.Content != "" { if msg.Content != "" {
sb.WriteString("\n\n") sb.WriteString("\n\n")
} }
_ = m.Shared.Ctx.Chroma.HighlightLang(sb, toolString, "yaml") _ = m.App.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
} }
content := strings.TrimRight(sb.String(), "\n") content := strings.TrimRight(sb.String(), "\n")
@ -223,13 +207,17 @@ func (m *Model) renderMessage(i int) string {
// render the conversation into a string // render the conversation into a string
func (m *Model) conversationMessagesView() string { func (m *Model) conversationMessagesView() string {
sb := strings.Builder{} m.messageOffsets = make([]int, len(m.App.Messages))
m.messageOffsets = make([]int, len(m.messages))
lineCnt := 1 lineCnt := 1
for i, message := range m.messages {
sb := strings.Builder{}
for i, message := range m.App.Messages {
m.messageOffsets[i] = lineCnt m.messageOffsets[i] = lineCnt
if !m.showDetails && message.Role.IsSystem() {
continue
}
heading := m.renderMessageHeading(i, &message) heading := m.renderMessageHeading(i, &message)
sb.WriteString(heading) sb.WriteString(heading)
sb.WriteString("\n") sb.WriteString("\n")
@ -242,80 +230,136 @@ func (m *Model) conversationMessagesView() string {
} }
// Render a placeholder for the incoming assistant reply // Render a placeholder for the incoming assistant reply
if m.state == pendingResponse && m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant { if m.state == pendingResponse && m.App.Messages[len(m.App.Messages)-1].Role != api.MessageRoleAssistant {
heading := m.renderMessageHeading(-1, &api.Message{ heading := m.renderMessageHeading(-1, &conversation.Message{
Role: api.MessageRoleAssistant, Role: api.MessageRoleAssistant,
Metadata: conversation.MessageMeta{
GenerationModel: &m.App.Model,
},
}) })
sb.WriteString(heading) sb.WriteString(heading)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View())) sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
sb.WriteString("\n") sb.WriteString("\n")
} }
return sb.String() return sb.String()
} }
func (m *Model) headerView() string { func (m *Model) Content(width, height int) string {
// calculate clamped input height to accomodate input text
// minimum 4 lines, maximum half of content area
inputHeight := max(4, min(height/2, m.input.LineCount()))
m.input.SetHeight(inputHeight)
input := m.input.View()
// remaining height towards content
m.content.Width, m.content.Height = width, height-tuiutil.Height(input)
content := m.content.View()
return lipgloss.JoinVertical(lipgloss.Left, content, input)
}
func (m *Model) Header(width int) string {
titleStyle := lipgloss.NewStyle().Bold(true) titleStyle := lipgloss.NewStyle().Bold(true)
var title string var title string
if m.conversation != nil && m.conversation.Title != "" { if m.App.Conversation.Title != "" {
title = m.conversation.Title title = m.App.Conversation.Title
} else { } else {
title = "Untitled" title = "Untitled"
} }
title = tuiutil.TruncateToCellWidth(title, m.Width-styles.Header.GetHorizontalPadding(), "...") title = tuiutil.TruncateRightToCellWidth(title, width-styles.Header.GetHorizontalPadding(), "...")
header := titleStyle.Render(title) header := titleStyle.Render(title)
return styles.Header.Width(m.Width).Render(header) return styles.Header.Width(width).Render(header)
} }
func (m *Model) footerView() string { func (m *Model) Footer(width int) string {
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true) segmentStyle := lipgloss.NewStyle().Faint(true)
segmentSeparator := "|" segmentSeparator := segmentStyle.Render(" | ")
savingStyle := segmentStyle.Copy().Bold(true) // Left segments
saving := "" leftSegments := make([]string, 0, 4)
if m.persistence {
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾") if m.state == pendingResponse {
leftSegments = append(leftSegments, segmentStyle.Render(m.spinner.View()))
} else { } else {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾") leftSegments = append(leftSegments, segmentStyle.Render("∙∙∙"))
} }
var status string
switch m.state {
case pendingResponse:
status = "Press ctrl+c to cancel" + m.spinner.View()
default:
status = "Press ctrl+s to send"
}
leftSegments := []string{
saving,
segmentStyle.Render(status),
}
rightSegments := []string{}
if m.elapsed > 0 && m.tokenCount > 0 { if m.elapsed > 0 && m.tokenCount > 0 {
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds()) throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
rightSegments = append(rightSegments, segmentStyle.Render(throughput)) leftSegments = append(leftSegments, segmentStyle.Render(throughput))
} }
model := fmt.Sprintf("Model: %s", *m.Shared.Ctx.Config.Defaults.Model) // var status string
rightSegments = append(rightSegments, segmentStyle.Render(model)) // switch m.state {
// case pendingResponse:
// status = "Press ctrl+c to cancel"
// default:
// status = "Press ctrl+s to send"
// }
// leftSegments = append(leftSegments, segmentStyle.Render(status))
// Right segments
rightSegments := make([]string, 0, 8)
if m.App.Agent != nil {
rightSegments = append(rightSegments, segmentStyle.Render(m.App.Agent.Name))
}
model := segmentStyle.Render(m.App.ActiveModel(lipgloss.NewStyle()))
rightSegments = append(rightSegments, model)
savingStyle := segmentStyle.Bold(true)
saving := ""
if m.persistence {
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("💾✅")
} else {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("💾❌")
}
rightSegments = append(rightSegments, saving)
return m.layoutFooter(width, leftSegments, rightSegments, segmentSeparator)
}
func (m *Model) layoutFooter(
width int,
leftSegments []string,
rightSegments []string,
segmentSeparator string,
) string {
left := strings.Join(leftSegments, segmentSeparator) left := strings.Join(leftSegments, segmentSeparator)
right := strings.Join(rightSegments, segmentSeparator) right := strings.Join(rightSegments, segmentSeparator)
totalWidth := lipgloss.Width(left) + lipgloss.Width(right) leftWidth := tuiutil.Width(left)
remaining := m.Width - totalWidth rightWidth := tuiutil.Width(right)
sepWidth := tuiutil.Width(segmentSeparator)
frameWidth := footerStyle.GetHorizontalFrameSize()
var padding string availableWidth := width - frameWidth - leftWidth - rightWidth
if remaining > 0 {
padding = strings.Repeat(" ", remaining) if availableWidth >= sepWidth {
// Everything fits
padding := strings.Repeat(" ", availableWidth)
return footerStyle.Render(left + padding + right)
} }
footer := left + padding + right // Inserted between left and right segments when they're being truncated
if remaining < 0 { div := "..."
footer = tuiutil.TruncateToCellWidth(footer, m.Width, "...")
totalAvailableWidth := width - frameWidth
availableTruncWidth := totalAvailableWidth - len(div)
minVisibleLength := 3
if availableTruncWidth < 2*minVisibleLength {
minVisibleLength = availableTruncWidth / 2
} }
return footerStyle.Width(m.Width).Render(footer)
leftProportion := float64(leftWidth) / float64(leftWidth+rightWidth)
newLeftWidth := int(max(float64(minVisibleLength), leftProportion*float64(availableTruncWidth)))
newRightWidth := totalAvailableWidth - newLeftWidth
truncatedLeft := faintStyle.Render(tuiutil.TruncateRightToCellWidth(left, newLeftWidth, ""))
truncatedRight := faintStyle.Render(tuiutil.TruncateLeftToCellWidth(right, newRightWidth, "..."))
return footerStyle.Width(width).Render(truncatedLeft + truncatedRight)
} }

View File

@ -5,8 +5,9 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles" "git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles" "git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
@ -16,67 +17,56 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
type loadedConversation struct {
conv api.Conversation
lastReply api.Message
}
type ( type (
// sent when conversation list is loaded // sent when conversation list is loaded
msgConversationsLoaded ([]loadedConversation) msgConversationsLoaded conversation.ConversationList
// sent when a conversation is selected // sent when a single conversation is loaded
msgConversationSelected api.Conversation msgConversationLoaded conversation.Conversation
// sent when a conversation is deleted // sent when a conversation is deleted
msgConversationDeleted struct{} msgConversationDeleted struct{}
) )
// Prompt payloads
type (
deleteConversationPayload api.Conversation
)
type Model struct { type Model struct {
shared.Shared App *model.AppModel
shared.Sections width int
height int
conversations []loadedConversation cursor int
cursor int // index of the currently selected conversation itemOffsets []int // conversation y offsets
itemOffsets []int // keeps track of the viewport y offset of each rendered item
content viewport.Model content viewport.Model
confirmPrompt bubbles.ConfirmPrompt confirmPrompt bubbles.ConfirmPrompt
} }
func Conversations(shared shared.Shared) Model { func Conversations(app *model.AppModel) *Model {
viewport.New(0, 0)
m := Model{ m := Model{
Shared: shared, App: app,
content: viewport.New(0, 0), content: viewport.New(0, 0),
} }
return m return &m
} }
func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) { func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
if m.confirmPrompt.Focused() { if m.confirmPrompt.Focused() {
var cmd tea.Cmd var cmd tea.Cmd
m.confirmPrompt, cmd = m.confirmPrompt.Update(msg) m.confirmPrompt, cmd = m.confirmPrompt.Update(msg)
if cmd != nil { if cmd != nil {
return true, cmd return cmd
} }
} }
conversations := m.App.Conversations.Items
switch msg.String() { switch msg.String() {
case "enter": case "enter":
if len(m.conversations) > 0 && m.cursor < len(m.conversations) { if len(conversations) > 0 && m.cursor < len(conversations) {
return true, func() tea.Msg { return m.loadConversation(conversations[m.cursor].ID)
return msgConversationSelected(m.conversations[m.cursor].conv)
}
} }
case "j", "down": case "j", "down":
if m.cursor < len(m.conversations)-1 { if m.cursor < len(conversations)-1 {
m.cursor++ m.cursor++
if m.cursor == len(m.conversations)-1 { if m.cursor == len(conversations)-1 {
// if last conversation, simply scroll to the bottom
m.content.GotoBottom() m.content.GotoBottom()
} else { } else {
// this hack positions the *next* conversatoin slightly // this hack positions the *next* conversatoin slightly
@ -86,10 +76,10 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
} }
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
} else { } else {
m.cursor = len(m.conversations) - 1 m.cursor = len(conversations) - 1
m.content.GotoBottom() m.content.GotoBottom()
} }
return true, nil return shared.KeyHandled(msg)
case "k", "up": case "k", "up":
if m.cursor > 0 { if m.cursor > 0 {
m.cursor-- m.cursor--
@ -103,23 +93,24 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
m.cursor = 0 m.cursor = 0
m.content.GotoTop() m.content.GotoTop()
} }
return true, nil return shared.KeyHandled(msg)
case "n": case "n":
// new conversation m.App.NewConversation()
return shared.ChangeView(shared.ViewChat)
case "d": case "d":
if !m.confirmPrompt.Focused() && len(m.conversations) > 0 && m.cursor < len(m.conversations) { if !m.confirmPrompt.Focused() && len(conversations) > 0 && m.cursor < len(conversations) {
title := m.conversations[m.cursor].conv.Title title := conversations[m.cursor].Title
if title == "" { if title == "" {
title = "(untitled)" title = "(untitled)"
} }
m.confirmPrompt = bubbles.NewConfirmPrompt( m.confirmPrompt = bubbles.NewConfirmPrompt(
fmt.Sprintf("Delete '%s'?", title), fmt.Sprintf("Delete '%s'?", title),
deleteConversationPayload(m.conversations[m.cursor].conv), conversations[m.cursor],
) )
m.confirmPrompt.Style = lipgloss.NewStyle(). m.confirmPrompt.Style = lipgloss.NewStyle().
Bold(true). Bold(true).
Foreground(lipgloss.Color("3")) Foreground(lipgloss.Color("3"))
return true, nil return shared.KeyHandled(msg)
} }
case "c": case "c":
// copy/clone conversation // copy/clone conversation
@ -128,122 +119,115 @@ func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
case "shift+r": case "shift+r":
// show prompt to generate name for conversation // show prompt to generate name for conversation
} }
return false, nil
}
func (m Model) Init() tea.Cmd {
return nil return nil
} }
func (m *Model) HandleResize(width, height int) { func (m *Model) Init() tea.Cmd {
m.Width, m.Height = width, height return nil
m.content.Width = width
} }
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
isInput := false
inputHandled := false
var cmds []tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.KeyMsg:
isInput = true
cmd := m.handleInput(msg)
if cmd != nil {
cmds = append(cmds, cmd)
inputHandled = true
}
case shared.MsgViewEnter: case shared.MsgViewEnter:
cmds = append(cmds, m.loadConversations()) cmds = append(cmds, m.loadConversations())
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
case tea.WindowSizeMsg: case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height) m.width, m.height = msg.Width, msg.Height
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
case msgConversationsLoaded: case msgConversationsLoaded:
m.conversations = msg m.App.Conversations = conversation.ConversationList(msg)
m.cursor = max(0, min(len(m.conversations), m.cursor)) m.cursor = max(0, min(len(m.App.Conversations.Items), m.cursor))
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
case msgConversationSelected: case msgConversationLoaded:
m.Values.ConvShortname = msg.ShortName.String m.App.ClearConversation()
m.App.Conversation = conversation.Conversation(msg)
cmds = append(cmds, func() tea.Msg { cmds = append(cmds, func() tea.Msg {
return shared.MsgViewChange(shared.StateChat) return shared.MsgViewChange(shared.ViewChat)
}) })
case bubbles.MsgConfirmPromptAnswered: case bubbles.MsgConfirmPromptAnswered:
m.confirmPrompt.Blur() m.confirmPrompt.Blur()
if msg.Value { if msg.Value {
switch payload := msg.Payload.(type) { conv, ok := msg.Payload.(conversation.ConversationListItem)
case deleteConversationPayload: if ok {
cmds = append(cmds, m.deleteConversation(api.Conversation(payload))) cmds = append(cmds, m.deleteConversation(conv))
} }
} }
case msgConversationDeleted: case msgConversationDeleted:
cmds = append(cmds, m.loadConversations()) cmds = append(cmds, m.loadConversations())
} }
var cmd tea.Cmd if !isInput || !inputHandled {
m.content, cmd = m.content.Update(msg) content, cmd := m.content.Update(msg)
m.content = content
if cmd != nil { if cmd != nil {
cmds = append(cmds, cmd) cmds = append(cmds, cmd)
} }
}
if m.Width > 0 { if len(cmds) > 0 {
wrap := lipgloss.NewStyle().Width(m.Width)
m.Header = m.headerView()
m.Footer = "" // TODO: "Press ? for help"
if m.confirmPrompt.Focused() {
m.Footer = wrap.Render(m.confirmPrompt.View())
}
m.Error = tuiutil.ErrorBanner(m.Err, m.Width)
fixedHeight := tuiutil.Height(m.Header) + tuiutil.Height(m.Error) + tuiutil.Height(m.Footer)
m.content.Height = m.Height - fixedHeight
m.Content = m.content.View()
}
return m, tea.Batch(cmds...) return m, tea.Batch(cmds...)
} }
return m, nil
}
func (m *Model) loadConversations() tea.Cmd { func (m *Model) loadConversations() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.Ctx.Store.LatestConversationMessages() list, err := m.App.Ctx.Conversations.LoadConversationList()
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversations: %v", err)) return shared.AsMsgError(fmt.Errorf("Could not load conversations: %v", err))
} }
return msgConversationsLoaded(list)
loaded := make([]loadedConversation, len(messages))
for i, m := range messages {
loaded[i].lastReply = m
loaded[i].conv = *m.Conversation
}
return msgConversationsLoaded(loaded)
} }
} }
func (m *Model) deleteConversation(conv api.Conversation) tea.Cmd { func (m *Model) loadConversation(conversationID uint) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Ctx.Store.DeleteConversation(&conv) conversation, err := m.App.Ctx.Conversations.GetConversationByID(conversationID)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not delete conversation: %v", err)) return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err))
}
return msgConversationLoaded(*conversation)
}
}
func (m *Model) deleteConversation(conv conversation.ConversationListItem) tea.Cmd {
return func() tea.Msg {
err := m.App.Ctx.Conversations.DeleteConversationById(conv.ID)
if err != nil {
return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err))
} }
return msgConversationDeleted{} return msgConversationDeleted{}
} }
} }
func (m Model) View() string { func (m *Model) Header(width int) string {
if m.Width == 0 {
return ""
}
sections := make([]string, 0, 6)
if m.Header != "" {
sections = append(sections, m.Header)
}
sections = append(sections, m.Content)
if m.Error != "" {
sections = append(sections, m.Error)
}
if m.Footer != "" {
sections = append(sections, m.Footer)
}
return lipgloss.JoinVertical(lipgloss.Left, sections...)
}
func (m *Model) headerView() string {
titleStyle := lipgloss.NewStyle().Bold(true) titleStyle := lipgloss.NewStyle().Bold(true)
header := titleStyle.Render("Conversations") header := titleStyle.Render("Conversations")
return styles.Header.Width(m.Width).Render(header) return styles.Header.Width(width).Render(header)
}
func (m *Model) Content(width int, height int) string {
m.content.Width, m.content.Height = width, height
return m.content.View()
}
func (m *Model) Footer(width int) string {
if m.confirmPrompt.Focused() {
return lipgloss.NewStyle().Width(width).Render(m.confirmPrompt.View())
}
return ""
} }
func (m *Model) renderConversationList() string { func (m *Model) renderConversationList() string {
@ -289,12 +273,12 @@ func (m *Model) renderConversationList() string {
sb strings.Builder sb strings.Builder
) )
m.itemOffsets = make([]int, len(m.conversations)) m.itemOffsets = make([]int, len(m.App.Conversations.Items))
sb.WriteRune('\n') sb.WriteRune('\n')
currentOffset += 1 currentOffset += 1
for i, c := range m.conversations { for i, c := range m.App.Conversations.Items {
lastReplyAge := now.Sub(c.lastReply.CreatedAt) lastReplyAge := now.Sub(c.LastMessageAt)
var category string var category string
for _, g := range categories { for _, g := range categories {
@ -313,15 +297,15 @@ func (m *Model) renderConversationList() string {
sb.WriteRune('\n') sb.WriteRune('\n')
} }
tStyle := titleStyle.Copy() tStyle := titleStyle
if c.conv.Title == "" { if c.Title == "" {
tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)") tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)")
} }
if i == m.cursor { if i == m.cursor {
tStyle = tStyle.Inherit(selectedStyle) tStyle = tStyle.Inherit(selectedStyle)
} }
title := tStyle.Width(m.Width - 3).PaddingLeft(2).Render(c.conv.Title) title := tStyle.Width(m.width - 3).PaddingLeft(2).Render(c.Title)
if i == m.cursor { if i == m.cursor {
title = ">" + title[1:] title = ">" + title[1:]
} }
@ -334,7 +318,7 @@ func (m *Model) renderConversationList() string {
)) ))
sb.WriteString(item) sb.WriteString(item)
currentOffset += tuiutil.Height(item) currentOffset += tuiutil.Height(item)
if i < len(m.conversations)-1 { if i < len(m.App.Conversations.Items)-1 {
sb.WriteRune('\n') sb.WriteRune('\n')
} }
} }

View File

@ -0,0 +1,137 @@
package settings
import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type Model struct {
App *model.AppModel
prevView shared.View
content viewport.Model
modelList list.Model
width int
height int
}
type modelOpt struct {
provider string
model string
}
const (
modelListId int = iota + 1
)
func Settings(app *model.AppModel) *Model {
m := &Model{
App: app,
content: viewport.New(0, 0),
}
return m
}
func (m *Model) Init() tea.Cmd {
m.modelList = list.NewWithGroups(m.getModelOptions())
m.modelList.ID = modelListId
return nil
}
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
m.modelList, cmd = m.modelList.Update(msg)
if cmd != nil {
return m, cmd
}
switch msg.String() {
case "esc":
return m, func() tea.Msg {
return shared.MsgViewChange(m.prevView)
}
}
case shared.MsgViewEnter:
m.prevView = shared.View(msg)
m.modelList.Focus()
m.content.SetContent(m.renderContent())
case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height
m.content.Width = msg.Width
m.content.Height = msg.Height
m.content.SetContent(m.renderContent())
case list.MsgOptionSelected:
switch msg.ID {
case modelListId:
if modelOpt, ok := msg.Option.Value.(modelOpt); ok {
m.App.Model = modelOpt.model
m.App.ProviderName = modelOpt.provider
}
return m, shared.ChangeView(m.prevView)
}
}
m.modelList, cmd = m.modelList.Update(msg)
if cmd != nil {
return m, cmd
}
m.content.SetContent(m.renderContent())
return m, nil
}
func (m *Model) getModelOptions() []list.OptionGroup {
modelOpts := []list.OptionGroup{}
for _, p := range m.App.Ctx.Config.Providers {
provider := p.Name
if provider == "" {
provider = p.Kind
}
providerLabel := p.Display
if providerLabel == "" {
providerLabel = strings.ToUpper(provider[:1]) + provider[1:]
}
group := list.OptionGroup{
Name: providerLabel,
}
for _, model := range p.Models {
group.Options = append(group.Options, list.Option{
Label: model,
Value: modelOpt{provider, model},
})
}
modelOpts = append(modelOpts, group)
}
return modelOpts
}
func (m *Model) Header(width int) string {
boldStyle := lipgloss.NewStyle().Bold(true)
// TODO: update header depending on active settings mode (model, agent, etc)
header := boldStyle.Render("Model selection")
return styles.Header.Width(width).Render(header)
}
func (m *Model) Content(width, height int) string {
// TODO: see Header()
currentModel := " Active model: " + m.App.ActiveModel(lipgloss.NewStyle())
m.modelList.Width, m.modelList.Height = width, height - 2
return "\n" + currentModel + "\n" + m.modelList.View()
}
func (m *Model) Footer(width int) string {
return ""
}
func (m *Model) renderContent() string {
return m.modelList.View()
}

View File

@ -138,7 +138,7 @@ func SetStructDefaults(data interface{}) bool {
// Get the "default" struct tag // Get the "default" struct tag
defaultTag, ok := v.Type().Field(i).Tag.Lookup("default") defaultTag, ok := v.Type().Field(i).Tag.Lookup("default")
if (!ok) { if !ok {
continue continue
} }